Compare commits
147 Commits
perf-build
...
fix/manyli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
971653fa1a | ||
|
|
02672c040d | ||
|
|
f55108feda | ||
|
|
74d485c908 | ||
|
|
d1fefb6378 | ||
|
|
732384f4f8 | ||
|
|
ae38e10d1b | ||
|
|
ca0fd88934 | ||
|
|
3c8d32f156 | ||
|
|
b8ff00fc6a | ||
|
|
3c836766f8 | ||
|
|
b4a1dfb9c7 | ||
|
|
a4d66e95d8 | ||
|
|
cf58b3e31b | ||
|
|
e9c2ca7936 | ||
|
|
dab154a77b | ||
|
|
13413dfae5 | ||
|
|
0543cc9816 | ||
|
|
fb53ed9a0e | ||
|
|
015f43733a | ||
|
|
2957c8bf5a | ||
|
|
a73194c3f6 | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 | ||
|
|
a97d3ada1c | ||
|
|
6217bb5638 | ||
|
|
2760e99e18 | ||
|
|
0544f96b79 | ||
|
|
2ebb29de65 | ||
|
|
43762d44c7 | ||
|
|
cdaf0c98be | ||
|
|
aa9a14a917 | ||
|
|
9efcc6d95c | ||
|
|
f3f5d91207 | ||
|
|
6070160959 | ||
|
|
43155d2811 | ||
|
|
d3f85678ec | ||
|
|
2a96d05b21 | ||
|
|
851e888535 | ||
|
|
90120d4dff | ||
|
|
8513471573 | ||
|
|
71e5f1774c | ||
|
|
870a443446 | ||
|
|
cefaa2a4cc | ||
|
|
ab72a2ab9d | ||
|
|
046d457d22 | ||
|
|
7fd0a30fee | ||
|
|
c2f35c8e73 | ||
|
|
573313f0b6 | ||
|
|
f7af6805fa | ||
|
|
966de3a399 | ||
|
|
8a75829f3a | ||
|
|
0f7e34b9e2 | ||
|
|
be0322b616 | ||
|
|
232a525a62 | ||
|
|
587ce65cf6 | ||
|
|
ccf6c8bfd7 | ||
|
|
c112956d2d | ||
|
|
b3970793cf | ||
|
|
727724990e | ||
|
|
530f6e4af5 | ||
|
|
2f224f5793 | ||
|
|
1b6272ce0e |
11
.github/workflows/build-and-publish.yml
vendored
Normal file
11
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
171
.github/workflows/build-cibuildwheel.yml
vendored
Normal file
171
.github/workflows/build-cibuildwheel.yml
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
name: Build with cibuildwheel
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'Git ref to build'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
name: Build wheels on ${{ matrix.os }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest] # Focus on Linux/manylinux for Colab compatibility
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
# Build pure Python packages separately
|
||||
- name: Build pure Python packages (leann-core, leann)
|
||||
if: matrix.os == 'ubuntu-latest' # Only build once
|
||||
run: |
|
||||
python -m pip install --upgrade pip build
|
||||
python -m build packages/leann-core --outdir wheelhouse/
|
||||
python -m build packages/leann --outdir wheelhouse/
|
||||
|
||||
- name: Build leann-backend-hnsw wheels
|
||||
uses: pypa/cibuildwheel@v2.20.0
|
||||
with:
|
||||
package-dir: packages/leann-backend-hnsw
|
||||
output-dir: wheelhouse
|
||||
env:
|
||||
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-*
|
||||
CIBW_SKIP: "*-win32 *-manylinux_i686 pp* *musllinux*"
|
||||
|
||||
# Use manylinux_2_35 for Colab compatibility with modern features
|
||||
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_35
|
||||
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_35
|
||||
|
||||
# Linux dependencies - using dnf for manylinux_2_35 (based on AlmaLinux 9)
|
||||
CIBW_BEFORE_ALL_LINUX: |
|
||||
dnf install -y epel-release
|
||||
dnf install -y gcc-c++ boost-devel zeromq-devel openblas-devel cmake python3-devel
|
||||
|
||||
# Install numpy before building
|
||||
CIBW_BEFORE_BUILD: |
|
||||
pip install numpy
|
||||
pip install --upgrade pip setuptools wheel
|
||||
|
||||
CIBW_BEFORE_BUILD_LINUX: |
|
||||
pip install numpy
|
||||
pip install --upgrade pip setuptools wheel swig
|
||||
|
||||
CIBW_BEFORE_ALL_MACOS: |
|
||||
brew install boost zeromq openblas cmake libomp
|
||||
|
||||
# Pre-install test dependencies to avoid compilation
|
||||
CIBW_BEFORE_TEST: |
|
||||
pip install --only-binary :all: "pyzmq>=23.0.0"
|
||||
|
||||
# Test command to verify the wheel works
|
||||
CIBW_TEST_COMMAND: |
|
||||
python -c "import leann_backend_hnsw; print('HNSW backend imported successfully')"
|
||||
|
||||
# Skip problematic configurations
|
||||
CIBW_TEST_SKIP: "*-macosx_arm64" # Skip ARM64 tests on GitHub Actions
|
||||
|
||||
# Test dependencies
|
||||
CIBW_TEST_REQUIRES: "pytest numpy"
|
||||
|
||||
# Environment variables for build
|
||||
CIBW_ENVIRONMENT: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
Python_FIND_VIRTUALENV=ONLY
|
||||
Python3_FIND_VIRTUALENV=ONLY
|
||||
|
||||
# Linux-specific environment variables
|
||||
CIBW_ENVIRONMENT_LINUX: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
|
||||
# macOS-specific environment variables
|
||||
CIBW_ENVIRONMENT_MACOS: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
CMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||
Python_FIND_VIRTUALENV=ONLY
|
||||
Python3_FIND_VIRTUALENV=ONLY
|
||||
|
||||
- name: Build leann-backend-diskann wheels
|
||||
uses: pypa/cibuildwheel@v2.20.0
|
||||
with:
|
||||
package-dir: packages/leann-backend-diskann
|
||||
output-dir: wheelhouse
|
||||
env:
|
||||
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-*
|
||||
CIBW_SKIP: "*-win32 *-manylinux_i686 pp* *musllinux*"
|
||||
|
||||
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_35
|
||||
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_35
|
||||
|
||||
CIBW_BEFORE_ALL_LINUX: |
|
||||
dnf install -y epel-release
|
||||
dnf install -y gcc-c++ boost-devel zeromq-devel openblas-devel cmake python3-devel
|
||||
|
||||
# Install numpy before building
|
||||
CIBW_BEFORE_BUILD: |
|
||||
pip install numpy
|
||||
pip install --upgrade pip setuptools wheel
|
||||
|
||||
CIBW_BEFORE_BUILD_LINUX: |
|
||||
pip install numpy
|
||||
pip install --upgrade pip setuptools wheel swig
|
||||
|
||||
CIBW_BEFORE_ALL_MACOS: |
|
||||
brew install boost zeromq openblas cmake libomp
|
||||
|
||||
# Pre-install test dependencies to avoid compilation
|
||||
CIBW_BEFORE_TEST: |
|
||||
pip install --only-binary :all: "pyzmq>=23.0.0"
|
||||
|
||||
# Test command to verify the wheel works
|
||||
CIBW_TEST_COMMAND: |
|
||||
python -c "import leann_backend_diskann; print('DiskANN backend imported successfully')"
|
||||
|
||||
# Skip problematic configurations
|
||||
CIBW_TEST_SKIP: "*-macosx_arm64" # Skip ARM64 tests on GitHub Actions
|
||||
|
||||
# Test dependencies - avoid pyzmq due to manylinux2014 compatibility issues
|
||||
CIBW_TEST_REQUIRES: "pytest numpy"
|
||||
|
||||
CIBW_ENVIRONMENT: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
Python_FIND_VIRTUALENV=ONLY
|
||||
Python3_FIND_VIRTUALENV=ONLY
|
||||
|
||||
# Linux-specific environment variables
|
||||
CIBW_ENVIRONMENT_LINUX: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
CMAKE_PREFIX_PATH=$VIRTUAL_ENV
|
||||
Python_FIND_VIRTUALENV=ONLY
|
||||
Python3_FIND_VIRTUALENV=ONLY
|
||||
Python_FIND_STRATEGY=LOCATION
|
||||
Python3_FIND_STRATEGY=LOCATION
|
||||
Python_EXECUTABLE=$VIRTUAL_ENV/bin/python
|
||||
Python3_EXECUTABLE=$VIRTUAL_ENV/bin/python
|
||||
|
||||
# macOS-specific environment variables
|
||||
CIBW_ENVIRONMENT_MACOS: |
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8
|
||||
MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
CMAKE_OSX_DEPLOYMENT_TARGET=11.0
|
||||
Python_FIND_VIRTUALENV=ONLY
|
||||
Python3_FIND_VIRTUALENV=ONLY
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheels-${{ matrix.os }}
|
||||
path: ./wheelhouse/*.whl
|
||||
313
.github/workflows/build-reusable.yml
vendored
Normal file
313
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
name: Reusable Build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'Git ref to build'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
python: '3.9'
|
||||
container: 'quay.io/pypa/manylinux2014_x86_64'
|
||||
- os: ubuntu-latest
|
||||
python: '3.10'
|
||||
container: 'quay.io/pypa/manylinux2014_x86_64'
|
||||
- os: ubuntu-latest
|
||||
python: '3.11'
|
||||
container: 'quay.io/pypa/manylinux2014_x86_64'
|
||||
- os: ubuntu-latest
|
||||
python: '3.12'
|
||||
container: 'quay.io/pypa/manylinux2014_x86_64'
|
||||
- os: ubuntu-latest
|
||||
python: '3.13'
|
||||
container: 'quay.io/pypa/manylinux2014_x86_64'
|
||||
- os: macos-latest
|
||||
python: '3.9'
|
||||
container: ''
|
||||
- os: macos-latest
|
||||
python: '3.10'
|
||||
container: ''
|
||||
- os: macos-latest
|
||||
python: '3.11'
|
||||
container: ''
|
||||
- os: macos-latest
|
||||
python: '3.12'
|
||||
container: ''
|
||||
- os: macos-latest
|
||||
python: '3.13'
|
||||
container: ''
|
||||
runs-on: ${{ matrix.os }}
|
||||
container: ${{ matrix.container }}
|
||||
|
||||
steps:
|
||||
# For manylinux2014 compatibility, we'll handle checkout differently
|
||||
- uses: actions/checkout@v4
|
||||
if: matrix.container == ''
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
# Manual checkout for containers to avoid Node.js compatibility issues
|
||||
- name: Manual checkout in container
|
||||
if: matrix.container != ''
|
||||
run: |
|
||||
# Install git if not available
|
||||
yum install -y git || true
|
||||
|
||||
# Configure git to handle the directory ownership issue
|
||||
git config --global --add safe.directory ${GITHUB_WORKSPACE}
|
||||
git config --global --add safe.directory /__w/LEANN/LEANN
|
||||
git config --global --add safe.directory /github/workspace
|
||||
git config --global --add safe.directory $(pwd)
|
||||
|
||||
# Clone the repository manually in the container
|
||||
git init
|
||||
git remote add origin https://github.com/${GITHUB_REPOSITORY}.git
|
||||
|
||||
# Fetch the appropriate ref
|
||||
if [ -n "${{ inputs.ref }}" ]; then
|
||||
git fetch --depth=1 origin ${{ inputs.ref }}
|
||||
else
|
||||
git fetch --depth=1 origin ${GITHUB_SHA}
|
||||
fi
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# Initialize submodules
|
||||
git submodule update --init --recursive
|
||||
|
||||
- name: Setup Python (macOS and regular Ubuntu)
|
||||
if: matrix.container == ''
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Setup Python (manylinux container)
|
||||
if: matrix.container != ''
|
||||
run: |
|
||||
# Use the pre-installed Python version in manylinux container
|
||||
# Convert Python version format (3.9 -> 39, 3.10 -> 310, etc.)
|
||||
PY_VER=$(echo "${{ matrix.python }}" | sed 's/\.//g')
|
||||
/opt/python/cp${PY_VER}-*/bin/python -m pip install --upgrade pip
|
||||
# Create symlinks for convenience
|
||||
ln -sf /opt/python/cp${PY_VER}-*/bin/python /usr/local/bin/python
|
||||
ln -sf /opt/python/cp${PY_VER}-*/bin/pip /usr/local/bin/pip
|
||||
|
||||
- name: Install uv (macOS and regular Ubuntu)
|
||||
if: matrix.container == ''
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install uv (manylinux container)
|
||||
if: matrix.container != ''
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install system dependencies (Ubuntu - regular)
|
||||
if: runner.os == 'Linux' && matrix.container == ''
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
- name: Install system dependencies (manylinux container)
|
||||
if: runner.os == 'Linux' && matrix.container != ''
|
||||
run: |
|
||||
# manylinux2014 uses yum instead of apt
|
||||
# Update yum cache first
|
||||
yum clean all
|
||||
yum makecache
|
||||
|
||||
# Install EPEL repository
|
||||
yum install -y epel-release || true
|
||||
|
||||
# Update cache again after EPEL
|
||||
yum makecache || true
|
||||
|
||||
# Install development packages
|
||||
# Note: Some packages might have different names in CentOS 7
|
||||
yum install -y \
|
||||
gcc-c++ \
|
||||
boost-devel \
|
||||
protobuf-compiler \
|
||||
protobuf-devel \
|
||||
zeromq-devel \
|
||||
pkgconfig \
|
||||
openblas-devel \
|
||||
cmake || {
|
||||
echo "Some packages failed to install, trying alternatives..."
|
||||
# Try alternative package names
|
||||
yum install -y libzmq3-devel || true
|
||||
yum install -y libzmq-devel || true
|
||||
}
|
||||
|
||||
# Install optional packages that might not be available
|
||||
yum install -y libaio-devel || echo "libaio-devel not available, continuing..."
|
||||
|
||||
# Verify zmq installation and create pkg-config file if needed
|
||||
if [ ! -f /usr/lib64/pkgconfig/libzmq.pc ] && [ ! -f /usr/lib/pkgconfig/libzmq.pc ]; then
|
||||
echo "Creating libzmq.pc file..."
|
||||
mkdir -p /usr/lib64/pkgconfig
|
||||
cat > /usr/lib64/pkgconfig/libzmq.pc << 'EOF'
|
||||
prefix=/usr
|
||||
exec_prefix=${prefix}
|
||||
libdir=${exec_prefix}/lib64
|
||||
includedir=${prefix}/include
|
||||
|
||||
Name: libzmq
|
||||
Description: ZeroMQ library
|
||||
Version: 4.1.4
|
||||
Libs: -L${libdir} -lzmq
|
||||
Cflags: -I${includedir}
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Update PKG_CONFIG_PATH
|
||||
echo "PKG_CONFIG_PATH=/usr/lib64/pkgconfig:/usr/lib/pkgconfig:$PKG_CONFIG_PATH" >> $GITHUB_ENV
|
||||
|
||||
# Build tools are pre-installed in manylinux
|
||||
# MKL is more complex in container, skip for now and use OpenBLAS
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
if [[ -n "${{ matrix.container }}" ]]; then
|
||||
# In manylinux container, use regular pip
|
||||
pip install scikit-build-core numpy swig Cython pybind11 auditwheel
|
||||
else
|
||||
# Regular environment, use uv
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Choose build command based on environment
|
||||
if [[ -n "${{ matrix.container }}" ]]; then
|
||||
BUILD_CMD="pip wheel . --no-deps -w dist"
|
||||
else
|
||||
BUILD_CMD="uv build --wheel --python python"
|
||||
fi
|
||||
|
||||
# Build core (platform independent)
|
||||
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
|
||||
cd packages/leann-core
|
||||
if [[ -n "${{ matrix.container }}" ]]; then
|
||||
pip wheel . --no-deps -w dist
|
||||
else
|
||||
uv build
|
||||
fi
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ $BUILD_CMD
|
||||
else
|
||||
eval $BUILD_CMD
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ $BUILD_CMD
|
||||
else
|
||||
eval $BUILD_CMD
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
|
||||
cd packages/leann
|
||||
if [[ -n "${{ matrix.container }}" ]]; then
|
||||
pip wheel . --no-deps -w dist
|
||||
else
|
||||
uv build
|
||||
fi
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
# Show what platform auditwheel will use
|
||||
auditwheel show dist/*.whl || true
|
||||
# Let auditwheel auto-detect the appropriate manylinux tag
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
# Show what platform auditwheel will use
|
||||
auditwheel show dist/*.whl || true
|
||||
# Let auditwheel auto-detect the appropriate manylinux tag
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}${{ matrix.container && '-manylinux' || '' }}
|
||||
path: packages/*/dist/
|
||||
118
.github/workflows/release-manual.yml
vendored
Normal file
118
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to release (e.g., 0.1.2)'
|
||||
required: true
|
||||
type: string
|
||||
use_cibuildwheel:
|
||||
description: 'Use cibuildwheel for better compatibility (recommended for Colab)'
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
name: Update Version
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid"
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
./scripts/bump_version.sh ${{ inputs.version }}
|
||||
git config user.name "GitHub Actions"
|
||||
git config user.email "actions@github.com"
|
||||
git add packages/*/pyproject.toml
|
||||
git commit -m "chore: release v${{ inputs.version }}"
|
||||
git push origin main
|
||||
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
|
||||
build-packages-reusable:
|
||||
name: Build packages (Standard)
|
||||
needs: update-version
|
||||
if: ${{ !inputs.use_cibuildwheel }}
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||
|
||||
build-packages-cibuildwheel:
|
||||
name: Build packages (cibuildwheel)
|
||||
needs: update-version
|
||||
if: ${{ inputs.use_cibuildwheel }}
|
||||
uses: ./.github/workflows/build-cibuildwheel.yml
|
||||
with:
|
||||
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
needs: [update-version, build-packages-reusable, build-packages-cibuildwheel]
|
||||
if: always() && needs.update-version.result == 'success' && (needs.build-packages-reusable.result == 'success' || needs.build-packages-cibuildwheel.result == 'success')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
if [ -z "$TWINE_PASSWORD" ]; then
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
git tag "v${{ inputs.version }}"
|
||||
git push origin "v${{ inputs.version }}"
|
||||
|
||||
gh release create "v${{ inputs.version }}" \
|
||||
--title "Release v${{ inputs.version }}" \
|
||||
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||
--latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
60
.github/workflows/test-manylinux.yml
vendored
Normal file
60
.github/workflows/test-manylinux.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Test Manylinux Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- '.github/workflows/**'
|
||||
- 'packages/**'
|
||||
- 'pyproject.toml'
|
||||
push:
|
||||
branches:
|
||||
- 'fix/manylinux-*'
|
||||
- 'test/build-*'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-cibuildwheel.yml
|
||||
|
||||
test-install:
|
||||
needs: build
|
||||
runs-on: ubuntu-22.04 # Simulating Colab environment
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12']
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: wheels-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
|
||||
- name: Test installation
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# Find and install the appropriate wheels
|
||||
pip install dist/leann_core-*.whl
|
||||
pip install dist/leann_backend_hnsw-*manylinux*.whl
|
||||
pip install dist/leann-*.whl
|
||||
|
||||
- name: Test import
|
||||
run: |
|
||||
python -c "
|
||||
import leann
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
print('Successfully imported leann modules')
|
||||
|
||||
# Quick functionality test
|
||||
builder = LeannBuilder(backend_name='hnsw')
|
||||
builder.add_text('Test document')
|
||||
print('LeannBuilder created and used successfully')
|
||||
"
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -12,7 +12,6 @@ outputs/
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
scripts/
|
||||
lm_eval.egg-info/
|
||||
demo/experiment_results/**/*.json
|
||||
*.jsonl
|
||||
@@ -84,4 +83,6 @@ test_*.py
|
||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
9
.vscode/extensions.json
vendored
9
.vscode/extensions.json
vendored
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"llvm-vs-code-extensions.vscode-clangd",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cmake-tools",
|
||||
"vadimcn.vscode-lldb",
|
||||
"eamodio.gitlens",
|
||||
]
|
||||
}
|
||||
283
.vscode/launch.json
vendored
283
.vscode/launch.json
vendored
@@ -1,283 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
// new emdedder
|
||||
{
|
||||
"name": "New Embedder",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--search",
|
||||
"--use-original",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--nprobe",
|
||||
"5000",
|
||||
"--load",
|
||||
"flat",
|
||||
"--embedder",
|
||||
"intfloat/multilingual-e5-small"
|
||||
]
|
||||
}
|
||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
||||
{
|
||||
"name": "main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--query",
|
||||
"1000",
|
||||
"--load",
|
||||
"bm25"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Simple Build",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/simple_build.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
//# Fix for Intel MKL error
|
||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
||||
//python faiss/demo/build_demo.py
|
||||
{
|
||||
"name": "Build Demo",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/build_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings",
|
||||
"--port",
|
||||
"8082",
|
||||
"--diskann-search-memory-maximum",
|
||||
"2",
|
||||
"--diskann-graph",
|
||||
"240",
|
||||
"--search-only"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
||||
},
|
||||
"preLaunchTask": "CMake: build",
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve MAC",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"ollama",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings"
|
||||
],
|
||||
"preLaunchTask": "CMake: build",
|
||||
"env": {
|
||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
||||
"KMP_BLOCKTIME": "0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "ric/main_ric.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--config-name",
|
||||
"${input:configSelection}"
|
||||
],
|
||||
"justMyCode": false
|
||||
},
|
||||
//python ./demo/validate_equivalence.py sglang
|
||||
{
|
||||
"name": "Validate Equivalence",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/validate_equivalence.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"sglang"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
||||
{
|
||||
"name": "Retrieval Demo",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"vllm",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
// "flat",
|
||||
"ivf_flat"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
||||
{
|
||||
"name": "Retrieval Demo DiskANN",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--hnsw-M",
|
||||
"64",
|
||||
"--hnsw-efConstruction",
|
||||
"150",
|
||||
"--hnsw-efSearch",
|
||||
"128",
|
||||
"--hnsw-sq-bits",
|
||||
"8"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Find Probe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "find_probe.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Python: Attach",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"processId": "${command:pickProcess}",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Edge RAG",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"edgerag_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Embedding Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/embedding_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--zmq-port",
|
||||
"5556",
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "HNSW Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--load",
|
||||
"hnsw",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--search",
|
||||
"--skip-pa",
|
||||
"--recompute",
|
||||
"--hnsw-old"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "configSelection",
|
||||
"type": "pickString",
|
||||
"description": "Select a configuration",
|
||||
"options": [
|
||||
"example_config",
|
||||
"vllm_gritlm"
|
||||
],
|
||||
"default": "example_config"
|
||||
}
|
||||
],
|
||||
}
|
||||
43
.vscode/settings.json
vendored
43
.vscode/settings.json
vendored
@@ -1,43 +0,0 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./sglang_repo/python"
|
||||
],
|
||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
||||
"cmake.configureArgs": [
|
||||
"-DPYBIND=True",
|
||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
||||
],
|
||||
"cmake.environment": {
|
||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
||||
},
|
||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
||||
"files.associations": {
|
||||
"*.tcc": "cpp",
|
||||
"deque": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"atomic": "cpp",
|
||||
"inplace_vector": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"list": "cpp",
|
||||
"any": "cpp",
|
||||
"system_error": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__tree": "cpp",
|
||||
"ios": "cpp",
|
||||
"set": "cpp",
|
||||
"__string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"ranges": "cpp",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"lldb.displayFormat": "auto",
|
||||
"lldb.showDisassembly": "auto",
|
||||
"lldb.dereferencePointers": true,
|
||||
"lldb.consoleMode": "commands",
|
||||
}
|
||||
16
.vscode/tasks.json
vendored
16
.vscode/tasks.json
vendored
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
||||
50
MANYLINUX_BUILD_STRATEGY.md
Normal file
50
MANYLINUX_BUILD_STRATEGY.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Manylinux Build Strategy
|
||||
|
||||
## Problem
|
||||
Google Colab requires wheels compatible with `manylinux_2_35_x86_64` or earlier. Our previous builds were producing `manylinux_2_39_x86_64` wheels, which are incompatible.
|
||||
|
||||
## Solution
|
||||
We're using `cibuildwheel` with `manylinux_2_35` images to build wheels that are compatible with Google Colab while maintaining modern toolchain features.
|
||||
|
||||
### Key Changes
|
||||
|
||||
1. **cibuildwheel Configuration**
|
||||
- Using `manylinux2014` images (provides `manylinux_2_17` compatibility)
|
||||
- Using `yum` package manager (CentOS 7 based)
|
||||
- Installing `cmake3` and creating symlink for compatibility
|
||||
|
||||
2. **Build Matrix**
|
||||
- Python versions: 3.9, 3.10, 3.11, 3.12, 3.13
|
||||
- Platforms: Linux (x86_64), macOS
|
||||
- No Windows support (not required)
|
||||
|
||||
3. **Dependencies**
|
||||
- Linux: gcc-c++, boost-devel, zeromq-devel, openblas-devel, cmake3
|
||||
- macOS: boost, zeromq, openblas, cmake (via Homebrew)
|
||||
|
||||
4. **Environment Variables**
|
||||
- `CMAKE_BUILD_PARALLEL_LEVEL=8`: Speed up builds
|
||||
- `Python_FIND_VIRTUALENV=ONLY`: Help CMake find Python in cibuildwheel env
|
||||
- `Python3_FIND_VIRTUALENV=ONLY`: Alternative variable for compatibility
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
1. **CI Pipeline**: `test-manylinux.yml`
|
||||
- Triggers on PR to main, manual dispatch, or push to `fix/manylinux-*` branches
|
||||
- Builds wheels using cibuildwheel
|
||||
- Tests installation on Ubuntu 22.04 (simulating Colab)
|
||||
|
||||
2. **Local Testing**
|
||||
```bash
|
||||
# Download built wheels
|
||||
# Test in fresh environment
|
||||
python -m venv test_env
|
||||
source test_env/bin/activate
|
||||
pip install leann_core-*.whl leann_backend_hnsw-*manylinux*.whl leann-*.whl
|
||||
python -c "from leann import LeannBuilder; print('Success!')"
|
||||
```
|
||||
|
||||
## References
|
||||
- [cibuildwheel documentation](https://cibuildwheel.readthedocs.io/)
|
||||
- [manylinux standards](https://github.com/pypa/manylinux)
|
||||
- [PEP 599 - manylinux2014](https://peps.python.org/pep-0599/)
|
||||
339
README.md
339
README.md
@@ -12,11 +12,11 @@
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
|
||||
## Why This Matters
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
@@ -38,8 +37,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Quick Start in 1 minute
|
||||
|
||||
## Installation
|
||||
> `pip leann` coming soon!
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
@@ -48,33 +47,30 @@ git submodule update --init --recursive
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
# Install uv first if you don't have it:
|
||||
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
**Linux:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Ollama Setup (Optional for Local LLM):**
|
||||
|
||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
|
||||
*macOS:*
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
|
||||
|
||||
**macOS:**
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
|
||||
@@ -83,7 +79,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
*Linux:*
|
||||
**Linux:**
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -95,62 +91,78 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||
## Quick Start in 30s
|
||||
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
[Try in this ipynb file →](demo.ipynb) [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# 1. Build index (no embeddings stored!)
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||
print(results)
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
```
|
||||
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
## RAG on Everything!
|
||||
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
## Wild Things You Can Do
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
### Process Any Documents (.pdf, .txt, .md)
|
||||
<p align="center">
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
|
||||
### Search Your Entire Life
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
```bash
|
||||
python examples/mail_reader_leann.py
|
||||
# "What did my boss say about the Christmas party last year?"
|
||||
# "Find all emails from my mom about birthday plans"
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||
```
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -183,13 +195,16 @@ Once the index is built, you can ask questions like:
|
||||
- "Show me emails about travel expenses"
|
||||
</details>
|
||||
|
||||
### Time Machine for the Web
|
||||
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py
|
||||
# "What was that AI paper I read last month?"
|
||||
# "Show me all the cooking videos I watched"
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -238,13 +253,17 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### WeChat Detective
|
||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py
|
||||
# "Show me all group chats about weekend plans"
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
@@ -255,7 +274,13 @@ First, you need to install the WeChat exporter:
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -290,6 +315,73 @@ Once the index is built, you can ask questions like:
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||
- Smart text chunking with overlap
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `~/.leann/indexes/`
|
||||
- Support for advanced search parameters
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact Use compact storage (default: true)
|
||||
--recompute Enable recomputation (default: true)
|
||||
```
|
||||
|
||||
**Search Command:**
|
||||
```bash
|
||||
leann search INDEX_NAME QUERY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute-embeddings Use recomputation for highest accuracy
|
||||
--pruning-strategy {global,local,proportional}
|
||||
```
|
||||
|
||||
**Ask Command:**
|
||||
```bash
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
<p align="center">
|
||||
@@ -308,18 +400,17 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Run the comparison yourself:
|
||||
```bash
|
||||
python examples/compare_faiss_vs_leann.py
|
||||
```
|
||||
|
||||
| System | Storage |
|
||||
|--------|---------|
|
||||
| FAISS HNSW | 5.5 MB |
|
||||
| LEANN | 0.5 MB |
|
||||
| **Savings** | **91%** |
|
||||
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||
### Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||
|
||||
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
@@ -329,33 +420,7 @@ python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR datase
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run.
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||
|
||||
<!-- ### Memory Usage Comparison
|
||||
|
||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||
| **Reduction** | **x%** | **x%** | **x%** |
|
||||
|
||||
### Query Performance of LEANN
|
||||
|
||||
| Backend | Index Size | Query Time | Recall@3 |
|
||||
| ------------------- | ---------- | ---------- | --------- |
|
||||
| DiskANN | 1M docs | xms | 0.95 |
|
||||
| HNSW | 1M docs | xms | 0.95 | -->
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
@@ -374,87 +439,15 @@ If you find Leann useful, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ Features
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
### Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
## 🤝 [Contributing →](docs/contributing.md)
|
||||
|
||||
|
||||
<!-- ## ❓ FAQ
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### NCCL Topology Error
|
||||
|
||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||
|
||||
```
|
||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||
```
|
||||
|
||||
**Solution**: Set these environment variables before running your script:
|
||||
|
||||
```bash
|
||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
``` -->
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
### 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
### 🚀 Q3 2025
|
||||
## [FAQ →](docs/faq.md)
|
||||
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
### 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
|
||||
## 📄 License
|
||||
|
||||
@@ -462,11 +455,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
- **Microsoft Research** for the DiskANN algorithm
|
||||
- **Meta AI** for FAISS and optimization insights
|
||||
- **HuggingFace** for the transformer ecosystem
|
||||
- **Our amazing contributors** who make this possible
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
322
demo.ipynb
322
demo.ipynb
@@ -1,35 +1,321 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||
"# 1. Build index (no embeddings stored!)\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
||||
"print(results)\n",
|
||||
"# install this if you areusing colab\n",
|
||||
"! pip install leann"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.21s] Read levels (5)\n",
|
||||
"[0.30s] Probing for compact storage flag...\n",
|
||||
"[0.30s] Found compact flag: False\n",
|
||||
"[0.30s] Compact flag is False, reading original format...\n",
|
||||
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.30s] Read offsets (6)\n",
|
||||
"[0.40s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.40s] Read neighbors (320)\n",
|
||||
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.50s] Checking for storage data...\n",
|
||||
"[0.50s] Found storage fourcc: 49467849.\n",
|
||||
"[0.50s] Converting to CSR format...\n",
|
||||
"[0.50s] Conversion loop finished. \n",
|
||||
"[0.50s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.69s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with real-time embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
||||
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
||||
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with LEANN using retrieved results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
||||
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"\n",
|
||||
"llm_config = {\n",
|
||||
" \"type\": \"hf\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||
" top_k=2,\n",
|
||||
" recompute_beighbor_embeddings=True,\n",
|
||||
" llm_kwargs={\"max_tokens\": 128}\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Release Guide
|
||||
|
||||
## Setup (One-time)
|
||||
|
||||
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||
1. Get token: https://pypi.org/manage/account/token/
|
||||
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||
|
||||
## Release (One-click)
|
||||
|
||||
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||
2. Click "Run workflow"
|
||||
3. Enter version: `0.1.2`
|
||||
4. Click green "Run workflow" button
|
||||
|
||||
That's it! The workflow will automatically:
|
||||
- ✅ Update version in all packages
|
||||
- ✅ Build all packages
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
11
docs/contributing.md
Normal file
11
docs/contributing.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# FAQ
|
||||
|
||||
## 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
22
docs/features.md
Normal file
22
docs/features.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# ✨ Detailed Features
|
||||
|
||||
## 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# 📈 Roadmap
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
## 🚀 Q3 2025
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
## 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
@@ -135,6 +135,7 @@ def test_leann_hnsw():
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
print(f"Total number of chunks: {len(all_texts)}")
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
|
||||
@@ -96,14 +96,12 @@ class EmlxReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def main():
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu`")
|
||||
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
|
||||
@@ -65,12 +65,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -78,7 +80,9 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
@@ -218,14 +222,15 @@ async def query_leann_index(index_path: str, query: str):
|
||||
"max_tokens": 1000
|
||||
}
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||
parser.add_argument('--index-dir', type=str, default="./google_history_index",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
|
||||
@@ -74,22 +74,17 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[BROWSING HISTORY METADATA]
|
||||
URL: {url}
|
||||
Title: {title}
|
||||
Last Visit: {last_visit}
|
||||
Visit Count: {visit_count}
|
||||
Typed Count: {typed_count}
|
||||
Hidden: {hidden}
|
||||
[END METADATA]
|
||||
|
||||
Title: {title}
|
||||
URL: {url}
|
||||
Last visited: {last_visit}
|
||||
[Title]: {title}
|
||||
[URL of the page]: {url}
|
||||
[Last visited time]: {last_visit}
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -335,14 +335,15 @@ class WeChatHistoryReader(BaseReader):
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%H:%M:%S')
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
sender = "Me" if is_sent_from_self else "Contact"
|
||||
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
@@ -354,13 +355,11 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
return doc_content, contact_name
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
@@ -441,8 +440,8 @@ Contact: {contact_name}
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_mail_path():
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
# Default mail path for macOS
|
||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
@@ -85,9 +85,11 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
@@ -156,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -216,22 +218,22 @@ async def query_leann_index(index_path: str, query: str):
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=12,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Leann: {chat_response}")
|
||||
# print(f"Time taken: {end_time - start_time} seconds")
|
||||
# highlight the answer
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
@@ -251,6 +253,9 @@ async def main():
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import shutil
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
@@ -58,22 +58,19 @@ async def main(args):
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
query = args.query
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -105,6 +102,18 @@ if __name__ == "__main__":
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||
help="The query to ask the Leann chat system.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -86,10 +86,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
@@ -224,7 +225,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=128,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
@@ -233,7 +234,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -252,13 +253,13 @@ async def main():
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_june19_test",
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=5000,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,8 +1,36 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
# Find Python - scikit-build-core should provide this
|
||||
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
|
||||
|
||||
# Print Python information for debugging
|
||||
message(STATUS "Python_FOUND: ${Python_FOUND}")
|
||||
message(STATUS "Python_VERSION: ${Python_VERSION}")
|
||||
message(STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
|
||||
message(STATUS "Python_INCLUDE_DIRS: ${Python_INCLUDE_DIRS}")
|
||||
|
||||
# Pass Python information to DiskANN through cache variables
|
||||
set(Python_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python executable" FORCE)
|
||||
set(Python_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python include dirs" FORCE)
|
||||
set(Python_LIBRARIES ${Python_LIBRARIES} CACHE FILEPATH "Python libraries" FORCE)
|
||||
set(Python_VERSION ${Python_VERSION} CACHE STRING "Python version" FORCE)
|
||||
set(Python_FOUND ${Python_FOUND} CACHE BOOL "Python found" FORCE)
|
||||
|
||||
# Also set Python3 variables for compatibility
|
||||
set(Python3_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python3 executable" FORCE)
|
||||
set(Python3_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python3 include dirs" FORCE)
|
||||
set(Python3_LIBRARIES ${Python_LIBRARIES} CACHE FILEPATH "Python3 libraries" FORCE)
|
||||
set(Python3_VERSION ${Python_VERSION} CACHE STRING "Python3 version" FORCE)
|
||||
set(Python3_FOUND ${Python_FOUND} CACHE BOOL "Python3 found" FORCE)
|
||||
set(Python3_Development_FOUND TRUE CACHE BOOL "Python3 development found" FORCE)
|
||||
|
||||
# Set Python finding strategy
|
||||
set(Python_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
|
||||
set(Python3_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
|
||||
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
add_subdirectory(third_party/DiskANN)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import contextlib
|
||||
import pickle
|
||||
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
@@ -14,6 +16,46 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cpp_output_if_needed():
|
||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
|
||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
if not should_suppress:
|
||||
# Don't suppress, just yield
|
||||
yield
|
||||
return
|
||||
|
||||
# Save original file descriptors
|
||||
stdout_fd = sys.stdout.fileno()
|
||||
stderr_fd = sys.stderr.fileno()
|
||||
|
||||
# Save original stdout/stderr
|
||||
stdout_dup = os.dup(stdout_fd)
|
||||
stderr_dup = os.dup(stderr_fd)
|
||||
|
||||
try:
|
||||
# Redirect to /dev/null
|
||||
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||
os.dup2(devnull, stdout_fd)
|
||||
os.dup2(devnull, stderr_fd)
|
||||
os.close(devnull)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore original file descriptors
|
||||
os.dup2(stdout_dup, stdout_fd)
|
||||
os.dup2(stderr_dup, stderr_fd)
|
||||
os.close(stdout_dup)
|
||||
os.close(stderr_dup)
|
||||
|
||||
|
||||
def _get_diskann_metrics():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -65,22 +107,20 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
if metric_enum is None:
|
||||
raise ValueError("Unsupported distance_metric.")
|
||||
raise ValueError(
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -102,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
os.remove(temp_data_file)
|
||||
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||
|
||||
|
||||
class DiskannSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
super().__init__(
|
||||
index_path,
|
||||
backend_module_name="leann_backend_diskann.embedding_server",
|
||||
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||
**kwargs,
|
||||
)
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
self.zmq_port,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -142,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
@@ -161,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
- "global": Use global pruning strategy (default)
|
||||
- "local": Use local pruning strategy
|
||||
- "proportional": Not supported in DiskANN, falls back to global
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||
@@ -169,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(
|
||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||
)
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
raise NotImplementedError(
|
||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||
)
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
@@ -194,28 +241,26 @@ class DiskannSearcher(BaseSearcher):
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
use_recompute,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||
"""
|
||||
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
# Import protobuf after ensuring the path is correct
|
||||
try:
|
||||
from . import embedding_pb2
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import protobuf module: {e}")
|
||||
return
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# REP socket receives single-part messages
|
||||
message = socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if len(message) == 0:
|
||||
logger.debug("Received empty message, sending empty response")
|
||||
socket.send(b"") # REP socket must respond to every request
|
||||
continue
|
||||
|
||||
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||
|
||||
e2e_start = time.time()
|
||||
|
||||
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||
texts = []
|
||||
node_ids = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
if not node_ids:
|
||||
raise RuntimeError(
|
||||
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||
)
|
||||
except Exception as protobuf_error:
|
||||
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
# For BaseSearcher compatibility, request is a list of texts directly
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
raise RuntimeError(
|
||||
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||
)
|
||||
|
||||
# Look up texts by node IDs (only if not direct text request)
|
||||
if not is_text_request:
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(
|
||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||
) # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For BaseSearcher compatibility: return msgpack format
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DiskANN Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
"--passages-file",
|
||||
type=str,
|
||||
help="Metadata JSON file containing passage sources",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the DiskANN embedding server
|
||||
create_diskann_embedding_server(
|
||||
passages_file=args.passages_file,
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
)
|
||||
@@ -1,741 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
self._meta_path = ''
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def keys(self):
|
||||
return self.passages_data.keys()
|
||||
|
||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages using metadata file with PassageManager for lazy loading
|
||||
"""
|
||||
# Load metadata to get passage sources
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Import PassageManager dynamically to avoid circular imports
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Find the leann package directory relative to this file
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.api import PassageManager
|
||||
passage_manager = PassageManager(meta['passage_sources'])
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Load label map
|
||||
passages_dir = Path(meta_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager, label_map):
|
||||
self.passage_manager = passage_manager
|
||||
self.label_map = label_map
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
if int_id in self.label_map:
|
||||
string_id = self.label_map[int_id]
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_map)
|
||||
|
||||
def keys(self):
|
||||
return self.label_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager, label_map)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and run embedding server in the current thread
|
||||
This function is designed to be called in a separate thread
|
||||
"""
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# Check if port is already occupied
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "sentence-transformers":
|
||||
# Initialize model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# Select device
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# Optimize model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
# Check if it's a metadata file or a single passages file
|
||||
if passages_file.endswith('.meta.json'):
|
||||
passages = load_passages_from_metadata(passages_file)
|
||||
else:
|
||||
# Try to find metadata file in same directory
|
||||
passages_dir = Path(passages_file).parent
|
||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||
if meta_files:
|
||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||
else:
|
||||
# Fallback to original single file loading (will cause warnings)
|
||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""Device timer"""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
elapsed = self.elapsed_time()
|
||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
||||
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process text batch"""
|
||||
if not texts_batch:
|
||||
return np.array([])
|
||||
|
||||
# Filter out empty texts and their corresponding IDs
|
||||
valid_texts = []
|
||||
valid_ids = []
|
||||
for i, text in enumerate(texts_batch):
|
||||
if text.strip(): # Only include non-empty texts
|
||||
valid_texts.append(text)
|
||||
valid_ids.append(ids_batch[i])
|
||||
|
||||
if not valid_texts:
|
||||
print("WARNING: No valid texts in batch")
|
||||
return np.array([])
|
||||
|
||||
# Tokenize
|
||||
token_timer = DeviceTimer("tokenization")
|
||||
with token_timer.timing():
|
||||
inputs = tokenizer(
|
||||
valid_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
# Compute embeddings
|
||||
embed_timer = DeviceTimer("embedding computation")
|
||||
with embed_timer.timing():
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Mean pooling
|
||||
attention_mask = inputs['attention_mask']
|
||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
batch_embeddings = sum_embeddings / sum_mask
|
||||
embed_timer.print_elapsed()
|
||||
|
||||
return batch_embeddings.cpu().numpy()
|
||||
|
||||
# ZMQ server main loop - modified to use REP socket
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||
|
||||
# Set timeouts
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
||||
|
||||
from . import embedding_pb2
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
|
||||
# --- Restore robust message format detection ---
|
||||
# Must check parts length to avoid IndexError
|
||||
if len(parts) >= 3:
|
||||
identity = parts[0]
|
||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
||||
message = parts[2]
|
||||
elif len(parts) == 2:
|
||||
# Can also handle cases without empty frame
|
||||
identity = parts[0]
|
||||
message = parts[1]
|
||||
else:
|
||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||
continue
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
# Handle control messages (MessagePack format)
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message)
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||
response = [current_meta_path]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||
# Update the server's meta path and reload passages
|
||||
new_meta_path = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||
# Reload passages from the new meta file
|
||||
passages = load_passages_from_metadata(new_meta_path)
|
||||
# Store the meta path for future queries
|
||||
passages._meta_path = new_meta_path
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update meta path: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__QUERY_MODEL__":
|
||||
# Return the current model being used by the server
|
||||
response = [model_name]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||
# Update the server's embedding model
|
||||
new_model_name = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||
|
||||
# Clean up old model to free memory
|
||||
if not use_mlx:
|
||||
print("INFO: Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||
|
||||
# Load new model
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||
|
||||
# Optimize new model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# Now safely delete old model after new one is loaded
|
||||
del old_model
|
||||
del old_tokenizer
|
||||
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update model: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
except:
|
||||
# Not a control message, continue with normal protobuf processing
|
||||
pass
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup")
|
||||
|
||||
# Parse request
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = req_proto.node_ids
|
||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||
|
||||
# Add debug information
|
||||
if len(node_ids) > 0:
|
||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||
|
||||
# Look up texts
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
else:
|
||||
# If text is empty, we still need a placeholder for batch processing,
|
||||
# but record its ID as missing
|
||||
texts.append("")
|
||||
missing_ids.append(nid)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# Process batch
|
||||
total_size = len(texts)
|
||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
||||
else: # sentence-transformers
|
||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
if embedding_mode == "mlx":
|
||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
hidden = compute_embeddings_openai(texts, model_name)
|
||||
else: # sentence-transformers
|
||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||
|
||||
# Serialize response
|
||||
ser_start = time.time()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
resp_proto.missing_ids.extend(missing_ids)
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# REP socket sends a single response
|
||||
socket.send_multipart([identity, b'', response_data])
|
||||
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
try:
|
||||
# Send empty response to maintain REQ-REP state
|
||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||
socket.send(empty_resp.SerializeToString())
|
||||
except:
|
||||
# If sending fails, recreate socket
|
||||
socket.close()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
print("INFO: ZMQ socket recreated after error")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to start embedding server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
load_embeddings=False,
|
||||
use_fp16=True,
|
||||
use_int8=False,
|
||||
use_cuda_graphs=False,
|
||||
zmq_port=5555,
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
||||
choices=["sentence-transformers", "mlx", "openai"],
|
||||
help="Embedding backend mode")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle backward compatibility with use_mlx
|
||||
embedding_mode = args.embedding_mode
|
||||
if args.use_mlx:
|
||||
embedding_mode = "mlx"
|
||||
|
||||
create_embedding_server(
|
||||
domain=args.domain,
|
||||
load_passages=args.load_passages,
|
||||
load_embeddings=args.load_embeddings,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
@@ -4,16 +4,21 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
version = "0.1.8"
|
||||
dependencies = ["leann-core==0.1.8", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# 关键:简化的 CMake 路径
|
||||
# Key: simplified CMake path
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
# Key: Python package in root directory, paths match exactly
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# 使用默认的 redirect 模式
|
||||
# Use default redirect mode
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
wheel.exclude = ["CMakeLists.txt", "src", "third_party/**", "*.o", "*.so"]
|
||||
sdist.include = ["CMakeLists.txt", "src", "third_party", "leann_backend_diskann/*.txt"]
|
||||
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...81666b7758
@@ -1,6 +1,7 @@
|
||||
# 最终简化版
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
set(CMAKE_C_COMPILER_WORKS 1)
|
||||
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
@@ -11,15 +12,9 @@ if(APPLE)
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
@@ -29,6 +24,39 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
# Find Python for our own use (not for Faiss)
|
||||
if(DEFINED SKBUILD)
|
||||
message(STATUS "Building with scikit-build")
|
||||
# scikit-build-core provides Python information
|
||||
endif()
|
||||
|
||||
# Find Python - scikit-build-core should provide this
|
||||
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module NumPy)
|
||||
|
||||
# Print Python information for debugging
|
||||
message(STATUS "Python_FOUND: ${Python_FOUND}")
|
||||
message(STATUS "Python_VERSION: ${Python_VERSION}")
|
||||
message(STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
|
||||
message(STATUS "Python_INCLUDE_DIRS: ${Python_INCLUDE_DIRS}")
|
||||
message(STATUS "Python_NumPy_INCLUDE_DIRS: ${Python_NumPy_INCLUDE_DIRS}")
|
||||
|
||||
# Pass Python information to faiss through cache variables
|
||||
set(Python_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python executable" FORCE)
|
||||
set(Python_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python include dirs" FORCE)
|
||||
set(Python_NumPy_INCLUDE_DIRS ${Python_NumPy_INCLUDE_DIRS} CACHE PATH "NumPy include dirs" FORCE)
|
||||
set(Python_VERSION ${Python_VERSION} CACHE STRING "Python version" FORCE)
|
||||
set(Python_FOUND ${Python_FOUND} CACHE BOOL "Python found" FORCE)
|
||||
|
||||
# Also set Python3 variables for compatibility
|
||||
set(Python3_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python3 executable" FORCE)
|
||||
set(Python3_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python3 include dirs" FORCE)
|
||||
set(Python3_NumPy_INCLUDE_DIRS ${Python_NumPy_INCLUDE_DIRS} CACHE PATH "NumPy include dirs" FORCE)
|
||||
set(Python3_VERSION ${Python_VERSION} CACHE STRING "Python3 version" FORCE)
|
||||
set(Python3_FOUND ${Python_FOUND} CACHE BOOL "Python3 found" FORCE)
|
||||
set(Python3_Development_FOUND TRUE CACHE BOOL "Python3 development found" FORCE)
|
||||
set(Python3_NumPy_FOUND TRUE CACHE BOOL "Python3 NumPy found" FORCE)
|
||||
|
||||
# Faiss configuration - streamlined build
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
@@ -36,4 +64,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
|
||||
# Disable additional SIMD versions to speed up compilation
|
||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# Additional optimization options from INSTALL.md
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||
|
||||
# Avoid building demos and benchmarks
|
||||
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# NEW: Tell Faiss to only build the generic version
|
||||
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# Force Faiss to use our Python settings
|
||||
set(Python_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
|
||||
set(Python3_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
@@ -1,10 +1,9 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal
|
||||
import pickle
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import shutil
|
||||
import time
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -16,6 +15,8 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_metric_map():
|
||||
from . import faiss # type: ignore
|
||||
@@ -57,13 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -85,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
@@ -94,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
logger.info("✅ CSR conversion successful.")
|
||||
# index_file_old = index_file.with_suffix(".old")
|
||||
# shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
print(
|
||||
logger.info(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
else:
|
||||
@@ -135,31 +132,22 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
||||
hnsw_config.is_recompute = (
|
||||
self.is_pruned
|
||||
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load label mapping
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||
|
||||
with open(label_map_file, "rb") as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -177,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
- "global": Use global PQ queue size for selection (default)
|
||||
- "local": Local pruning, sort and select best candidates
|
||||
- "proportional": Base selection on new neighbor count ratio
|
||||
zmq_port: ZMQ port for embedding server
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||
|
||||
@@ -186,15 +174,14 @@ class HNSWSearcher(BaseSearcher):
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings or self.is_pruned
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
if not recompute_embeddings:
|
||||
if self.is_pruned:
|
||||
raise RuntimeError("Recompute is required for pruned index.")
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -202,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.normalize_L2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
params.zmq_port = zmq_port
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = (
|
||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
)
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
@@ -239,11 +229,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,13 +6,24 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.0"
|
||||
version = "0.1.8"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
dependencies = [
|
||||
"leann-core==0.1.8",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.scikit-build]
|
||||
wheel.packages = ["leann_backend_hnsw"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
wheel.exclude = ["CMakeLists.txt", "src", "third_party"]
|
||||
sdist.include = ["CMakeLists.txt", "src", "third_party", "leann_backend_hnsw/*.txt"]
|
||||
cmake.args = ["-DCMAKE_BUILD_TYPE=Release"]
|
||||
# Ensure CMake can find system libraries
|
||||
build-dir = "build/{cache_tag}"
|
||||
minimum-version = "build-system.requires"
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
SKBUILD_SOABI = "YES"
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...553f937220
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -4,16 +4,27 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.0"
|
||||
description = "Core API and plugin system for Leann."
|
||||
version = "0.1.8"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
dependencies = [
|
||||
"numpy>=1.20.0",
|
||||
"tqdm>=4.60.0"
|
||||
"tqdm>=4.60.0",
|
||||
"psutil>=5.8.0",
|
||||
"pyzmq>=23.0.0,<27", # Cap at 26.x for manylinux2014 compatibility
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -5,16 +5,18 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
@@ -22,7 +24,8 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -39,251 +42,63 @@ def compute_embeddings(
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
# Override mode for backward compatibility
|
||||
if use_mlx:
|
||||
mode = "mlx"
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
if port is None:
|
||||
raise ValueError("port is required when use_server is True")
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
# Use direct computation (for build_index)
|
||||
from .embedding_compute import (
|
||||
compute_embeddings as compute_embeddings_direct,
|
||||
)
|
||||
|
||||
return compute_embeddings_direct(
|
||||
chunks,
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
def compute_embeddings_via_server(
|
||||
chunks: List[str], model_name: str, port: int
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
||||
) from e
|
||||
|
||||
# Load model using sentence-transformers
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
|
||||
# Generate embeddings
|
||||
# give use an warning if OOM here means we need to turn down the batch size
|
||||
embeddings = model.encode(
|
||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
||||
)
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using OpenAI API."""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(chunks), max_batch_size)
|
||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
||||
except ImportError:
|
||||
# Fallback without progress bar
|
||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
id: str
|
||||
@@ -299,25 +114,24 @@ class PassageManager:
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
|
||||
for source in passage_sources:
|
||||
if source["type"] == "jsonl":
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Passage index file not found: {index_file}"
|
||||
)
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
@@ -328,7 +142,7 @@ class LeannBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever-msmarco",
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
@@ -344,14 +158,12 @@ class LeannBuilder:
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.backend_kwargs = backend_kwargs
|
||||
if 'mlx' in self.embedding_model:
|
||||
self.embedding_mode = "mlx"
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||
self.chunks.append(chunk_data)
|
||||
|
||||
@@ -377,10 +189,13 @@ class LeannBuilder:
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||
|
||||
chunk_iterator = tqdm(
|
||||
self.chunks, desc="Writing passages", unit="chunk"
|
||||
)
|
||||
except ImportError:
|
||||
chunk_iterator = self.chunks
|
||||
|
||||
|
||||
for chunk in chunk_iterator:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
@@ -398,7 +213,11 @@ class LeannBuilder:
|
||||
pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
@@ -472,7 +291,7 @@ class LeannBuilder:
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
)
|
||||
|
||||
@@ -480,7 +299,7 @@ class LeannBuilder:
|
||||
if len(self.chunks) != len(ids):
|
||||
# If no text chunks provided, create placeholder text entries
|
||||
if not self.chunks:
|
||||
print("No text chunks provided, creating placeholder entries...")
|
||||
logger.info("No text chunks provided, creating placeholder entries...")
|
||||
for id_val in ids:
|
||||
self.add_text(
|
||||
f"Document {id_val}",
|
||||
@@ -555,15 +374,19 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
logger.info(
|
||||
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||
)
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
@@ -571,16 +394,15 @@ class LeannSearcher:
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -589,26 +411,39 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
expected_zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
print(f" Query: '{query}'")
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
logger.info("🔍 LeannSearcher.search() called:")
|
||||
logger.info(f" Query: '{query}'")
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
if recompute_embeddings:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
del expected_zmq_port
|
||||
zmq_time = time.time() - start_time
|
||||
logger.info(f" Launching server time: {zmq_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
print(f" Embedding time: {embedding_time} seconds")
|
||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
results = self.backend_impl.search(
|
||||
@@ -623,14 +458,14 @@ class LeannSearcher:
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
print(f" Search time: {search_time} seconds")
|
||||
print(
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
print(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
):
|
||||
@@ -644,15 +479,25 @@ class LeannSearcher:
|
||||
metadata=passage_data.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
print(
|
||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
||||
|
||||
# Color codes for better logging
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Truncate text for display (first 100 chars)
|
||||
display_text = passage_data['text']
|
||||
logger.info(
|
||||
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||
)
|
||||
except KeyError:
|
||||
print(
|
||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
||||
RED = "\033[91m"
|
||||
logger.error(
|
||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||
)
|
||||
|
||||
print(f" Final enriched results: {len(enriched_results)} passages")
|
||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||
return enriched_results
|
||||
|
||||
|
||||
@@ -674,15 +519,15 @@ class LeannChat:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
llm_kwargs = {}
|
||||
|
||||
search_time = time.time()
|
||||
results = self.searcher.search(
|
||||
question,
|
||||
top_k=top_k,
|
||||
@@ -691,9 +536,11 @@ class LeannChat:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
zmq_port=zmq_port,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
import os
|
||||
import difflib
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||
"""Check if a model exists in Ollama's remote library and return available tags
|
||||
|
||||
Returns:
|
||||
(model_exists, available_tags): bool and list of matching tags
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
import re
|
||||
|
||||
# Split model name and tag
|
||||
if ':' in model_name:
|
||||
base_model, requested_tag = model_name.split(':', 1)
|
||||
else:
|
||||
base_model, requested_tag = model_name, None
|
||||
|
||||
# First check if base model exists in library
|
||||
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||
if library_response.status_code != 200:
|
||||
return True, [] # Assume exists if can't check
|
||||
|
||||
# Extract model names from library page
|
||||
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||
|
||||
if base_model not in models_in_library:
|
||||
return False, [] # Base model doesn't exist
|
||||
|
||||
# If base model exists, get available tags
|
||||
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||
if tags_response.status_code != 200:
|
||||
return True, [] # Base model exists but can't get tags
|
||||
|
||||
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||
|
||||
# Clean up tags - remove HTML artifacts and duplicates
|
||||
available_tags = []
|
||||
seen = set()
|
||||
for tag in raw_tags:
|
||||
# Skip if it looks like HTML (contains < or >)
|
||||
if '<' in tag or '>' in tag:
|
||||
continue
|
||||
if tag not in seen:
|
||||
seen.add(tag)
|
||||
available_tags.append(tag)
|
||||
|
||||
# Check if exact model exists
|
||||
if requested_tag is None:
|
||||
# User just requested base model, suggest tags
|
||||
return True, available_tags[:10] # Return up to 10 tags
|
||||
else:
|
||||
exact_match = model_name in available_tags
|
||||
return exact_match, available_tags[:10]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If scraping fails, assume model might exist (don't block user)
|
||||
return True, []
|
||||
|
||||
|
||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||
"""Use intelligent fuzzy search for Ollama models"""
|
||||
if not available_models:
|
||||
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
if available_models and model_name not in available_models:
|
||||
# Use intelligent fuzzy search based on locally installed models
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
error_msg += "\nTo list all models: ollama list"
|
||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
||||
# Check if the model exists remotely and get available tags
|
||||
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||
|
||||
if model_exists_remotely and model_name in available_tags:
|
||||
# Exact model exists remotely - suggest pulling it
|
||||
error_msg += f"\n\nTo install the requested model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n"
|
||||
|
||||
# Show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
elif model_exists_remotely and available_tags:
|
||||
# Base model exists but requested tag doesn't - suggest correct tags
|
||||
base_model = model_name.split(':')[0]
|
||||
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||
|
||||
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||
for i, tag in enumerate(available_tags[:8], 1):
|
||||
error_msg += f" {i}. ollama pull {tag}\n"
|
||||
if len(available_tags) > 8:
|
||||
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||
|
||||
# Also show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
else:
|
||||
# Model doesn't exist remotely - show fuzzy suggestions
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
error_msg += "\n\nCommands:"
|
||||
error_msg += "\n ollama list # List installed models"
|
||||
if model_exists_remotely and available_tags:
|
||||
if model_name in available_tags:
|
||||
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||
else:
|
||||
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||
return error_msg
|
||||
|
||||
elif llm_type == "hf":
|
||||
@@ -375,8 +480,9 @@ class OllamaChat(LLMInterface):
|
||||
"stream": False, # Keep it simple for now
|
||||
"options": kwargs,
|
||||
}
|
||||
logger.info(f"Sending request to Ollama: {payload}")
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||
response = requests.post(full_url, data=json.dumps(payload))
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -396,7 +502,7 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models."""
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
@@ -407,7 +513,7 @@ class HFChat(LLMInterface):
|
||||
raise ValueError(model_error)
|
||||
|
||||
try:
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -416,54 +522,101 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Auto-detect device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
self.device = "cuda"
|
||||
logger.info("CUDA is available. Using GPU.")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
self.device = "mps"
|
||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||
else:
|
||||
device = "cpu"
|
||||
self.device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
||||
if "max_tokens" in kwargs:
|
||||
# Prefer user-provided max_new_tokens if both are present
|
||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
||||
# Remove the unsupported key to avoid errors in Transformers
|
||||
kwargs.pop("max_tokens")
|
||||
print('kwargs in HF: ', kwargs)
|
||||
# Check if this is a Qwen model and add /no_think by default
|
||||
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||
|
||||
# For Qwen models, automatically add /no_think to the prompt
|
||||
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||
prompt = prompt + " /no_think"
|
||||
|
||||
# Prepare chat template
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Apply chat template if available
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
try:
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||
formatted_prompt = prompt
|
||||
else:
|
||||
# Fallback for models without chat template
|
||||
formatted_prompt = prompt
|
||||
|
||||
# Handle temperature=0 edge-case for greedy decoding
|
||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
||||
# Remove unsupported zero temperature and use deterministic generation
|
||||
kwargs.pop("temperature")
|
||||
kwargs.setdefault("do_sample", False)
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
if self.device != "cpu":
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Sensible defaults for text generation
|
||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||
results = self.pipeline(prompt, **params)
|
||||
# Set generation parameters
|
||||
generation_config = {
|
||||
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"top_p": kwargs.get("top_p", 0.9),
|
||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
# Handle temperature=0 for greedy decoding
|
||||
if generation_config["temperature"] == 0.0:
|
||||
generation_config["do_sample"] = False
|
||||
generation_config.pop("temperature")
|
||||
|
||||
# Handle different response formats from transformers
|
||||
if isinstance(results, list) and len(results) > 0:
|
||||
generated_text = (
|
||||
results[0].get("generated_text", "")
|
||||
if isinstance(results[0], dict)
|
||||
else str(results[0])
|
||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**generation_config
|
||||
)
|
||||
else:
|
||||
generated_text = str(results)
|
||||
|
||||
# Extract only the newly generated portion by removing the original prompt
|
||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||
response = generated_text[len(prompt) :].strip()
|
||||
else:
|
||||
# Fallback: return the full response if prompt removal fails
|
||||
response = str(generated_text)
|
||||
|
||||
return response
|
||||
# Decode response
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
class OpenAIChat(LLMInterface):
|
||||
|
||||
315
packages/leann-core/src/leann/cli.py
Normal file
315
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
return meta_file.exists()
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="leann",
|
||||
description="LEANN - Local Enhanced AI Navigation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
""",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, required=True, help="Documents directory"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-model", type=str, default="facebook/contriever"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild"
|
||||
)
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("index_name", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5)
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
)
|
||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
|
||||
if not index_dirs:
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(
|
||||
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||
) / (1024 * 1024)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print(f"\nUsage:")
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
is_recompute=args.recompute,
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
print(f" {result.text[:200]}...")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
index_name = args.index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Starting chat with index '{index_name}'...")
|
||||
print(f"Using {args.model} ({args.llm})")
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "list":
|
||||
self.list_indexes()
|
||||
elif args.command == "build":
|
||||
await self.build_index(args)
|
||||
elif args.command == "search":
|
||||
await self.search_documents(args)
|
||||
elif args.command == "ask":
|
||||
await self.ask_questions(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
def main():
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
cli = LeannCLI()
|
||||
asyncio.run(cli.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
375
packages/leann-core/src/leann/embedding_compute.py
Normal file
375
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
Unified embedding computation module
|
||||
Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
is_build: bool = False,
|
||||
batch_size: int = 32,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
batch_size: Batch size for processing
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
batch_size=batch_size,
|
||||
adaptive_optimization=adaptive_optimization,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
is_build: bool = False,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
"""
|
||||
# Handle empty input
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Apply optimizations based on benchmark results
|
||||
if adaptive_optimization:
|
||||
# Use optimal batch_size constants for different devices based on benchmark results
|
||||
if device == "mps":
|
||||
batch_size = 128 # MPS optimal batch size from benchmark
|
||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||
batch_size = 32
|
||||
elif device == "cuda":
|
||||
batch_size = 256 # CUDA optimal batch size
|
||||
# Keep original batch_size for CPU
|
||||
|
||||
# Create cache key
|
||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||
|
||||
# Check if model is already cached
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached optimized model: {model_name}")
|
||||
model = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
||||
)
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Apply hardware optimizations
|
||||
if device == "cuda":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.cuda.set_per_process_memory_fraction(0.9)
|
||||
elif device == "mps":
|
||||
try:
|
||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||
torch.mps.set_per_process_memory_fraction(0.9)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Some MPS optimizations not available in this PyTorch version"
|
||||
)
|
||||
elif device == "cpu":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||
try:
|
||||
torch.backends.mkldnn.enabled = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Prepare optimized model and tokenizer parameters
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True,
|
||||
"attn_implementation": "eager", # Use eager attention for speed
|
||||
}
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Try local loading first
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations based on mode
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
logger.info(f"Applied FP16 precision: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"FP16 optimization failed: {e}")
|
||||
|
||||
# Apply torch.compile optimization
|
||||
if device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
|
||||
logger.info(f"Applied torch.compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"torch.compile optimization failed: {e}")
|
||||
|
||||
# Set model to eval mode and disable gradients for inference
|
||||
model.eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Cache the model
|
||||
_model_cache[cache_key] = model
|
||||
logger.info(f"Model cached: {cache_key}")
|
||||
|
||||
# Compute embeddings with optimized inference mode
|
||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||
|
||||
# Use torch.inference_mode for optimal performance
|
||||
with torch.inference_mode():
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
# Validate results
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
raise RuntimeError(
|
||||
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(texts), max_batch_size)
|
||||
batch_iterator = tqdm(
|
||||
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback when tqdm is not available
|
||||
batch_iterator = range(0, len(texts), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_texts = texts[i : i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(
|
||||
chunks: List[str], model_name: str, batch_size: int = 16
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Cache MLX model and tokenizer
|
||||
cache_key = f"mlx_{model_name}"
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached MLX model: {model_name}")
|
||||
model, tokenizer = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(f"Loading and caching MLX model: {model_name}")
|
||||
model, tokenizer = load(model_name)
|
||||
_model_cache[cache_key] = (model, tokenizer)
|
||||
logger.info(f"MLX model cached: {cache_key}")
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
batch_iterator = tqdm(
|
||||
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||
)
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
@@ -1,14 +1,21 @@
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import zmq
|
||||
import msgpack
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import select
|
||||
import psutil
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format="%(levelname)s - %(name)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
@@ -17,151 +24,135 @@ def _check_port(port: int) -> bool:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
||||
def _check_process_matches_config(
|
||||
port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct meta file.
|
||||
Returns True if the server has the right meta path, False otherwise.
|
||||
Check if the process using the port matches our expected model and passages file.
|
||||
Returns True if matches, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||
if not _is_process_listening_on_port(proc, port):
|
||||
continue
|
||||
|
||||
# Send a special control message to query the server's meta path
|
||||
control_request = ["__QUERY_META_PATH__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
cmdline = proc.info["cmdline"]
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the meta path and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_meta_path = response[0]
|
||||
# Normalize paths for comparison
|
||||
expected_path = Path(expected_meta_path).resolve()
|
||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
||||
return server_path == expected_path
|
||||
return _check_cmdline_matches_config(
|
||||
cmdline, port, expected_model, expected_passages_file
|
||||
)
|
||||
|
||||
logger.debug(f"No process found listening on port {port}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
||||
logger.warning(f"Could not check process on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's meta path.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||
"""Check if a process is listening on the given port."""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the meta path
|
||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
connections = proc.net_connections()
|
||||
for conn in connections:
|
||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return False
|
||||
|
||||
|
||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
||||
def _check_cmdline_matches_config(
|
||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""Check if command line matches our expected configuration."""
|
||||
cmdline_str = " ".join(cmdline)
|
||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||
|
||||
# Check if it's our embedding server
|
||||
is_embedding_server = any(
|
||||
server_type in cmdline_str
|
||||
for server_type in [
|
||||
"embedding_server",
|
||||
"leann_backend_diskann.embedding_server",
|
||||
"leann_backend_hnsw.hnsw_embedding_server",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_embedding_server:
|
||||
logger.debug(f"Process on port {port} is not our embedding server")
|
||||
return False
|
||||
|
||||
# Check model name
|
||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||
|
||||
# Check passages file if provided
|
||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||
|
||||
result = model_matches and passages_matches
|
||||
logger.debug(
|
||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||
"""Check if the command line contains the expected model."""
|
||||
if "--model-name" not in cmdline:
|
||||
return False
|
||||
|
||||
model_idx = cmdline.index("--model-name")
|
||||
if model_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_model = cmdline[model_idx + 1]
|
||||
return actual_model == expected_model
|
||||
|
||||
|
||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||
"""Check if the command line contains the expected passages file."""
|
||||
if "--passages-file" not in cmdline:
|
||||
return False # Expected but not found
|
||||
|
||||
passages_idx = cmdline.index("--passages-file")
|
||||
if passages_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_passages = cmdline[passages_idx + 1]
|
||||
expected_path = Path(expected_passages_file).resolve()
|
||||
actual_path = Path(actual_passages).resolve()
|
||||
return actual_path == expected_path
|
||||
|
||||
|
||||
def _find_compatible_port_or_next_available(
|
||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct embedding model.
|
||||
Returns True if the server has the right model, False otherwise.
|
||||
Find a port that either has a compatible server or is available.
|
||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if not _check_port(port):
|
||||
# Port is available
|
||||
return port, False
|
||||
|
||||
# Send a special control message to query the server's model
|
||||
control_request = ["__QUERY_MODEL__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
# Port is in use, check if it's compatible
|
||||
if _check_process_matches_config(port, model_name, passages_file):
|
||||
logger.info(f"Found compatible server on port {port}")
|
||||
return port, True
|
||||
else:
|
||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the model name and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_model = response[0]
|
||||
return server_model == expected_model
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_model(port: int, new_model: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's embedding model.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the model
|
||||
control_request = ["__UPDATE_MODEL__", new_model]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
||||
return False
|
||||
raise RuntimeError(
|
||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
"""
|
||||
|
||||
def __init__(self, backend_module_name: str):
|
||||
@@ -175,246 +166,183 @@ class EmbeddingServerManager:
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
atexit.register(self.stop_server)
|
||||
self._atexit_registered = False
|
||||
|
||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
||||
def start_server(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Starts the embedding server process.
|
||||
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
port (int): The preferred ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
**kwargs: Additional arguments for the server.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
tuple[bool, int]: (success, actual_port_used)
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
# Even if we have a running process, check if model/meta path match
|
||||
if self.server_port is not None:
|
||||
port_in_use = _check_port(self.server_port)
|
||||
if port_in_use:
|
||||
print(
|
||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
||||
)
|
||||
passages_file = kwargs.get("passages_file")
|
||||
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(self.server_port, model_name):
|
||||
print(
|
||||
"❌ Failed to update existing server model. Restarting server..."
|
||||
)
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
print(
|
||||
f"✅ Successfully updated existing server model to: {model_name}"
|
||||
)
|
||||
# Check if we have a compatible running server
|
||||
if self._has_compatible_running_server(model_name, passages_file):
|
||||
assert self.server_port is not None, (
|
||||
"a compatible running server should set server_port"
|
||||
)
|
||||
return True, self.server_port
|
||||
|
||||
# Also check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
# Find available port (compatible or free)
|
||||
try:
|
||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||
port, model_name, passages_file
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(str(e))
|
||||
return False, port
|
||||
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
# No port stored - restart
|
||||
print("⚠️ No port information stored. Restarting server...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
if is_compatible:
|
||||
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||
self.server_port = actual_port
|
||||
self.server_process = None # We don't own this process
|
||||
return True, actual_port
|
||||
|
||||
if _check_port(port):
|
||||
# Port is in use, check if it's using the correct meta file and model
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if actual_port != port:
|
||||
logger.info(f"Using port {actual_port} instead of {port}")
|
||||
|
||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
||||
# Start new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(port, model_name):
|
||||
raise RuntimeError(
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
def _has_compatible_running_server(
|
||||
self, model_name: str, passages_file: str
|
||||
) -> bool:
|
||||
"""Check if we have a compatible running server."""
|
||||
if not (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
):
|
||||
return False
|
||||
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
||||
if not meta_matches:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
||||
)
|
||||
if not _update_server_meta_path(port, str(passages_file)):
|
||||
raise RuntimeError(
|
||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
||||
)
|
||||
print(
|
||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
||||
)
|
||||
|
||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||
logger.info(
|
||||
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
)
|
||||
return True
|
||||
|
||||
print(
|
||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
||||
logger.info(
|
||||
"Existing server process is incompatible. Should start a new server."
|
||||
)
|
||||
return False
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
# Add extra arguments for specific backends
|
||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1, # Line buffered
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print(
|
||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
||||
)
|
||||
self._print_recent_output()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(
|
||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
||||
)
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
|
||||
def _print_recent_output(self):
|
||||
"""Print any recent output from the server process."""
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
return
|
||||
try:
|
||||
# Read any available output
|
||||
def _build_server_command(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> list:
|
||||
"""Build the command to start the embedding server."""
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||
output = self.server_process.stdout.read()
|
||||
if output:
|
||||
print(f"[{self.backend_module_name} OUTPUT]: {output}")
|
||||
except Exception as e:
|
||||
print(f"Error reading server output: {e}")
|
||||
if kwargs.get("passages_file"):
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
|
||||
def _log_monitor(self):
|
||||
"""Monitors and prints the server's stdout and stderr."""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
while True:
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
print(
|
||||
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
|
||||
# Let server output go directly to console
|
||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
if not self._atexit_registered:
|
||||
# Use a lambda to avoid issues with bound methods
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready."""
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
logger.info("Embedding server is ready!")
|
||||
return True, port
|
||||
|
||||
if self.server_process and self.server_process.poll() is not None:
|
||||
logger.error("Server terminated during startup.")
|
||||
return False, port
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False, port
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(
|
||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
||||
if not self.server_process:
|
||||
return
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
# Process already terminated
|
||||
self.server_process = None
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
"WARNING: Server process did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
self.server_process.kill()
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.server_process = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Literal
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
|
||||
|
||||
class LeannBackendBuilderInterface(ABC):
|
||||
@@ -34,6 +34,13 @@ class LeannBackendSearcherInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
@@ -44,7 +51,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -57,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -67,7 +74,10 @@ class LeannBackendSearcherInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
@@ -7,30 +7,37 @@ import importlib.metadata
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
|
||||
|
||||
def register_backend(name: str):
|
||||
"""A decorator to register a new backend class."""
|
||||
|
||||
def decorator(cls):
|
||||
print(f"INFO: Registering backend '{name}'")
|
||||
BACKEND_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def autodiscover_backends():
|
||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||
print("INFO: Starting backend auto-discovery...")
|
||||
# print("INFO: Starting backend auto-discovery...")
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata['name']
|
||||
if dist_name.startswith('leann-backend-'):
|
||||
backend_module_name = dist_name.replace('-', '_')
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||
|
||||
for backend_module_name in sorted(
|
||||
discovered_backends
|
||||
): # sort for deterministic loading
|
||||
try:
|
||||
importlib.import_module(backend_module_name)
|
||||
# Registration message is printed by the decorator
|
||||
except ImportError as e:
|
||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
print("INFO: Backend auto-discovery finished.")
|
||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
pass
|
||||
# print("INFO: Backend auto-discovery finished.")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal
|
||||
from typing import Dict, Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -43,10 +42,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||
)
|
||||
|
||||
self.label_map = self._load_label_map()
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name
|
||||
backend_module_name=backend_module_name,
|
||||
)
|
||||
|
||||
def _load_meta(self) -> Dict[str, Any]:
|
||||
@@ -58,17 +57,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_label_map(self) -> Dict[int, str]:
|
||||
"""Loads the mapping from integer IDs to string IDs."""
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
with open(label_map_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: int, **kwargs
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""
|
||||
Ensures the embedding server is running if recompute is needed.
|
||||
This is a helper for subclasses.
|
||||
@@ -78,21 +69,26 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
||||
)
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
raise RuntimeError(
|
||||
f"Failed to start embedding server on port {actual_port}"
|
||||
)
|
||||
|
||||
return actual_port
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -106,12 +102,21 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
if use_server_if_available:
|
||||
try:
|
||||
# TODO: Maybe we can directly use this port here?
|
||||
# For this internal method, it's ok to assume that the server is running
|
||||
# on that port?
|
||||
|
||||
# Ensure we have a server with passages_file for compatibility
|
||||
passages_source_file = (
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
# Convert to absolute path to ensure server can find it
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file.resolve()), zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
@@ -120,7 +125,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .api import compute_embeddings
|
||||
from .embedding_compute import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
@@ -167,7 +172,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -181,7 +186,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||
|
||||
Returns:
|
||||
|
||||
40
packages/leann/README.md
Normal file
40
packages/leann/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# LEANN - The smallest vector index in the world
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Default installation (HNSW backend, recommended)
|
||||
uv pip install leann
|
||||
|
||||
# With DiskANN backend (for large-scale deployments)
|
||||
uv pip install leann[diskann]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.build_index("my_index.leann")
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher("my_index.leann")
|
||||
results = searcher.search("storage savings", top_k=3)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||
response = chat.ask("How much storage does LEANN save?")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||
|
||||
A revolutionary vector database that democratizes personal AI.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Re-export main API from leann-core
|
||||
from leann_core import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]
|
||||
42
packages/leann/pyproject.toml
Normal file
42
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.8"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "LEANN Team" }
|
||||
]
|
||||
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw
|
||||
dependencies = [
|
||||
"leann-core>=0.1.0",
|
||||
"leann-backend-hnsw>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diskann = [
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/leann"
|
||||
Documentation = "https://leann.readthedocs.io"
|
||||
Repository = "https://github.com/yourusername/leann"
|
||||
Issues = "https://github.com/yourusername/leann/issues"
|
||||
@@ -33,8 +33,9 @@ dependencies = [
|
||||
"msgpack>=1.1.1",
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
"mlx>=0.26.3",
|
||||
"mlx-lm>=0.26.0",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"psutil>=5.8.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -59,3 +60,29 @@ py-modules = []
|
||||
leann-core = { path = "packages/leann-core", editable = true }
|
||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||
|
||||
[tool.cibuildwheel]
|
||||
# Skip 32-bit and PyPy builds
|
||||
skip = "*-win32 *-manylinux_i686 pp* *musllinux*"
|
||||
|
||||
# Use manylinux_2_35 for Colab compatibility while keeping modern features
|
||||
manylinux-x86_64-image = "manylinux_2_35"
|
||||
manylinux-aarch64-image = "manylinux_2_35"
|
||||
|
||||
# Linux system dependencies
|
||||
[tool.cibuildwheel.linux]
|
||||
before-all = """
|
||||
yum install -y epel-release
|
||||
yum install -y gcc-c++ boost-devel zeromq-devel openblas-devel cmake3 python3-devel
|
||||
ln -sf /usr/bin/cmake3 /usr/bin/cmake
|
||||
"""
|
||||
|
||||
# macOS system dependencies
|
||||
[tool.cibuildwheel.macos]
|
||||
before-all = "brew install boost zeromq openblas cmake libomp"
|
||||
# Set minimum macOS version
|
||||
environment = { MACOSX_DEPLOYMENT_TARGET = "11.0", CMAKE_OSX_DEPLOYMENT_TARGET = "11.0" }
|
||||
|
||||
# Environment variables configuration
|
||||
[tool.cibuildwheel.environment]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import faiss
|
||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(hnsw_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
|
||||
# save_degree_distribution
|
||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
||||
@@ -1,11 +0,0 @@
|
||||
import faiss
|
||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(nsg_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
||||
|
||||
# save degree distribution
|
||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
||||
@@ -1,63 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
# import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
M = 2048
|
||||
N = 2048
|
||||
|
||||
bsz = 2048
|
||||
import torch_int
|
||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
||||
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(M, N),
|
||||
# nn.Linear(2048, 2048)
|
||||
)
|
||||
|
||||
int8_model = nn.Sequential(
|
||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
||||
)
|
||||
|
||||
int8_model.load_state_dict(fp16_model.state_dict())
|
||||
int8_model = int8_model.to(0) # Quantization happens here
|
||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
||||
|
||||
# Create random input tensor
|
||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int8_time
|
||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
||||
@@ -1,89 +0,0 @@
|
||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 45 KiB |
@@ -1,594 +0,0 @@
|
||||
# python embedd_micro.py --use_int8 Fastest
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao import quantize_
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_int4: bool = False
|
||||
use_int8: bool = False # Add this parameter
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
# Create quantization config
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Try xformers if available
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
# Int8 quantization using HuggingFace integration
|
||||
# Int8 quantization using TorchAO
|
||||
elif self.config.use_int8:
|
||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
||||
|
||||
# Import the quantize_ function and the quantization config
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||
print("- Successfully imported TorchAO")
|
||||
|
||||
# Load model normally first
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
print("- Model loaded in full precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply quantization - call the function to get the config, then apply it
|
||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
# For older PyTorch versions that have issues with tensor subclasses
|
||||
from torchao.utils import unwrap_tensor_subclass
|
||||
import torch
|
||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
||||
unwrap_tensor_subclass(model)
|
||||
|
||||
# Apply optimizations
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
# For better performance with int8 dynamic quantization
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
print("- Enabled fusion of int matmul with mul operations")
|
||||
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Reset peak memory stats
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = (
|
||||
self.cuda_graphs.get_or_create(batch_size)
|
||||
if self.cuda_graphs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
if i == 0: # Only print on first run
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
# Log memory usage
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
action="store_true",
|
||||
help="Enable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int4",
|
||||
action="store_true",
|
||||
help="Enable INT4 quantization using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
action="store_true",
|
||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_linear8bitlt",
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int4=args.use_int4,
|
||||
use_int8=args.use_int8, # Add this line
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,376 +0,0 @@
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
import math
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
||||
|
||||
if effective_batch_size not in self.graphs:
|
||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
||||
self.model, effective_batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[effective_batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16:
|
||||
model = model.half()
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
# No need to do anything as it's automatically enabled
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Optimize LayerNorm
|
||||
try:
|
||||
num_layernorms = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.LayerNorm):
|
||||
module.forward = torch.jit.script(module.forward)
|
||||
num_layernorms += 1
|
||||
if num_layernorms > 0:
|
||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
||||
except Exception as e:
|
||||
print(f"- LayerNorm optimization failed: {e}")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model = self._load_model()
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
return ModelOptimizer.optimize(model, self.config)
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
original_batch_size = input_ids.shape[0]
|
||||
print(f"Original input_ids shape: {input_ids.shape}")
|
||||
|
||||
# Split large batches to avoid OOM
|
||||
max_batch_size = self.config.max_batch_size
|
||||
if original_batch_size > max_batch_size:
|
||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
||||
total_time = 0
|
||||
outputs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, original_batch_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, original_batch_size)
|
||||
batch_slice = input_ids[i:end_idx]
|
||||
mask_slice = attention_mask[i:end_idx]
|
||||
|
||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
||||
|
||||
# Use CUDA graph if available (with the smaller batch size)
|
||||
chunk_cuda_graph = None
|
||||
if cuda_graph_wrapper is not None:
|
||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
||||
|
||||
with self.timer.timing():
|
||||
if chunk_cuda_graph is not None:
|
||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
||||
else:
|
||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
||||
|
||||
total_time += self.timer.elapsed_time()
|
||||
outputs.append(chunk_output.last_hidden_state)
|
||||
|
||||
# Combine outputs
|
||||
combined_output = torch.cat(outputs, dim=0)
|
||||
print(f"Combined output shape: {combined_output.shape}")
|
||||
|
||||
# Create a wrapper object similar to model output to maintain consistency
|
||||
class DummyOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.last_hidden_state = hidden_states
|
||||
|
||||
output = DummyOutput(combined_output)
|
||||
return total_time, output
|
||||
else:
|
||||
# Process normally for small batches
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = None
|
||||
if self.cuda_graphs is not None:
|
||||
if batch_size <= self.config.max_batch_size:
|
||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
||||
else:
|
||||
# For large batches, we'll use the max_batch_size graph in chunks
|
||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
# Run benchmark
|
||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
times.append(elapsed_time)
|
||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_fp16",
|
||||
action="store_true",
|
||||
help="Disable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_batch_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum batch size before splitting to prevent OOM",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=not args.no_fp16,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
max_batch_size=args.max_batch_size,
|
||||
)
|
||||
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Print overall summary
|
||||
print("\n===== BENCHMARK SUMMARY =====")
|
||||
print(f"Model: {config.model_path}")
|
||||
print(f"Sequence Length: {config.seq_length}")
|
||||
print(f"FP16: {config.use_fp16}")
|
||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
||||
print(f"Flash Attention: {config.use_flash_attention}")
|
||||
print(f"Max Batch Size: {config.max_batch_size}")
|
||||
print("\nResults:")
|
||||
|
||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
||||
print("-" * 50)
|
||||
for bs in sorted(results.keys()):
|
||||
r = results[bs]
|
||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,218 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Import necessary functions from the quantize.py file
|
||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
||||
# needed for GPTQ with padding
|
||||
if groupsize > w.shape[-1]:
|
||||
groupsize = w.shape[-1]
|
||||
assert groupsize > 1
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
||||
torch.bfloat16
|
||||
).reshape(w.shape[0], -1)
|
||||
|
||||
def pack_scales_and_zeros(scales, zeros):
|
||||
assert scales.shape == zeros.shape
|
||||
assert scales.dtype == torch.bfloat16
|
||||
assert zeros.dtype == torch.bfloat16
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
||||
],
|
||||
2,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
||||
return w_int32, scales_and_zeros
|
||||
|
||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
||||
assert groupsize > 1
|
||||
# needed for GPTQ single column quantize
|
||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
||||
groupsize = w.shape[-1]
|
||||
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
w_int32 = (
|
||||
to_quant.sub(min_val)
|
||||
.div(scales)
|
||||
.round()
|
||||
.clamp_(min_int, max_int)
|
||||
.to(torch.int32)
|
||||
.reshape_as(w)
|
||||
)
|
||||
|
||||
return w_int32
|
||||
|
||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
||||
weight_bf16, n_bit=4, groupsize=groupsize
|
||||
)
|
||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
||||
return weight_int4pack, scales_and_zeros
|
||||
|
||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
||||
origin_x_size = x.size()
|
||||
x = x.reshape(-1, origin_x_size[-1])
|
||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
||||
new_shape = origin_x_size[:-1] + (out_features,)
|
||||
c = c.reshape(new_shape)
|
||||
return c
|
||||
|
||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int,
|
||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
|
||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales_and_zeros",
|
||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
input = input.to(torch.bfloat16)
|
||||
return linear_forward_int4(
|
||||
input,
|
||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
||||
)
|
||||
|
||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
||||
# in_features must be divisible by inner_k_tiles * 16
|
||||
# out_features must be divisible by 8
|
||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
||||
out_features = 2048 # Must be divisible by 8
|
||||
groupsize = 128
|
||||
inner_k_tiles = 8
|
||||
|
||||
# Create models
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(in_features, out_features, bias=False)
|
||||
)
|
||||
|
||||
# Create INT4 model
|
||||
int4_model = nn.Sequential(
|
||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
||||
)
|
||||
|
||||
# Quantize the weights and set up the INT4 model
|
||||
with torch.no_grad():
|
||||
# Convert FP16 weights to INT4
|
||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
||||
fp16_weight, groupsize, inner_k_tiles
|
||||
)
|
||||
|
||||
# Set the quantized weights in the INT4 model
|
||||
int4_model[0].weight.copy_(weight_int4pack)
|
||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
||||
|
||||
# Move models to GPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
fp16_model = fp16_model.to(device)
|
||||
int4_model = int4_model.to(device)
|
||||
|
||||
# Create random input tensor
|
||||
batch_size = 1024
|
||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
||||
|
||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int4_time
|
||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
||||
|
||||
# Calculate memory savings
|
||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
||||
|
||||
memory_reduction = fp16_memory / int4_memory
|
||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
||||
|
||||
# Check accuracy
|
||||
with torch.no_grad():
|
||||
fp16_output = fp16_model(input_tensor_bf16)
|
||||
int4_output = int4_model(input_tensor)
|
||||
|
||||
# Calculate error metrics
|
||||
abs_error = torch.abs(fp16_output - int4_output)
|
||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
||||
|
||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
||||
@@ -1,83 +0,0 @@
|
||||
import torch
|
||||
import nvmath.bindings.cublas
|
||||
import ctypes
|
||||
|
||||
# 创建 CUBLAS 句柄
|
||||
handle = nvmath.bindings.cublas.create()
|
||||
|
||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
||||
|
||||
# 确保张量在 CUDA 上
|
||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
||||
# 确保张量是连续的
|
||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
||||
|
||||
# 获取指针
|
||||
a_ptr = a.data_ptr()
|
||||
b_ptr = b.data_ptr()
|
||||
c_ptr = c.data_ptr()
|
||||
|
||||
# 设置参数
|
||||
transa = 0 # CUBLAS_OP_N (不转置)
|
||||
transb = 0 # CUBLAS_OP_N (不转置)
|
||||
transc = 0 # CUBLAS_OP_N (不转置)
|
||||
|
||||
# 设置偏置值
|
||||
a_bias = 0
|
||||
b_bias = 0
|
||||
c_bias = 0
|
||||
|
||||
# 设置正确的 leading dimensions
|
||||
lda = k # A 的 leading dimension
|
||||
ldb = n # B 的 leading dimension
|
||||
ldc = n # C 的 leading dimension
|
||||
|
||||
c_mult = 1
|
||||
c_shift = 0
|
||||
|
||||
# 打印调试信息
|
||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
||||
|
||||
try:
|
||||
# 调用 uint8gemm_bias
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr, a_bias, lda,
|
||||
b_ptr, b_bias, ldb,
|
||||
c_ptr, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
# 尝试使用 ctypes 转换指针
|
||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
||||
|
||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
||||
|
||||
# 再次尝试调用
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr_c, a_bias, lda,
|
||||
b_ptr_c, b_bias, ldb,
|
||||
c_ptr_c, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
|
||||
# 销毁 CUBLAS 句柄
|
||||
nvmath.bindings.cublas.destroy(handle)
|
||||
|
||||
# 打印结果
|
||||
print("Result:")
|
||||
print(c)
|
||||
@@ -1,23 +0,0 @@
|
||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from llmcompressor import oneshot
|
||||
|
||||
# Select quantization algorithm. In this case, we:
|
||||
# * apply SmoothQuant to make the activations easier to quantize
|
||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
||||
# * quantize the activations to int8 (dynamic per token)
|
||||
recipe = [
|
||||
SmoothQuantModifier(smoothing_strength=0.8),
|
||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
||||
]
|
||||
|
||||
# Apply quantization using the built in open_platypus dataset.
|
||||
# * See examples for demos showing how to pass a custom calibration set
|
||||
oneshot(
|
||||
model="facebook/contriever",
|
||||
dataset="open_platypus",
|
||||
recipe=recipe,
|
||||
output_dir="contriever-INT4",
|
||||
max_seq_length=2048,
|
||||
num_calibration_samples=512,
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
||||
|
||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
||||
scales are used to dequantize input operands and quantize the result. Without proper
|
||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
||||
|
||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
||||
capability 8.9 or higher.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import nvmath
|
||||
|
||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
||||
# transpose it.
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
||||
|
||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
||||
|
||||
# Perform the multiplication. The result of the multiplication will be:
|
||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
||||
|
||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
||||
print(result_without_scaling)
|
||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
||||
print(result)
|
||||
@@ -1,58 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from pathlib import Path
|
||||
|
||||
def save_model_in_pth_format(model_name, output_dir):
|
||||
"""
|
||||
Download a model from Hugging Face and save it in PTH format
|
||||
for use with quantization benchmarks.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model on Hugging Face
|
||||
output_dir: Directory to save the model
|
||||
"""
|
||||
print(f"Loading model {model_name}...")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Extract and save the model weights in PTH format
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
# Save the model weights
|
||||
model_path = Path(output_dir) / "model.pth"
|
||||
torch.save(model_state_dict, model_path)
|
||||
|
||||
print(f"Model saved to {model_path}")
|
||||
|
||||
# Print model size information
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
||||
|
||||
print(f"Model parameters: {param_count:,}")
|
||||
print(f"Model size: {model_size_mb:.2f} MB")
|
||||
|
||||
return model_path
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use a small model for testing
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
output_dir = "./tinyllama-1.1b-chat"
|
||||
|
||||
model_path = save_model_in_pth_format(model_name, output_dir)
|
||||
|
||||
print("\nYou can now use this model with the INT4 benchmark script.")
|
||||
print("Example command:")
|
||||
print(f"python int4benchmark.py --model_path {model_path}")
|
||||
@@ -1,677 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "cab91cfc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import copy\n",
|
||||
"import dataclasses\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import pathlib\n",
|
||||
"import itertools\n",
|
||||
"import multiprocessing\n",
|
||||
"import scipy\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import pickle\n",
|
||||
"import gzip\n",
|
||||
"import threading\n",
|
||||
"import queue\n",
|
||||
"import pytz\n",
|
||||
"import traceback\n",
|
||||
"from datetime import datetime\n",
|
||||
"from tqdm.auto import tqdm, trange\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as mtick\n",
|
||||
"%matplotlib inline\n",
|
||||
"%config InlineBackend.figure_format='retina'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8d24fbd7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sat Apr 12 00:10:05 2025 \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
|
||||
"|-----------------------------------------+------------------------+----------------------+\n",
|
||||
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
||||
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
||||
"| | | MIG M. |\n",
|
||||
"|=========================================+========================+======================|\n",
|
||||
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
|
||||
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
|
||||
"| | | N/A |\n",
|
||||
"+-----------------------------------------+------------------------+----------------------+\n",
|
||||
" \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| Processes: |\n",
|
||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
||||
"| ID ID Usage |\n",
|
||||
"|=========================================================================================|\n",
|
||||
"| No running processes found |\n",
|
||||
"+-----------------------------------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "538b2c11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
|
||||
" latency = []\n",
|
||||
" \n",
|
||||
" # First run, ignore min_secs\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" \n",
|
||||
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
|
||||
" min_nanos = int(min_secs * 1e9)\n",
|
||||
" start_nanos = time.perf_counter_ns()\n",
|
||||
" while True:\n",
|
||||
" now_nanos = time.perf_counter_ns()\n",
|
||||
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
|
||||
" break\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" return np.array(latency)\n",
|
||||
"\n",
|
||||
"def tail_mean(xs, skip=0.2):\n",
|
||||
" return xs[int(len(xs) * skip):].mean()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "02c9c9b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"torch.set_grad_enabled(False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3405fdc7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
|
||||
"seqlen_list = [256]\n",
|
||||
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "10dc981a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[(12, 256), (3, 256)]\n",
|
||||
"[256]\n",
|
||||
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(nd_list)\n",
|
||||
"print(seqlen_list)\n",
|
||||
"print(bs_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7e0ee385",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" seqlen_list = [1] + seqlen_list\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" maxbs = max(bs_list)\n",
|
||||
" print(maxbs, n, d, seqlen)\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" def run():\n",
|
||||
" torch.matmul(X[:bs], W)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, X, W\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c206a502",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a3a2103c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "3aaad98a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = {}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "18137de3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/22 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_init\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "26c76e15",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_ar\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "313e36eb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"dense\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "50c37959",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(data, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "828ddb54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_dense = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"dense\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"dense\")\n",
|
||||
")\n",
|
||||
"df_qk_init = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_init\")\n",
|
||||
")\n",
|
||||
"df_qk_ar = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_ar\")\n",
|
||||
")\n",
|
||||
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "c296a395",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a25cdd5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63b8a531",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af90eff1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
|
||||
" return transformers.OPTConfig(\n",
|
||||
" num_hidden_layers=n_layers,\n",
|
||||
" hidden_size=d_model,\n",
|
||||
" ffn_dim=d_model*4,\n",
|
||||
" num_attention_heads=n_heads,\n",
|
||||
" **kwargs\n",
|
||||
" )\n",
|
||||
"optcfg = {\n",
|
||||
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
|
||||
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
|
||||
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
|
||||
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
|
||||
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
|
||||
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
|
||||
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
|
||||
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
|
||||
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
|
||||
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
|
||||
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
|
||||
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
|
||||
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b9ebbec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
|
||||
" bs, tgt_len = input_ids.shape\n",
|
||||
" if past_key_values is not None:\n",
|
||||
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
|
||||
" assert bs == _bs\n",
|
||||
" else:\n",
|
||||
" src_len = 0\n",
|
||||
" if attention_mask is None:\n",
|
||||
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
|
||||
" ret = model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" past_key_values=past_key_values,\n",
|
||||
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
|
||||
" )\n",
|
||||
" return ret\n",
|
||||
"\n",
|
||||
"def time_greedy_generate(model, input_ids, new_tokens):\n",
|
||||
" ts = []\n",
|
||||
" output = input_ids\n",
|
||||
" past_key_values = None\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
|
||||
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
|
||||
" for _ in range(new_tokens):\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" \n",
|
||||
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
|
||||
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
|
||||
" output = torch.cat([output, input_ids], axis=1)\n",
|
||||
" past_key_values = ret.past_key_values\n",
|
||||
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
|
||||
" \n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" ts.append((ed-st)/1e9)\n",
|
||||
" return np.array(ts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc92f940",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"opt_config = optcfg[\"6.7b\"]\n",
|
||||
"\n",
|
||||
"torch.set_default_dtype(torch.bfloat16)\n",
|
||||
"with transformers.modeling_utils.no_init_weights():\n",
|
||||
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
|
||||
"torch.set_default_dtype(torch.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c19fa396",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = {}\n",
|
||||
"input_tokens = 200\n",
|
||||
"new_tokens = 500\n",
|
||||
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
|
||||
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
|
||||
" stack = []\n",
|
||||
" for _ in range(10):\n",
|
||||
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
|
||||
" stack.append(l)\n",
|
||||
" db[bs] = np.median(np.stack(stack), axis=0)\n",
|
||||
" del x\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
"del model\n",
|
||||
"torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(db, f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Set plot parameters
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# Path settings
|
||||
FIGURE_PATH = "./paper_plot/figures"
|
||||
|
||||
# Load accuracy data
|
||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
||||
|
||||
# Create figure with 4 subplots (one for each dataset)
|
||||
fig, axs = plt.subplots(1, 4)
|
||||
fig.set_size_inches(9, 2.5)
|
||||
|
||||
# Reduce the spacing between subplots
|
||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
||||
|
||||
# Define datasets and their columns
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
metrics = ["Exact Match", "F1"]
|
||||
|
||||
# Define bar settings - make bars thicker
|
||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
||||
# width = total_width / n
|
||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
||||
# It's also used as the base for calculating the actual plotted bar width.
|
||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
||||
n = 3 # Number of models
|
||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
||||
|
||||
# Colors and hatches
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
||||
|
||||
# Create plots for each dataset
|
||||
for i, dataset in enumerate(datasets):
|
||||
ax = axs[i]
|
||||
|
||||
# Get data for this dataset and convert to percentages
|
||||
em_values = [
|
||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
||||
]
|
||||
f1_values = [
|
||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
||||
]
|
||||
|
||||
# Define x positions for bars
|
||||
# For EM: center - width, center, center + width
|
||||
# For F1: center - width, center, center + width
|
||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
||||
bar_offsets = [-width, 0, width]
|
||||
|
||||
# Plot all bars on the same axis
|
||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
||||
for j, model_label in enumerate(labels):
|
||||
x_pos = metric_group_center + bar_offsets[j]
|
||||
bar_value = values_to_plot[j]
|
||||
|
||||
ax.bar(
|
||||
x_pos,
|
||||
bar_value,
|
||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
||||
)
|
||||
|
||||
# Add value on top of bar
|
||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
||||
|
||||
# Set x-ticks and labels
|
||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
||||
|
||||
# Now, shift these labels slightly to the right
|
||||
# Adjust this value to control the amount of shift (in data coordinates)
|
||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
||||
# horizontal_shift = 0.7 # Try adjusting this value
|
||||
|
||||
# for label in xticklabels:
|
||||
# # Get the current x position (which is the tick location)
|
||||
# current_x_pos = label.get_position()[0]
|
||||
# # Set the new x position by adding the shift
|
||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
||||
# # Ensure the label remains horizontally centered on this new x position
|
||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
||||
# label.set_horizontalalignment('center')
|
||||
|
||||
# Set title
|
||||
ax.set_title(dataset, fontsize=14)
|
||||
|
||||
# Set y-label for all subplots
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
else:
|
||||
# Hide y-tick labels for non-first subplots to save space
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
# Set y-limits based on data range
|
||||
all_values = em_values + f1_values
|
||||
max_val = max(all_values)
|
||||
min_val = min(all_values)
|
||||
|
||||
# Special handling for GPQA which has very low values
|
||||
if dataset == "GPQA":
|
||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
||||
else:
|
||||
# Reduce the extra space above the bars
|
||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
||||
|
||||
# Format y-ticks as percentages
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
# Set x-limits to properly space the bars with less blank space
|
||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
||||
|
||||
# Add a box around the subplot
|
||||
# for spine in ax.spines.values():
|
||||
# spine.set_visible(True)
|
||||
# spine.set_linewidth(1.0)
|
||||
|
||||
# Add legend to first subplot
|
||||
if i == 0:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
||||
ncol=3, # Changed to 3 columns for three labels
|
||||
loc="upper center",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.0,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 12},
|
||||
)
|
||||
|
||||
# Save figure with tight layout but no additional padding
|
||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
||||
plt.show()
|
||||
@@ -1,309 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
||||
# per degree bin, styled for academic publications, with caching.
|
||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
from collections import Counter
|
||||
import os # For robust filepath manipulation
|
||||
import math # For calculating scaling factor
|
||||
import pickle # For caching data
|
||||
|
||||
# %%
|
||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
||||
|
||||
# --- Define styles from reference ---
|
||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
|
||||
# %%
|
||||
# --- File Paths ---
|
||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
||||
visit_log_file = './re.log'
|
||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
||||
# --- CACHE FILE PATH: Keep this consistent ---
|
||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
||||
|
||||
# --- Configuration ---
|
||||
# Set to True to bypass cache and force recomputation.
|
||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
||||
FORCE_RECOMPUTE = False
|
||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
||||
|
||||
# Create directory for figures if it doesn't exist
|
||||
output_dir = os.path.dirname(output_image_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created directory: {output_dir}")
|
||||
|
||||
# %%
|
||||
# --- Attempt to load data from cache or compute ---
|
||||
df_plot_data = None
|
||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
||||
|
||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
||||
cache_content = pickle.load(f)
|
||||
df_plot_data = cache_content['data']
|
||||
bin_size_for_plot = cache_content['bin_size']
|
||||
# Basic validation of cached data
|
||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
||||
'degree_bin_label' not in df_plot_data.columns or \
|
||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
||||
not isinstance(bin_size_for_plot, int):
|
||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
else:
|
||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
||||
|
||||
# --- Modify the label loaded from cache for display purpose ---
|
||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
||||
# Check if "0-4" label exists
|
||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
||||
# Use .loc to ensure the modification is on the original DataFrame
|
||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
||||
except Exception as e:
|
||||
print(f"Error loading from cache: {e}. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
|
||||
if df_plot_data is None:
|
||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
||||
# --- 1. Read Degree Distribution File ---
|
||||
degrees_data = []
|
||||
try:
|
||||
with open(degree_file, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
line_stripped = line.strip()
|
||||
if line_stripped:
|
||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
||||
|
||||
|
||||
if not degrees_data:
|
||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
||||
exit()
|
||||
df_degrees = pd.DataFrame(degrees_data)
|
||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
||||
|
||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
||||
visit_counts = Counter()
|
||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
||||
try:
|
||||
with open(visit_log_file, 'r') as f_log:
|
||||
for line_num, line in enumerate(f_log, 1):
|
||||
match = node_id_pattern.search(line)
|
||||
if match:
|
||||
try:
|
||||
node_id = int(match.group(2))
|
||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
||||
except ValueError:
|
||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
||||
if not df_degrees.empty:
|
||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
||||
# Generate visit counts to test different probability magnitudes
|
||||
if node_id_val % 23 == 0: # Very low probability
|
||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
||||
elif node_id_val % 11 == 0: # Low probability
|
||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
||||
elif node_id_val % 5 == 0: # Moderate probability
|
||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
||||
|
||||
if not visit_counts:
|
||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
||||
else:
|
||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
||||
df_visits = pd.DataFrame(df_visits_list)
|
||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
||||
|
||||
# --- 3. Merge Degree Data with Visit Data ---
|
||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
||||
print(f"Merged data contains {len(df_merged)} entries.")
|
||||
|
||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
||||
current_bin_size = 5
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
if not df_degrees.empty:
|
||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
||||
|
||||
df_merged_with_bins = df_merged.copy()
|
||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
||||
|
||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
||||
node_count_in_bin=('node_id', 'nunique')
|
||||
).reset_index()
|
||||
|
||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
||||
# This value is what gets cached.
|
||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
||||
|
||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
||||
|
||||
bin_to_drop_label = '60-64'
|
||||
original_length = len(df_binned_analysis)
|
||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
||||
if len(df_plot_data_intermediate) < original_length:
|
||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
||||
else:
|
||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
||||
|
||||
df_plot_data = df_plot_data_intermediate
|
||||
|
||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty:
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Error saving data to cache: {e}")
|
||||
elif df_plot_data is None or df_plot_data.empty:
|
||||
print("Computed data for binned plot is empty, not saving to cache.")
|
||||
else:
|
||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
||||
df_plot_data = pd.DataFrame()
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
# %%
|
||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
||||
base_name, ext = os.path.splitext(output_image_file)
|
||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
||||
binned_output_image_file = base_name + ext
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
||||
|
||||
df_plot_data_plotting = df_plot_data.copy()
|
||||
# Calculate per-query probability: (avg visits over N queries) / N
|
||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
||||
|
||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
||||
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
||||
|
||||
apply_scaling_to_label_and_values = False # Initialize flag
|
||||
exponent_for_label_display = 0 # Initialize exponent
|
||||
|
||||
if pd.notna(max_probability) and max_probability > 0:
|
||||
potential_exponent = math.floor(math.log10(max_probability))
|
||||
|
||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
||||
apply_scaling_to_label_and_values = True
|
||||
exponent_for_label_display = potential_exponent
|
||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
||||
|
||||
if apply_scaling_to_label_and_values:
|
||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
||||
else:
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
||||
|
||||
elif pd.notna(max_probability) and max_probability == 0:
|
||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
||||
else:
|
||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
||||
|
||||
ax.bar(
|
||||
df_plot_data_plotting['degree_bin_label'],
|
||||
y_axis_values_to_plot,
|
||||
color='white',
|
||||
edgecolor=edgecolors_ref[0],
|
||||
linewidth=1.5,
|
||||
width=0.8
|
||||
)
|
||||
|
||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
||||
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
||||
|
||||
num_bins = len(df_plot_data_plotting)
|
||||
if num_bins > 12:
|
||||
ax.set_xticks(ax.get_xticks())
|
||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
||||
elif num_bins > 8:
|
||||
ax.tick_params(axis='x', labelsize=9)
|
||||
else:
|
||||
ax.tick_params(axis='x', labelsize=10)
|
||||
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
padding_factor = 0.05
|
||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
||||
|
||||
upper_y_limit = 0.1 # Default small upper limit
|
||||
if pd.notna(current_max_y_on_axis):
|
||||
if current_max_y_on_axis > 0:
|
||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
||||
min_meaningful_limit = 0.01
|
||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
||||
|
||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
||||
|
||||
else: # current_max_y_on_axis is 0
|
||||
upper_y_limit = 0.1
|
||||
ax.set_ylim(0, upper_y_limit)
|
||||
else:
|
||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
else:
|
||||
if df_plot_data is None:
|
||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
||||
elif df_plot_data.empty:
|
||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
||||
|
||||
# %%
|
||||
print("Script finished.")
|
||||
@@ -1,7 +0,0 @@
|
||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
||||
BIG_GRAPH_PATHS = [
|
||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
||||
]
|
||||
STATS_FILE_NAME = "degree_distribution.txt"
|
||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
||||
|
||||
# --- Cache File Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
||||
|
||||
def create_degree_data_cache():
|
||||
"""
|
||||
Reads degree distribution data from specified text files and saves it
|
||||
into a compressed NumPy (.npz) cache file.
|
||||
"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
cached_data = {}
|
||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
||||
|
||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
||||
method_label = BIG_GRAPH_LABELS[i]
|
||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
||||
|
||||
print(f"Processing: {method_label} from {degree_file_path}")
|
||||
|
||||
try:
|
||||
# Load degrees as integers
|
||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
||||
|
||||
if degrees.size == 0:
|
||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
||||
cached_data[method_label] = np.array([], dtype=int)
|
||||
else:
|
||||
# Store the loaded degrees array with the method label as the key
|
||||
cached_data[method_label] = degrees
|
||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
||||
except Exception as e:
|
||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
||||
|
||||
if not cached_data:
|
||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save all collected degree arrays into a single .npz file.
|
||||
# Using savez_compressed for potentially smaller file size.
|
||||
np.savez_compressed(cache_file_path, **cached_data)
|
||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("--- Degree Distribution Data Caching Script ---")
|
||||
create_degree_data_cache()
|
||||
print("--- Caching script finished. ---")
|
||||
@@ -1,4 +0,0 @@
|
||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
||||
|
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
||||
size 227482438
|
||||
@@ -1,21 +0,0 @@
|
||||
2,1,512,1024,0.541,0.326,1.659509202
|
||||
2,2,512,1024,0.979,0.621,1.576489533
|
||||
2,4,512,1024,1.846,0.977,1.889457523
|
||||
2,8,512,1024,3.575,1.943,1.83993824
|
||||
2,16,512,1024,7.035,3.733,1.884543263
|
||||
2,32,512,1024,15.655,8.517,1.838088529
|
||||
2,64,512,1024,32.772,17.43,1.88020654
|
||||
4,1,512,1024,2.675,1.38,1.938405797
|
||||
4,2,512,1024,5.397,2.339,2.307396323
|
||||
4,4,512,1024,10.672,4.944,2.158576052
|
||||
4,8,512,1024,21.061,9.266,2.272933305
|
||||
4,16,512,1024,46.332,18.334,2.527108105
|
||||
4,32,512,1024,99.607,36.156,2.754923111
|
||||
4,64,512,1024,186.348,72.356,2.575432583
|
||||
8,1,512,1024,7.325,4.087,1.792268167
|
||||
8,2,512,1024,14.109,7.491,1.883460152
|
||||
8,4,512,1024,28.499,14.013,2.033754371
|
||||
8,8,512,1024,65.222,27.453,2.375769497
|
||||
8,16,512,1024,146.294,52.55,2.783901047
|
||||
8,32,512,1024,277.099,103.61,2.674442621
|
||||
8,64,512,1024,512.979,208.36,2.461984066
|
||||
|
@@ -1,9 +0,0 @@
|
||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
||||
NQ,Latency,6.9,5.8,4.2,3.7
|
||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
||||
|
@@ -1,25 +0,0 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
||||
|
@@ -1,25 +0,0 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
||||
|
@@ -1,3 +0,0 @@
|
||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
||||
RAM,190,171,10,0,0,0,0
|
||||
Storage,185.4,171,240,171,0.5,5,59
|
||||
|
@@ -1,12 +0,0 @@
|
||||
Torch,8,55.592
|
||||
Torch,16,75.439
|
||||
Torch,32,110.025
|
||||
Torch,64,186.496
|
||||
Tutel,8,56.718
|
||||
Tutel,16,82.121
|
||||
Tutel,32,125.070
|
||||
Tutel,64,216.191
|
||||
BRT,8,56.725
|
||||
BRT,16,79.291
|
||||
BRT,32,93.180
|
||||
BRT,64,118.923
|
||||
|
@@ -1,6 +0,0 @@
|
||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
||||
Latency,,,,,
|
||||
NQ,4.616,4.133,3.826,3.511,3.323
|
||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
||||
|
@@ -1,151 +0,0 @@
|
||||
import matplotlib
|
||||
from matplotlib.axes import Axes
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
# plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % Use Helvetica font for text
|
||||
\usepackage{sfmath} % Use sans-serif font for math
|
||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
||||
"""
|
||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# New data in dictionary format
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
||||
|
||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
||||
latency_data = {
|
||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
||||
}
|
||||
cache_hit_counts = {
|
||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
||||
}
|
||||
|
||||
# Create the figure with 4 subplots in a 2x2 grid
|
||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
||||
|
||||
# Bar style settings
|
||||
width = 0.7
|
||||
x = np.arange(len(cache_ratios))
|
||||
|
||||
# Define hatch patterns for different cache ratios
|
||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
||||
|
||||
# Find max cache hit value across all datasets for unified y-axis
|
||||
all_hit_counts = []
|
||||
for dataset in datasets:
|
||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
||||
max_unified_hit = max(all_hit_counts) * 1.13
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
latencies = latency_data[dataset]
|
||||
hit_counts = cache_hit_counts[dataset]
|
||||
|
||||
for j, val in enumerate(latencies):
|
||||
container = axes[i].bar(
|
||||
x[j],
|
||||
val,
|
||||
width=width,
|
||||
color="white",
|
||||
edgecolor="black",
|
||||
linewidth=1.0,
|
||||
zorder=10,
|
||||
)
|
||||
axes[i].bar_label(
|
||||
container,
|
||||
[f"{val:.2f}"],
|
||||
fontsize=10,
|
||||
zorder=200,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
axes[i].set_title(dataset, fontsize=font_size)
|
||||
axes[i].set_xticks(x)
|
||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
||||
|
||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
||||
max_val = max(latencies) * max_val_ratios[i]
|
||||
axes[i].set_ylim(0, max_val)
|
||||
axes[i].tick_params(axis='y', labelsize=12)
|
||||
|
||||
if i % 2 == 0:
|
||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
||||
|
||||
ax2: Axes = axes[i].twinx()
|
||||
ax2.plot(x, hit_counts,
|
||||
linestyle='--',
|
||||
marker='o',
|
||||
markersize=6,
|
||||
linewidth=1.5,
|
||||
color='k',
|
||||
markerfacecolor='none',
|
||||
zorder=20)
|
||||
|
||||
ax2.set_ylim(0, max_unified_hit)
|
||||
ax2.tick_params(axis='y', labelsize=12)
|
||||
if i % 2 == 1:
|
||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
||||
|
||||
for j, val in enumerate(hit_counts):
|
||||
if val > 0:
|
||||
ax2.annotate(f"{val:.1f}%",
|
||||
(x[j], val),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 5),
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold')
|
||||
|
||||
# Create legend for both plots
|
||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
||||
|
||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
||||
fig.legend(handles=[bar_patch, line_patch],
|
||||
loc='upper center', # Position the legend at the upper center
|
||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
||||
# 0.97 means 97% from the bottom, so near the top)
|
||||
ncol=3,
|
||||
fontsize=font_size-2)
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Set common x-axis label - you might want to add this back if needed
|
||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
||||
|
||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
||||
# Adjust rect to make space for the legend at the top.
|
||||
# (left, bottom, right, top_for_subplots)
|
||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Create directory if it doesn't exist (optional, good practice)
|
||||
import os
|
||||
if not os.path.exists(SAVE_PTH):
|
||||
os.makedirs(SAVE_PTH)
|
||||
|
||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
||||
# plt.show() # Optional: to display the plot
|
||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 130 KiB |
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 100 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 41 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user