Compare commits
14 Commits
v0.1.10
...
readme-pol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f83c97e6d1 | ||
|
|
6e755f0402 | ||
|
|
cc6b904c44 | ||
|
|
bda028cc1b | ||
|
|
bed814e7e6 | ||
|
|
96f74973b1 | ||
|
|
1f90cdfafb | ||
|
|
8f4f66d871 | ||
|
|
43b52a8c0a | ||
|
|
1a3180bc0f | ||
|
|
fe4a748a69 | ||
|
|
d296f372e0 | ||
|
|
909835dd2d | ||
|
|
1eea69e8d7 |
11
.github/workflows/build-and-publish.yml
vendored
11
.github/workflows/build-and-publish.yml
vendored
@@ -1,11 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
167
.github/workflows/build-reusable.yml
vendored
167
.github/workflows/build-reusable.yml
vendored
@@ -1,167 +0,0 @@
|
||||
name: Reusable Build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'Git ref to build'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
- os: macos-latest
|
||||
python: '3.9'
|
||||
- os: macos-latest
|
||||
python: '3.10'
|
||||
- os: macos-latest
|
||||
python: '3.11'
|
||||
- os: macos-latest
|
||||
python: '3.12'
|
||||
- os: macos-latest
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
126
.github/workflows/release-manual.yml
vendored
126
.github/workflows/release-manual.yml
vendored
@@ -1,126 +0,0 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to release (e.g., 0.1.2)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
name: Update Version
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid"
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
# Check current version
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
else
|
||||
./scripts/bump_version.sh ${{ inputs.version }}
|
||||
git config user.name "GitHub Actions"
|
||||
git config user.email "actions@github.com"
|
||||
git add packages/*/pyproject.toml
|
||||
git commit -m "chore: release v${{ inputs.version }}"
|
||||
git push origin main
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
name: Build packages
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
needs: [update-version, build-packages]
|
||||
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
if [ -z "$TWINE_PASSWORD" ]; then
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||
else
|
||||
git tag "v${{ inputs.version }}"
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
else
|
||||
gh release create "v${{ inputs.version }}" \
|
||||
--title "Release v${{ inputs.version }}" \
|
||||
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||
--latest
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -12,6 +12,7 @@ outputs/
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
scripts/
|
||||
lm_eval.egg-info/
|
||||
demo/experiment_results/**/*.json
|
||||
*.jsonl
|
||||
@@ -83,6 +84,4 @@ test_*.py
|
||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
*.passages.json
|
||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,9 +1,9 @@
|
||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||
url = https://github.com/yichuan-w/DiskANN.git
|
||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||
path = packages/leann-backend-hnsw/third_party/faiss
|
||||
url = https://github.com/yichuan-w/faiss.git
|
||||
url = https://github.com/yichuan520030910320/faiss.git
|
||||
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||
url = https://github.com/msgpack/msgpack-c.git
|
||||
|
||||
9
.vscode/extensions.json
vendored
Normal file
9
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"llvm-vs-code-extensions.vscode-clangd",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cmake-tools",
|
||||
"vadimcn.vscode-lldb",
|
||||
"eamodio.gitlens",
|
||||
]
|
||||
}
|
||||
283
.vscode/launch.json
vendored
Executable file
283
.vscode/launch.json
vendored
Executable file
@@ -0,0 +1,283 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
// new emdedder
|
||||
{
|
||||
"name": "New Embedder",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--search",
|
||||
"--use-original",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--nprobe",
|
||||
"5000",
|
||||
"--load",
|
||||
"flat",
|
||||
"--embedder",
|
||||
"intfloat/multilingual-e5-small"
|
||||
]
|
||||
}
|
||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
||||
{
|
||||
"name": "main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--query",
|
||||
"1000",
|
||||
"--load",
|
||||
"bm25"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Simple Build",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/simple_build.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
//# Fix for Intel MKL error
|
||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
||||
//python faiss/demo/build_demo.py
|
||||
{
|
||||
"name": "Build Demo",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/build_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings",
|
||||
"--port",
|
||||
"8082",
|
||||
"--diskann-search-memory-maximum",
|
||||
"2",
|
||||
"--diskann-graph",
|
||||
"240",
|
||||
"--search-only"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
||||
},
|
||||
"preLaunchTask": "CMake: build",
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve MAC",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"ollama",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings"
|
||||
],
|
||||
"preLaunchTask": "CMake: build",
|
||||
"env": {
|
||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
||||
"KMP_BLOCKTIME": "0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "ric/main_ric.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--config-name",
|
||||
"${input:configSelection}"
|
||||
],
|
||||
"justMyCode": false
|
||||
},
|
||||
//python ./demo/validate_equivalence.py sglang
|
||||
{
|
||||
"name": "Validate Equivalence",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/validate_equivalence.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"sglang"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
||||
{
|
||||
"name": "Retrieval Demo",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"vllm",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
// "flat",
|
||||
"ivf_flat"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
||||
{
|
||||
"name": "Retrieval Demo DiskANN",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--hnsw-M",
|
||||
"64",
|
||||
"--hnsw-efConstruction",
|
||||
"150",
|
||||
"--hnsw-efSearch",
|
||||
"128",
|
||||
"--hnsw-sq-bits",
|
||||
"8"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Find Probe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "find_probe.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Python: Attach",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"processId": "${command:pickProcess}",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Edge RAG",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"edgerag_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Embedding Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/embedding_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--zmq-port",
|
||||
"5556",
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "HNSW Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--load",
|
||||
"hnsw",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--search",
|
||||
"--skip-pa",
|
||||
"--recompute",
|
||||
"--hnsw-old"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "configSelection",
|
||||
"type": "pickString",
|
||||
"description": "Select a configuration",
|
||||
"options": [
|
||||
"example_config",
|
||||
"vllm_gritlm"
|
||||
],
|
||||
"default": "example_config"
|
||||
}
|
||||
],
|
||||
}
|
||||
43
.vscode/settings.json
vendored
Executable file
43
.vscode/settings.json
vendored
Executable file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./sglang_repo/python"
|
||||
],
|
||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
||||
"cmake.configureArgs": [
|
||||
"-DPYBIND=True",
|
||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
||||
],
|
||||
"cmake.environment": {
|
||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
||||
},
|
||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
||||
"files.associations": {
|
||||
"*.tcc": "cpp",
|
||||
"deque": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"atomic": "cpp",
|
||||
"inplace_vector": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"list": "cpp",
|
||||
"any": "cpp",
|
||||
"system_error": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__tree": "cpp",
|
||||
"ios": "cpp",
|
||||
"set": "cpp",
|
||||
"__string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"ranges": "cpp",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"lldb.displayFormat": "auto",
|
||||
"lldb.showDisassembly": "auto",
|
||||
"lldb.dereferencePointers": true,
|
||||
"lldb.consoleMode": "commands",
|
||||
}
|
||||
16
.vscode/tasks.json
vendored
Normal file
16
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
||||
356
README.md
356
README.md
@@ -12,74 +12,67 @@
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
LEANN is a revolutionary vector database that makes personal AI accessible to everyone. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
RAG your **[emails](#-search-your-entire-life)**, **[browser history](#-time-machine-for-the-web)**, **[WeChat](#-wechat-detective)**, or 60M documents on your laptop, in nearly zero cost. No cloud, no API keys, completely private.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Read more →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
## Why LEANN?
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="100%">
|
||||
</p>
|
||||
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||
|
||||
## Why This Matters
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
🪶 **Lightweight:** Smart graph pruning means less storage, less memory usage, better performance on your existing hardware.
|
||||
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
📈 **Scalability:** Organize our messy personal data that would crash traditional vector DBs, with performance that gets better as your data grows more personalized.
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Installation
|
||||
> `pip leann` coming soon!
|
||||
## Quick Start in 1 minute
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
# Install uv first if you don't have it:
|
||||
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
brew install llvm libomp boost protobuf
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
uv sync
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||
uv sync
|
||||
```
|
||||
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
**Ollama Setup (Optional for Local LLM):**
|
||||
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||
|
||||
|
||||
**macOS:**
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
*macOS:*
|
||||
|
||||
```bash
|
||||
# Install Ollama
|
||||
brew install ollama
|
||||
|
||||
# Pull a lightweight model (recommended for consumer hardware)
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
*Linux:*
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -91,78 +84,62 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
## Quick Start in 30s
|
||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
[Try in this ipynb file →](demo.ipynb) [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
# 1. Build index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Python is a powerful programming language")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||
print(results)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
## Wild Things You Can Do
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
### 📚 Process Any Documents (.pdf, .txt, .md)
|
||||
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
### 🕵️ Search Your Entire Life
|
||||
```bash
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||
python examples/mail_reader_leann.py
|
||||
# "What did my boss say about the Christmas party last year?"
|
||||
# "Find all emails from my mom about birthday plans"
|
||||
```
|
||||
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -195,16 +172,13 @@ Once the index is built, you can ask questions like:
|
||||
- "Show me emails about travel expenses"
|
||||
</details>
|
||||
|
||||
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
### 🌐 Time Machine for the Web
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
python examples/google_history_reader_leann.py
|
||||
# "What was that AI paper I read last month?"
|
||||
# "Show me all the cooking videos I watched"
|
||||
```
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -253,17 +227,13 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||
</p>
|
||||
### 💬 WeChat Detective
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
python examples/wechat_history_reader_leann.py
|
||||
# "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
@@ -274,13 +244,7 @@ First, you need to install the WeChat exporter:
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -315,73 +279,6 @@ Once the index is built, you can ask questions like:
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||
- Smart text chunking with overlap
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `~/.leann/indexes/`
|
||||
- Support for advanced search parameters
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact Use compact storage (default: true)
|
||||
--recompute Enable recomputation (default: true)
|
||||
```
|
||||
|
||||
**Search Command:**
|
||||
```bash
|
||||
leann search INDEX_NAME QUERY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute-embeddings Use recomputation for highest accuracy
|
||||
--pruning-strategy {global,local,proportional}
|
||||
```
|
||||
|
||||
**Ask Command:**
|
||||
```bash
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
<p align="center">
|
||||
@@ -400,17 +297,18 @@ Options:
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Run the comparison yourself:
|
||||
```bash
|
||||
python examples/compare_faiss_vs_leann.py
|
||||
```
|
||||
|
||||
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||
### Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||
|
||||
| System | Storage |
|
||||
|--------|---------|
|
||||
| FAISS HNSW | 5.5 MB |
|
||||
| LEANN | 0.5 MB |
|
||||
| **Savings** | **91%** |
|
||||
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
@@ -420,7 +318,33 @@ python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR datase
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
The evaluation script downloads data automatically on first run.
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||
|
||||
<!-- ### Memory Usage Comparison
|
||||
|
||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||
| **Reduction** | **x%** | **x%** | **x%** |
|
||||
|
||||
### Query Performance of LEANN
|
||||
|
||||
| Backend | Index Size | Query Time | Recall@3 |
|
||||
| ------------------- | ---------- | ---------- | --------- |
|
||||
| DiskANN | 1M docs | xms | 0.95 |
|
||||
| HNSW | 1M docs | xms | 0.95 | -->
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
@@ -439,15 +363,87 @@ If you find Leann useful, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
## ✨ Features
|
||||
|
||||
## 🤝 [Contributing →](docs/contributing.md)
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
### Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
|
||||
## [FAQ →](docs/faq.md)
|
||||
<!-- ## ❓ FAQ
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### NCCL Topology Error
|
||||
|
||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||
|
||||
```
|
||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||
```
|
||||
|
||||
**Solution**: Set these environment variables before running your script:
|
||||
|
||||
```bash
|
||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
``` -->
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
### 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
### 🚀 Q3 2025
|
||||
|
||||
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
### 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
|
||||
## 📄 License
|
||||
|
||||
@@ -455,7 +451,11 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||
- **Microsoft Research** for the DiskANN algorithm
|
||||
- **Meta AI** for FAISS and optimization insights
|
||||
- **HuggingFace** for the transformer ecosystem
|
||||
- **Our amazing contributors** who make this possible
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 339 KiB After Width: | Height: | Size: 206 KiB |
316
demo.ipynb
316
demo.ipynb
@@ -1,321 +1,35 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# install this if you areusing colab\n",
|
||||
"! pip install leann"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.21s] Read levels (5)\n",
|
||||
"[0.30s] Probing for compact storage flag...\n",
|
||||
"[0.30s] Found compact flag: False\n",
|
||||
"[0.30s] Compact flag is False, reading original format...\n",
|
||||
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.30s] Read offsets (6)\n",
|
||||
"[0.40s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.40s] Read neighbors (320)\n",
|
||||
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.50s] Checking for storage data...\n",
|
||||
"[0.50s] Found storage fourcc: 49467849.\n",
|
||||
"[0.50s] Converting to CSR format...\n",
|
||||
"[0.50s] Conversion loop finished. \n",
|
||||
"[0.50s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.69s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||
"# 1. Build index (no embeddings stored!)\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with real-time embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
||||
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
||||
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with LEANN using retrieved results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
||||
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
|
||||
"print(results)\n",
|
||||
"\n",
|
||||
"llm_config = {\n",
|
||||
" \"type\": \"hf\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||
"}\n",
|
||||
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
|
||||
" top_k=2,\n",
|
||||
" llm_kwargs={\"max_tokens\": 128}\n",
|
||||
" recompute_beighbor_embeddings=True,\n",
|
||||
")\n",
|
||||
"response"
|
||||
"print(response)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
# Release Guide
|
||||
|
||||
## Setup (One-time)
|
||||
|
||||
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||
1. Get token: https://pypi.org/manage/account/token/
|
||||
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||
|
||||
## Release (One-click)
|
||||
|
||||
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||
2. Click "Run workflow"
|
||||
3. Enter version: `0.1.2`
|
||||
4. Click green "Run workflow" button
|
||||
|
||||
That's it! The workflow will automatically:
|
||||
- ✅ Update version in all packages
|
||||
- ✅ Build all packages
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
@@ -1,11 +0,0 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
10
docs/faq.md
10
docs/faq.md
@@ -1,10 +0,0 @@
|
||||
# FAQ
|
||||
|
||||
## 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
@@ -1,22 +0,0 @@
|
||||
# ✨ Detailed Features
|
||||
|
||||
## 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
@@ -1,21 +0,0 @@
|
||||
# 📈 Roadmap
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
## 🚀 Q3 2025
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
## 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
@@ -135,7 +135,6 @@ def test_leann_hnsw():
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
print(f"Total number of chunks: {len(all_texts)}")
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
|
||||
@@ -96,12 +96,14 @@ class EmlxReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def main():
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
|
||||
print("Please install it with `uv pip install faiss-cpu`")
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
|
||||
@@ -65,14 +65,12 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -80,9 +78,7 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
@@ -222,15 +218,14 @@ async def query_leann_index(index_path: str, query: str):
|
||||
"max_tokens": 1000
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./google_history_index",
|
||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
|
||||
@@ -74,17 +74,22 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[Title]: {title}
|
||||
[URL of the page]: {url}
|
||||
[Last visited time]: {last_visit}
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
[BROWSING HISTORY METADATA]
|
||||
URL: {url}
|
||||
Title: {title}
|
||||
Last Visit: {last_visit}
|
||||
Visit Count: {visit_count}
|
||||
Typed Count: {typed_count}
|
||||
Hidden: {hidden}
|
||||
[END METADATA]
|
||||
|
||||
Title: {title}
|
||||
URL: {url}
|
||||
Last visited: {last_visit}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -197,8 +197,8 @@ class WeChatHistoryReader(BaseReader):
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||
max_length: Maximum length for concatenated message groups
|
||||
time_window_minutes: Time window in minutes to group messages together
|
||||
overlap_messages: Number of messages to overlap between consecutive groups
|
||||
|
||||
Returns:
|
||||
@@ -230,8 +230,8 @@ class WeChatHistoryReader(BaseReader):
|
||||
if not readable_text.strip():
|
||||
continue
|
||||
|
||||
# Check time window constraint (only if time_window_minutes != -1)
|
||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||
# Check time window constraint
|
||||
if last_timestamp is not None and create_time > 0:
|
||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||
if time_diff_minutes > time_window_minutes:
|
||||
# Time gap too large, start new group
|
||||
@@ -250,9 +250,9 @@ class WeChatHistoryReader(BaseReader):
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
# Check length constraint (only if max_length != -1)
|
||||
# Check length constraint
|
||||
message_length = len(readable_text)
|
||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||
if current_length + message_length > max_length and current_group:
|
||||
# Current group would exceed max length, save it and start new
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
@@ -335,15 +335,14 @@ class WeChatHistoryReader(BaseReader):
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
time_str = timestamp.strftime('%H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
sender = "Me" if is_sent_from_self else "Contact"
|
||||
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
@@ -355,11 +354,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content, contact_name
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
@@ -430,9 +431,9 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=2 # Keep 2 messages overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
@@ -440,8 +441,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
doc_content = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_mail_path():
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
# Default mail path for macOS
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
@@ -85,11 +85,9 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
@@ -158,7 +156,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -218,22 +216,22 @@ async def query_leann_index(index_path: str, query: str):
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
complexity=12,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
end_time = time.time()
|
||||
# print(f"Time taken: {end_time - start_time} seconds")
|
||||
# highlight the answer
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index",
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
@@ -253,9 +251,6 @@ async def main():
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
|
||||
@@ -1,40 +1,40 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
@@ -58,19 +58,22 @@ async def main(args):
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
query = args.query
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
chat_response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -102,18 +105,6 @@ if __name__ == "__main__":
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||
help="The query to ask the Leann chat system.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
|
||||
@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -86,11 +86,10 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||
all_texts.append(text)
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
@@ -225,7 +224,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
complexity=64,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
@@ -234,7 +233,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -253,13 +252,13 @@ async def main():
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
default="./wechat_history_june19_test",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
default=5000,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
from typing import Dict, Any, List, Literal
|
||||
import contextlib
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
@@ -16,46 +14,6 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cpp_output_if_needed():
|
||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
|
||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||
|
||||
if not should_suppress:
|
||||
# Don't suppress, just yield
|
||||
yield
|
||||
return
|
||||
|
||||
# Save original file descriptors
|
||||
stdout_fd = sys.stdout.fileno()
|
||||
stderr_fd = sys.stderr.fileno()
|
||||
|
||||
# Save original stdout/stderr
|
||||
stdout_dup = os.dup(stdout_fd)
|
||||
stderr_dup = os.dup(stderr_fd)
|
||||
|
||||
try:
|
||||
# Redirect to /dev/null
|
||||
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||
os.dup2(devnull, stdout_fd)
|
||||
os.dup2(devnull, stderr_fd)
|
||||
os.close(devnull)
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore original file descriptors
|
||||
os.dup2(stdout_dup, stdout_fd)
|
||||
os.dup2(stderr_dup, stderr_fd)
|
||||
os.close(stdout_dup)
|
||||
os.close(stderr_dup)
|
||||
|
||||
|
||||
def _get_diskann_metrics():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -107,20 +65,22 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
if metric_enum is None:
|
||||
raise ValueError(
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
raise ValueError("Unsupported distance_metric.")
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
@@ -142,40 +102,36 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
os.remove(temp_data_file)
|
||||
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||
|
||||
|
||||
class DiskannSearcher(BaseSearcher):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
super().__init__(
|
||||
index_path,
|
||||
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||
backend_module_name="leann_backend_diskann.embedding_server",
|
||||
**kwargs,
|
||||
)
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
self.zmq_port,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -186,7 +142,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: int = 5557,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
@@ -205,7 +161,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
- "global": Use global pruning strategy (default)
|
||||
- "local": Use local pruning strategy
|
||||
- "proportional": Not supported in DiskANN, falls back to global
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
zmq_port: ZMQ port for embedding server
|
||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||
@@ -213,25 +169,22 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(
|
||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||
)
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
raise NotImplementedError(
|
||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||
)
|
||||
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
|
||||
@@ -241,26 +194,28 @@ class DiskannSearcher(BaseSearcher):
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
use_recompute,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
"""
|
||||
DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Ensure we have a handler if none exists
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||
"""
|
||||
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
# Import protobuf after ensuring the path is correct
|
||||
try:
|
||||
from . import embedding_pb2
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import protobuf module: {e}")
|
||||
return
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# REP socket receives single-part messages
|
||||
message = socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if len(message) == 0:
|
||||
logger.debug("Received empty message, sending empty response")
|
||||
socket.send(b"") # REP socket must respond to every request
|
||||
continue
|
||||
|
||||
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||
|
||||
e2e_start = time.time()
|
||||
|
||||
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||
texts = []
|
||||
node_ids = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
if not node_ids:
|
||||
raise RuntimeError(
|
||||
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||
)
|
||||
except Exception as protobuf_error:
|
||||
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
# For BaseSearcher compatibility, request is a list of texts directly
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
raise RuntimeError(
|
||||
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||
)
|
||||
|
||||
# Look up texts by node IDs (only if not direct text request)
|
||||
if not is_text_request:
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(
|
||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||
) # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For BaseSearcher compatibility: return msgpack format
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DiskANN Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
"--passages-file",
|
||||
type=str,
|
||||
help="Metadata JSON file containing passage sources",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the DiskANN embedding server
|
||||
create_diskann_embedding_server(
|
||||
passages_file=args.passages_file,
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- New Passage Loader from HNSW backend ---
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
self._meta_path = ''
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def keys(self):
|
||||
return self.passages_data.keys()
|
||||
|
||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages using metadata file with PassageManager for lazy loading
|
||||
"""
|
||||
# Load metadata to get passage sources
|
||||
with open(meta_file, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Import PassageManager dynamically to avoid circular imports
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Find the leann package directory relative to this file
|
||||
current_dir = Path(__file__).parent
|
||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.api import PassageManager
|
||||
passage_manager = PassageManager(meta['passage_sources'])
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Load label map
|
||||
passages_dir = Path(meta_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager, label_map):
|
||||
self.passage_manager = passage_manager
|
||||
self.label_map = label_map
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
if int_id in self.label_map:
|
||||
string_id = self.label_map[int_id]
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_map)
|
||||
|
||||
def keys(self):
|
||||
return self.label_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager, label_map)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and run embedding server in the current thread
|
||||
This function is designed to be called in a separate thread
|
||||
"""
|
||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# Check if port is already occupied
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
embedding_mode = "openai"
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
from leann.api import compute_embeddings_mlx
|
||||
import torch
|
||||
logger.info("Using MLX for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "openai":
|
||||
from leann.api import compute_embeddings_openai
|
||||
import torch
|
||||
logger.info("Using OpenAI API for embeddings")
|
||||
# Set device to CPU for compatibility with DeviceTimer class
|
||||
device = torch.device("cpu")
|
||||
cuda_available = False
|
||||
mps_available = False
|
||||
elif embedding_mode == "sentence-transformers":
|
||||
# Initialize model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# Select device
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
logger.info("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
logger.info("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU device")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# Optimize model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||
|
||||
# Load passages from file if provided
|
||||
if passages_file and os.path.exists(passages_file):
|
||||
# Check if it's a metadata file or a single passages file
|
||||
if passages_file.endswith('.meta.json'):
|
||||
passages = load_passages_from_metadata(passages_file)
|
||||
else:
|
||||
# Try to find metadata file in same directory
|
||||
passages_dir = Path(passages_file).parent
|
||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
||||
if meta_files:
|
||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
||||
else:
|
||||
# Fallback to original single file loading (will cause warnings)
|
||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||
passages = SimplePassageLoader()
|
||||
|
||||
logger.info(f"Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""Device timer"""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
elapsed = self.elapsed_time()
|
||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
||||
|
||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process text batch"""
|
||||
if not texts_batch:
|
||||
return np.array([])
|
||||
|
||||
# Filter out empty texts and their corresponding IDs
|
||||
valid_texts = []
|
||||
valid_ids = []
|
||||
for i, text in enumerate(texts_batch):
|
||||
if text.strip(): # Only include non-empty texts
|
||||
valid_texts.append(text)
|
||||
valid_ids.append(ids_batch[i])
|
||||
|
||||
if not valid_texts:
|
||||
print("WARNING: No valid texts in batch")
|
||||
return np.array([])
|
||||
|
||||
# Tokenize
|
||||
token_timer = DeviceTimer("tokenization")
|
||||
with token_timer.timing():
|
||||
inputs = tokenizer(
|
||||
valid_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
# Compute embeddings
|
||||
embed_timer = DeviceTimer("embedding computation")
|
||||
with embed_timer.timing():
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Mean pooling
|
||||
attention_mask = inputs['attention_mask']
|
||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
batch_embeddings = sum_embeddings / sum_mask
|
||||
embed_timer.print_elapsed()
|
||||
|
||||
return batch_embeddings.cpu().numpy()
|
||||
|
||||
# ZMQ server main loop - modified to use REP socket
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||
|
||||
# Set timeouts
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
||||
|
||||
from . import embedding_pb2
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
|
||||
# --- Restore robust message format detection ---
|
||||
# Must check parts length to avoid IndexError
|
||||
if len(parts) >= 3:
|
||||
identity = parts[0]
|
||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
||||
message = parts[2]
|
||||
elif len(parts) == 2:
|
||||
# Can also handle cases without empty frame
|
||||
identity = parts[0]
|
||||
message = parts[1]
|
||||
else:
|
||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||
continue
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
# Handle control messages (MessagePack format)
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message)
|
||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||
if request_payload[0] == "__QUERY_META_PATH__":
|
||||
# Return the current meta path being used by the server
|
||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
||||
response = [current_meta_path]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
||||
# Update the server's meta path and reload passages
|
||||
new_meta_path = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
||||
# Reload passages from the new meta file
|
||||
passages = load_passages_from_metadata(new_meta_path)
|
||||
# Store the meta path for future queries
|
||||
passages._meta_path = new_meta_path
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update meta path: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__QUERY_MODEL__":
|
||||
# Return the current model being used by the server
|
||||
response = [model_name]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
|
||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
||||
# Update the server's embedding model
|
||||
new_model_name = request_payload[1]
|
||||
try:
|
||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
||||
|
||||
# Clean up old model to free memory
|
||||
if not use_mlx:
|
||||
print("INFO: Releasing old model from memory...")
|
||||
old_model = model
|
||||
old_tokenizer = tokenizer
|
||||
|
||||
# Load new tokenizer first
|
||||
print(f"Loading new tokenizer for {new_model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
||||
|
||||
# Load new model
|
||||
print(f"Loading new model {new_model_name}...")
|
||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
||||
|
||||
# Optimize new model
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# Now safely delete old model after new one is loaded
|
||||
del old_model
|
||||
del old_tokenizer
|
||||
|
||||
# Clear GPU cache if available
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("INFO: Cleared CUDA cache")
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
print("INFO: Cleared MPS cache")
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
print("INFO: Memory cleanup completed")
|
||||
|
||||
# Update model name
|
||||
model_name = new_model_name
|
||||
|
||||
response = ["SUCCESS"]
|
||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to update model: {e}")
|
||||
response = ["FAILED", str(e)]
|
||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
||||
continue
|
||||
except:
|
||||
# Not a control message, continue with normal protobuf processing
|
||||
pass
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup")
|
||||
|
||||
# Parse request
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = req_proto.node_ids
|
||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||
|
||||
# Add debug information
|
||||
if len(node_ids) > 0:
|
||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||
|
||||
# Look up texts
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
else:
|
||||
# If text is empty, we still need a placeholder for batch processing,
|
||||
# but record its ID as missing
|
||||
texts.append("")
|
||||
missing_ids.append(nid)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# Process batch
|
||||
total_size = len(texts)
|
||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
if embedding_mode == "mlx":
|
||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
||||
else: # sentence-transformers
|
||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
if embedding_mode == "mlx":
|
||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
||||
elif embedding_mode == "openai":
|
||||
hidden = compute_embeddings_openai(texts, model_name)
|
||||
else: # sentence-transformers
|
||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
||||
|
||||
# Serialize response
|
||||
ser_start = time.time()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
resp_proto.missing_ids.extend(missing_ids)
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# REP socket sends a single response
|
||||
socket.send_multipart([identity, b'', response_data])
|
||||
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if embedding_mode == "sentence-transformers":
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
try:
|
||||
# Send empty response to maintain REQ-REP state
|
||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||
socket.send(empty_resp.SerializeToString())
|
||||
except:
|
||||
# If sending fails, recreate socket
|
||||
socket.close()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
print("INFO: ZMQ socket recreated after error")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to start embedding server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
load_embeddings=False,
|
||||
use_fp16=True,
|
||||
use_int8=False,
|
||||
use_cuda_graphs=False,
|
||||
zmq_port=5555,
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
||||
choices=["sentence-transformers", "mlx", "openai"],
|
||||
help="Embedding backend mode")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle backward compatibility with use_mlx
|
||||
embedding_mode = args.embedding_mode
|
||||
if args.use_mlx:
|
||||
embedding_mode = "mlx"
|
||||
|
||||
create_embedding_server(
|
||||
domain=args.domain,
|
||||
load_passages=args.load_passages,
|
||||
load_embeddings=args.load_embeddings,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
@@ -4,16 +4,13 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.10"
|
||||
dependencies = ["leann-core==0.1.10", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
# 关键:简化的 CMake 路径
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# Key: Python package in root directory, paths match exactly
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# Use default redirect mode
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
# 使用默认的 redirect 模式
|
||||
editable.mode = "redirect"
|
||||
@@ -1,7 +1,6 @@
|
||||
# 最终简化版
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
set(CMAKE_C_COMPILER_WORKS 1)
|
||||
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
@@ -12,9 +11,15 @@ if(APPLE)
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
@@ -24,7 +29,6 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
# Faiss configuration - streamlined build
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
@@ -32,24 +36,4 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
|
||||
# Disable additional SIMD versions to speed up compilation
|
||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# Additional optimization options from INSTALL.md
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||
|
||||
# Avoid building demos and benchmarks
|
||||
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# NEW: Tell Faiss to only build the generic version
|
||||
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
@@ -1,9 +1,10 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
from typing import Dict, Any, List, Literal
|
||||
import pickle
|
||||
import shutil
|
||||
import logging
|
||||
import time
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -15,8 +16,6 @@ from leann.interface import (
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_metric_map():
|
||||
from . import faiss # type: ignore
|
||||
@@ -58,9 +57,13 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||
data = data.astype(np.float32)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -82,7 +85,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
|
||||
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
print(f"INFO: Converting HNSW index to {mode_str} format...")
|
||||
|
||||
csr_temp_file = index_file.with_suffix(".csr.tmp")
|
||||
|
||||
@@ -91,11 +94,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("✅ CSR conversion successful.")
|
||||
# index_file_old = index_file.with_suffix(".old")
|
||||
# shutil.move(str(index_file), str(index_file_old))
|
||||
print("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
logger.info(
|
||||
print(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
else:
|
||||
@@ -132,22 +135,31 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
hnsw_config.is_recompute = (
|
||||
self.is_pruned
|
||||
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load label mapping
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||
|
||||
with open(label_map_file, "rb") as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -165,7 +177,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
- "global": Use global PQ queue size for selection (default)
|
||||
- "local": Local pruning, sort and select best candidates
|
||||
- "proportional": Base selection on new neighbor count ratio
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
zmq_port: ZMQ port for embedding server
|
||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||
|
||||
@@ -174,14 +186,15 @@ class HNSWSearcher(BaseSearcher):
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
if not recompute_embeddings:
|
||||
if self.is_pruned:
|
||||
raise RuntimeError("Recompute is required for pruned index.")
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
# Use recompute_embeddings parameter
|
||||
use_recompute = recompute_embeddings or self.is_pruned
|
||||
if use_recompute:
|
||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
if not meta_file_path.exists():
|
||||
raise RuntimeError(
|
||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
||||
)
|
||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -189,10 +202,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.normalize_L2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = (
|
||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
)
|
||||
params.zmq_port = zmq_port
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
@@ -229,7 +239,11 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,22 +6,12 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.10"
|
||||
version = "0.1.0"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.1.10",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
]
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
wheel.packages = ["leann_backend_hnsw"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
cmake.build-type = "Debug"
|
||||
build.verbose = true
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ff22e2c86b...2547df4377
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: a0b2ec09da...9b801f087a
@@ -4,27 +4,16 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.10"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
version = "0.1.0"
|
||||
description = "Core API and plugin system for Leann."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
|
||||
# All required dependencies included
|
||||
dependencies = [
|
||||
"numpy>=1.20.0",
|
||||
"tqdm>=4.60.0",
|
||||
"psutil>=5.8.0",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
"torch>=2.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"tqdm>=4.60.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -5,18 +5,16 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
@@ -24,8 +22,7 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -42,63 +39,251 @@ def compute_embeddings(
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
if port is None:
|
||||
raise ValueError("port is required when use_server is True")
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
# Override mode for backward compatibility
|
||||
if use_mlx:
|
||||
mode = "mlx"
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
else:
|
||||
# Use direct computation (for build_index)
|
||||
from .embedding_compute import (
|
||||
compute_embeddings as compute_embeddings_direct,
|
||||
)
|
||||
|
||||
return compute_embeddings_direct(
|
||||
chunks,
|
||||
model_name,
|
||||
mode=mode,
|
||||
is_build=is_build,
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_via_server(
|
||||
chunks: List[str], model_name: str, port: int
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
if not use_server:
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
||||
) from e
|
||||
|
||||
# Load model using sentence-transformers
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
|
||||
# Generate embeddings
|
||||
# give use an warning if OOM here means we need to turn down the batch size
|
||||
embeddings = model.encode(
|
||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using OpenAI API."""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(chunks), max_batch_size)
|
||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
||||
except ImportError:
|
||||
# Fallback without progress bar
|
||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
id: str
|
||||
@@ -114,24 +299,25 @@ class PassageManager:
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
|
||||
for source in passage_sources:
|
||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
if source["type"] == "jsonl":
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Passage index file not found: {index_file}"
|
||||
)
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
# Build global map for O(1) lookup
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
@@ -142,7 +328,7 @@ class LeannBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
embedding_model: str = "facebook/contriever-msmarco",
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
@@ -158,12 +344,14 @@ class LeannBuilder:
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.backend_kwargs = backend_kwargs
|
||||
if 'mlx' in self.embedding_model:
|
||||
self.embedding_mode = "mlx"
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||
self.chunks.append(chunk_data)
|
||||
|
||||
@@ -189,13 +377,10 @@ class LeannBuilder:
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
chunk_iterator = tqdm(
|
||||
self.chunks, desc="Writing passages", unit="chunk"
|
||||
)
|
||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||
except ImportError:
|
||||
chunk_iterator = self.chunks
|
||||
|
||||
|
||||
for chunk in chunk_iterator:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
@@ -213,11 +398,7 @@ class LeannBuilder:
|
||||
pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
is_build=True,
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
@@ -291,7 +472,7 @@ class LeannBuilder:
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
print(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
)
|
||||
|
||||
@@ -299,7 +480,7 @@ class LeannBuilder:
|
||||
if len(self.chunks) != len(ids):
|
||||
# If no text chunks provided, create placeholder text entries
|
||||
if not self.chunks:
|
||||
logger.info("No text chunks provided, creating placeholder entries...")
|
||||
print("No text chunks provided, creating placeholder entries...")
|
||||
for id_val in ids:
|
||||
self.add_text(
|
||||
f"Document {id_val}",
|
||||
@@ -374,19 +555,15 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||
)
|
||||
print(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
@@ -394,15 +571,16 @@ class LeannSearcher:
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
# Backward compatibility with use_mlx
|
||||
if self.meta_data.get("use_mlx", False):
|
||||
self.embedding_mode = "mlx"
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -411,39 +589,26 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: int = 5557,
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
logger.info("🔍 LeannSearcher.search() called:")
|
||||
logger.info(f" Query: '{query}'")
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
print(f" Query: '{query}'")
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
if recompute_embeddings:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
del expected_zmq_port
|
||||
zmq_time = time.time() - start_time
|
||||
logger.info(f" Launching server time: {zmq_time} seconds")
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||
print(f" Embedding time: {embedding_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
results = self.backend_impl.search(
|
||||
@@ -458,14 +623,14 @@ class LeannSearcher:
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(
|
||||
print(f" Search time: {search_time} seconds")
|
||||
print(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
print(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
):
|
||||
@@ -479,25 +644,15 @@ class LeannSearcher:
|
||||
metadata=passage_data.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
|
||||
# Color codes for better logging
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Truncate text for display (first 100 chars)
|
||||
display_text = passage_data['text']
|
||||
logger.info(
|
||||
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||
print(
|
||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
||||
)
|
||||
except KeyError:
|
||||
RED = "\033[91m"
|
||||
logger.error(
|
||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||
print(
|
||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
||||
)
|
||||
|
||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||
print(f" Final enriched results: {len(enriched_results)} passages")
|
||||
return enriched_results
|
||||
|
||||
|
||||
@@ -519,15 +674,15 @@ class LeannChat:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int = 5557,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
llm_kwargs = {}
|
||||
search_time = time.time()
|
||||
|
||||
results = self.searcher.search(
|
||||
question,
|
||||
top_k=top_k,
|
||||
@@ -536,11 +691,9 @@ class LeannChat:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
zmq_port=zmq_port,
|
||||
**search_kwargs,
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
import os
|
||||
import difflib
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -29,68 +28,6 @@ def check_ollama_models() -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||
"""Check if a model exists in Ollama's remote library and return available tags
|
||||
|
||||
Returns:
|
||||
(model_exists, available_tags): bool and list of matching tags
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
import re
|
||||
|
||||
# Split model name and tag
|
||||
if ':' in model_name:
|
||||
base_model, requested_tag = model_name.split(':', 1)
|
||||
else:
|
||||
base_model, requested_tag = model_name, None
|
||||
|
||||
# First check if base model exists in library
|
||||
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||
if library_response.status_code != 200:
|
||||
return True, [] # Assume exists if can't check
|
||||
|
||||
# Extract model names from library page
|
||||
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||
|
||||
if base_model not in models_in_library:
|
||||
return False, [] # Base model doesn't exist
|
||||
|
||||
# If base model exists, get available tags
|
||||
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||
if tags_response.status_code != 200:
|
||||
return True, [] # Base model exists but can't get tags
|
||||
|
||||
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||
|
||||
# Clean up tags - remove HTML artifacts and duplicates
|
||||
available_tags = []
|
||||
seen = set()
|
||||
for tag in raw_tags:
|
||||
# Skip if it looks like HTML (contains < or >)
|
||||
if '<' in tag or '>' in tag:
|
||||
continue
|
||||
if tag not in seen:
|
||||
seen.add(tag)
|
||||
available_tags.append(tag)
|
||||
|
||||
# Check if exact model exists
|
||||
if requested_tag is None:
|
||||
# User just requested base model, suggest tags
|
||||
return True, available_tags[:10] # Return up to 10 tags
|
||||
else:
|
||||
exact_match = model_name in available_tags
|
||||
return exact_match, available_tags[:10]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If scraping fails, assume model might exist (don't block user)
|
||||
return True, []
|
||||
|
||||
|
||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||
"""Use intelligent fuzzy search for Ollama models"""
|
||||
if not available_models:
|
||||
@@ -306,66 +243,24 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
if available_models and model_name not in available_models:
|
||||
# Use intelligent fuzzy search based on locally installed models
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
# Check if the model exists remotely and get available tags
|
||||
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||
|
||||
if model_exists_remotely and model_name in available_tags:
|
||||
# Exact model exists remotely - suggest pulling it
|
||||
error_msg += f"\n\nTo install the requested model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n"
|
||||
|
||||
# Show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
elif model_exists_remotely and available_tags:
|
||||
# Base model exists but requested tag doesn't - suggest correct tags
|
||||
base_model = model_name.split(':')[0]
|
||||
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||
|
||||
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||
for i, tag in enumerate(available_tags[:8], 1):
|
||||
error_msg += f" {i}. ollama pull {tag}\n"
|
||||
if len(available_tags) > 8:
|
||||
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||
|
||||
# Also show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
# Model doesn't exist remotely - show fuzzy suggestions
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
error_msg += "\n\nYour installed models:\n"
|
||||
for i, model in enumerate(available_models[:8], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
error_msg += "\n\nCommands:"
|
||||
error_msg += "\n ollama list # List installed models"
|
||||
if model_exists_remotely and available_tags:
|
||||
if model_name in available_tags:
|
||||
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||
else:
|
||||
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||
error_msg += "\nTo list all models: ollama list"
|
||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
||||
return error_msg
|
||||
|
||||
elif llm_type == "hf":
|
||||
@@ -480,9 +375,8 @@ class OllamaChat(LLMInterface):
|
||||
"stream": False, # Keep it simple for now
|
||||
"options": kwargs,
|
||||
}
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
logger.info(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||
response = requests.post(full_url, data=json.dumps(payload))
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -502,7 +396,7 @@ class OllamaChat(LLMInterface):
|
||||
|
||||
|
||||
class HFChat(LLMInterface):
|
||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||
"""LLM interface for local Hugging Face Transformers models."""
|
||||
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
@@ -513,7 +407,7 @@ class HFChat(LLMInterface):
|
||||
raise ValueError(model_error)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers.pipelines import pipeline
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -522,101 +416,54 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Auto-detect device
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
device = "cuda"
|
||||
logger.info("CUDA is available. Using GPU.")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
device = "mps"
|
||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||
else:
|
||||
self.device = "cpu"
|
||||
device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
print('kwargs in HF: ', kwargs)
|
||||
# Check if this is a Qwen model and add /no_think by default
|
||||
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||
|
||||
# For Qwen models, automatically add /no_think to the prompt
|
||||
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||
prompt = prompt + " /no_think"
|
||||
|
||||
# Prepare chat template
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Apply chat template if available
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
try:
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||
formatted_prompt = prompt
|
||||
else:
|
||||
# Fallback for models without chat template
|
||||
formatted_prompt = prompt
|
||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
||||
if "max_tokens" in kwargs:
|
||||
# Prefer user-provided max_new_tokens if both are present
|
||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
||||
# Remove the unsupported key to avoid errors in Transformers
|
||||
kwargs.pop("max_tokens")
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
if self.device != "cpu":
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
# Handle temperature=0 edge-case for greedy decoding
|
||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
||||
# Remove unsupported zero temperature and use deterministic generation
|
||||
kwargs.pop("temperature")
|
||||
kwargs.setdefault("do_sample", False)
|
||||
|
||||
# Set generation parameters
|
||||
generation_config = {
|
||||
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"top_p": kwargs.get("top_p", 0.9),
|
||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
# Handle temperature=0 for greedy decoding
|
||||
if generation_config["temperature"] == 0.0:
|
||||
generation_config["do_sample"] = False
|
||||
generation_config.pop("temperature")
|
||||
# Sensible defaults for text generation
|
||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||
results = self.pipeline(prompt, **params)
|
||||
|
||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**generation_config
|
||||
# Handle different response formats from transformers
|
||||
if isinstance(results, list) and len(results) > 0:
|
||||
generated_text = (
|
||||
results[0].get("generated_text", "")
|
||||
if isinstance(results[0], dict)
|
||||
else str(results[0])
|
||||
)
|
||||
else:
|
||||
generated_text = str(results)
|
||||
|
||||
# Decode response
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
return response.strip()
|
||||
# Extract only the newly generated portion by removing the original prompt
|
||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||
response = generated_text[len(prompt) :].strip()
|
||||
else:
|
||||
# Fallback: return the full response if prompt removal fails
|
||||
response = str(generated_text)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class OpenAIChat(LLMInterface):
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
return meta_file.exists()
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="leann",
|
||||
description="LEANN - Local Enhanced AI Navigation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
""",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, required=True, help="Documents directory"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--embedding-model", type=str, default="facebook/contriever"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--force", "-f", action="store_true", help="Force rebuild"
|
||||
)
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("index_name", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5)
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="ollama",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
)
|
||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
|
||||
if not index_dirs:
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(
|
||||
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||
) / (1024 * 1024)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print(f"\nUsage:")
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
is_recompute=args.recompute,
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
print(f" {result.text[:200]}...")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
index_name = args.index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Starting chat with index '{index_name}'...")
|
||||
print(f"Using {args.model} ({args.llm})")
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "list":
|
||||
self.list_indexes()
|
||||
elif args.command == "build":
|
||||
await self.build_index(args)
|
||||
elif args.command == "search":
|
||||
await self.search_documents(args)
|
||||
elif args.command == "ask":
|
||||
await self.ask_questions(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
def main():
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
cli = LeannCLI()
|
||||
asyncio.run(cli.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,375 +0,0 @@
|
||||
"""
|
||||
Unified embedding computation module
|
||||
Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
is_build: bool = False,
|
||||
batch_size: int = 32,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
batch_size: Batch size for processing
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
texts,
|
||||
model_name,
|
||||
is_build=is_build,
|
||||
batch_size=batch_size,
|
||||
adaptive_optimization=adaptive_optimization,
|
||||
)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
is_build: bool = False,
|
||||
adaptive_optimization: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
"""
|
||||
# Handle empty input
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# Apply optimizations based on benchmark results
|
||||
if adaptive_optimization:
|
||||
# Use optimal batch_size constants for different devices based on benchmark results
|
||||
if device == "mps":
|
||||
batch_size = 128 # MPS optimal batch size from benchmark
|
||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||
batch_size = 32
|
||||
elif device == "cuda":
|
||||
batch_size = 256 # CUDA optimal batch size
|
||||
# Keep original batch_size for CPU
|
||||
|
||||
# Create cache key
|
||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||
|
||||
# Check if model is already cached
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached optimized model: {model_name}")
|
||||
model = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
||||
)
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Apply hardware optimizations
|
||||
if device == "cuda":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.cuda.set_per_process_memory_fraction(0.9)
|
||||
elif device == "mps":
|
||||
try:
|
||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||
torch.mps.set_per_process_memory_fraction(0.9)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"Some MPS optimizations not available in this PyTorch version"
|
||||
)
|
||||
elif device == "cpu":
|
||||
# TODO: Haven't tested this yet
|
||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||
try:
|
||||
torch.backends.mkldnn.enabled = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Prepare optimized model and tokenizer parameters
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True,
|
||||
"attn_implementation": "eager", # Use eager attention for speed
|
||||
}
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Try local loading first
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info("Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
logger.info("Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations based on mode
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
logger.info(f"Applied FP16 precision: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"FP16 optimization failed: {e}")
|
||||
|
||||
# Apply torch.compile optimization
|
||||
if device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
|
||||
logger.info(f"Applied torch.compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"torch.compile optimization failed: {e}")
|
||||
|
||||
# Set model to eval mode and disable gradients for inference
|
||||
model.eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
# Cache the model
|
||||
_model_cache[cache_key] = model
|
||||
logger.info(f"Model cached: {cache_key}")
|
||||
|
||||
# Compute embeddings with optimized inference mode
|
||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||
|
||||
# Use torch.inference_mode for optimal performance
|
||||
with torch.inference_mode():
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
# Validate results
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
raise RuntimeError(
|
||||
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
# Cache OpenAI client
|
||||
cache_key = "openai_client"
|
||||
if cache_key in _model_cache:
|
||||
client = _model_cache[cache_key]
|
||||
else:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
_model_cache[cache_key] = client
|
||||
logger.info("OpenAI client cached")
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(texts), max_batch_size)
|
||||
batch_iterator = tqdm(
|
||||
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback when tqdm is not available
|
||||
batch_iterator = range(0, len(texts), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_texts = texts[i : i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {i} failed: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
logger.info(
|
||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(
|
||||
chunks: List[str], model_name: str, batch_size: int = 16
|
||||
) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Cache MLX model and tokenizer
|
||||
cache_key = f"mlx_{model_name}"
|
||||
if cache_key in _model_cache:
|
||||
logger.info(f"Using cached MLX model: {model_name}")
|
||||
model, tokenizer = _model_cache[cache_key]
|
||||
else:
|
||||
logger.info(f"Loading and caching MLX model: {model_name}")
|
||||
model, tokenizer = load(model_name)
|
||||
_model_cache[cache_key] = (model, tokenizer)
|
||||
logger.info(f"MLX model cached: {cache_key}")
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
batch_iterator = tqdm(
|
||||
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||
)
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
@@ -1,21 +1,14 @@
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import zmq
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import psutil
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||
format="%(levelname)s - %(name)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
import select
|
||||
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
@@ -24,135 +17,151 @@ def _check_port(port: int) -> bool:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def _check_process_matches_config(
|
||||
port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
||||
"""
|
||||
Check if the process using the port matches our expected model and passages file.
|
||||
Returns True if matches, False otherwise.
|
||||
Check if the existing server on the port is using the correct meta file.
|
||||
Returns True if the server has the right meta path, False otherwise.
|
||||
"""
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||
if not _is_process_listening_on_port(proc, port):
|
||||
continue
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
cmdline = proc.info["cmdline"]
|
||||
if not cmdline:
|
||||
continue
|
||||
# Send a special control message to query the server's meta path
|
||||
control_request = ["__QUERY_META_PATH__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
return _check_cmdline_matches_config(
|
||||
cmdline, port, expected_model, expected_passages_file
|
||||
)
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the meta path and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_meta_path = response[0]
|
||||
# Normalize paths for comparison
|
||||
expected_path = Path(expected_meta_path).resolve()
|
||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
||||
return server_path == expected_path
|
||||
|
||||
logger.debug(f"No process found listening on port {port}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check process on port {port}: {e}")
|
||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||
"""Check if a process is listening on the given port."""
|
||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's meta path.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
connections = proc.net_connections()
|
||||
for conn in connections:
|
||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||
return True
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the meta path
|
||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
return False
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _check_cmdline_matches_config(
|
||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""Check if command line matches our expected configuration."""
|
||||
cmdline_str = " ".join(cmdline)
|
||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||
|
||||
# Check if it's our embedding server
|
||||
is_embedding_server = any(
|
||||
server_type in cmdline_str
|
||||
for server_type in [
|
||||
"embedding_server",
|
||||
"leann_backend_diskann.embedding_server",
|
||||
"leann_backend_hnsw.hnsw_embedding_server",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_embedding_server:
|
||||
logger.debug(f"Process on port {port} is not our embedding server")
|
||||
return False
|
||||
|
||||
# Check model name
|
||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||
|
||||
# Check passages file if provided
|
||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||
|
||||
result = model_matches and passages_matches
|
||||
logger.debug(
|
||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||
"""Check if the command line contains the expected model."""
|
||||
if "--model-name" not in cmdline:
|
||||
return False
|
||||
|
||||
model_idx = cmdline.index("--model-name")
|
||||
if model_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_model = cmdline[model_idx + 1]
|
||||
return actual_model == expected_model
|
||||
|
||||
|
||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||
"""Check if the command line contains the expected passages file."""
|
||||
if "--passages-file" not in cmdline:
|
||||
return False # Expected but not found
|
||||
|
||||
passages_idx = cmdline.index("--passages-file")
|
||||
if passages_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_passages = cmdline[passages_idx + 1]
|
||||
expected_path = Path(expected_passages_file).resolve()
|
||||
actual_path = Path(actual_passages).resolve()
|
||||
return actual_path == expected_path
|
||||
|
||||
|
||||
def _find_compatible_port_or_next_available(
|
||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||
) -> tuple[int, bool]:
|
||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
||||
"""
|
||||
Find a port that either has a compatible server or is available.
|
||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||
Check if the existing server on the port is using the correct embedding model.
|
||||
Returns True if the server has the right model, False otherwise.
|
||||
"""
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if not _check_port(port):
|
||||
# Port is available
|
||||
return port, False
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Port is in use, check if it's compatible
|
||||
if _check_process_matches_config(port, model_name, passages_file):
|
||||
logger.info(f"Found compatible server on port {port}")
|
||||
return port, True
|
||||
else:
|
||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||
# Send a special control message to query the server's model
|
||||
control_request = ["__QUERY_MODEL__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||
)
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the model name and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_model = response[0]
|
||||
return server_model == expected_model
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_model(port: int, new_model: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's embedding model.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the model
|
||||
control_request = ["__UPDATE_MODEL__", new_model]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||
"""
|
||||
|
||||
def __init__(self, backend_module_name: str):
|
||||
@@ -166,185 +175,246 @@ class EmbeddingServerManager:
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
self._atexit_registered = False
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
||||
"""
|
||||
Starts the embedding server process.
|
||||
|
||||
Args:
|
||||
port (int): The preferred ZMQ port for the server.
|
||||
port (int): The ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
|
||||
Returns:
|
||||
tuple[bool, int]: (success, actual_port_used)
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
"""
|
||||
passages_file = kwargs.get("passages_file")
|
||||
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
# Even if we have a running process, check if model/meta path match
|
||||
if self.server_port is not None:
|
||||
port_in_use = _check_port(self.server_port)
|
||||
if port_in_use:
|
||||
print(
|
||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
||||
)
|
||||
|
||||
# Check if we have a compatible running server
|
||||
if self._has_compatible_running_server(model_name, passages_file):
|
||||
assert self.server_port is not None, (
|
||||
"a compatible running server should set server_port"
|
||||
)
|
||||
return True, self.server_port
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(self.server_port, model_name):
|
||||
print(
|
||||
"❌ Failed to update existing server model. Restarting server..."
|
||||
)
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
print(
|
||||
f"✅ Successfully updated existing server model to: {model_name}"
|
||||
)
|
||||
|
||||
# Find available port (compatible or free)
|
||||
try:
|
||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||
port, model_name, passages_file
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(str(e))
|
||||
return False, port
|
||||
# Also check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
if is_compatible:
|
||||
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||
self.server_port = actual_port
|
||||
self.server_process = None # We don't own this process
|
||||
return True, actual_port
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
# No port stored - restart
|
||||
print("⚠️ No port information stored. Restarting server...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
|
||||
if actual_port != port:
|
||||
logger.info(f"Using port {actual_port} instead of {port}")
|
||||
if _check_port(port):
|
||||
# Port is in use, check if it's using the correct meta file and model
|
||||
passages_file = kwargs.get("passages_file")
|
||||
|
||||
# Start new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
||||
|
||||
def _has_compatible_running_server(
|
||||
self, model_name: str, passages_file: str
|
||||
) -> bool:
|
||||
"""Check if we have a compatible running server."""
|
||||
if not (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
):
|
||||
return False
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(port, model_name):
|
||||
raise RuntimeError(
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
|
||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||
logger.info(
|
||||
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
)
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
||||
if not meta_matches:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
||||
)
|
||||
if not _update_server_meta_path(port, str(passages_file)):
|
||||
raise RuntimeError(
|
||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
||||
)
|
||||
print(
|
||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
||||
)
|
||||
|
||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
"Existing server process is incompatible. Should start a new server."
|
||||
print(
|
||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
||||
)
|
||||
return False
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
# Add extra arguments for specific backends
|
||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1, # Line buffered
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print(
|
||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
||||
)
|
||||
self._print_recent_output()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(
|
||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
||||
)
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _build_server_command(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> list:
|
||||
"""Build the command to start the embedding server."""
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
def _print_recent_output(self):
|
||||
"""Print any recent output from the server process."""
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
return
|
||||
try:
|
||||
# Read any available output
|
||||
|
||||
if kwargs.get("passages_file"):
|
||||
# Convert to absolute path to ensure subprocess can find the file
|
||||
passages_file = Path(kwargs["passages_file"]).resolve()
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||
output = self.server_process.stdout.read()
|
||||
if output:
|
||||
print(f"[{self.backend_module_name} OUTPUT]: {output}")
|
||||
except Exception as e:
|
||||
print(f"Error reading server output: {e}")
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
|
||||
# Let server output go directly to console
|
||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
if not self._atexit_registered:
|
||||
# Use a lambda to avoid issues with bound methods
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready."""
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
logger.info("Embedding server is ready!")
|
||||
return True, port
|
||||
|
||||
if self.server_process and self.server_process.poll() is not None:
|
||||
logger.error("Server terminated during startup.")
|
||||
return False, port
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False, port
|
||||
def _log_monitor(self):
|
||||
"""Monitors and prints the server's stdout and stderr."""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
while True:
|
||||
line = self.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
print(
|
||||
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if not self.server_process:
|
||||
return
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
# Process already terminated
|
||||
self.server_process = None
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(
|
||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
||||
)
|
||||
self.server_process.kill()
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
"WARNING: Server process did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
from typing import Dict, Any, List, Literal
|
||||
|
||||
|
||||
class LeannBackendBuilderInterface(ABC):
|
||||
@@ -34,13 +34,6 @@ class LeannBackendSearcherInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
@@ -51,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -64,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -74,10 +67,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embedding(
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
@@ -7,37 +7,30 @@ import importlib.metadata
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
||||
|
||||
def register_backend(name: str):
|
||||
"""A decorator to register a new backend class."""
|
||||
|
||||
def decorator(cls):
|
||||
print(f"INFO: Registering backend '{name}'")
|
||||
BACKEND_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def autodiscover_backends():
|
||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||
# print("INFO: Starting backend auto-discovery...")
|
||||
print("INFO: Starting backend auto-discovery...")
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
dist_name = dist.metadata['name']
|
||||
if dist_name.startswith('leann-backend-'):
|
||||
backend_module_name = dist_name.replace('-', '_')
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
for backend_module_name in sorted(
|
||||
discovered_backends
|
||||
): # sort for deterministic loading
|
||||
|
||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||
try:
|
||||
importlib.import_module(backend_module_name)
|
||||
# Registration message is printed by the decorator
|
||||
except ImportError as e:
|
||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
pass
|
||||
# print("INFO: Backend auto-discovery finished.")
|
||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
print("INFO: Backend auto-discovery finished.")
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Literal, Optional
|
||||
from typing import Dict, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -42,10 +43,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||
)
|
||||
|
||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
self.label_map = self._load_label_map()
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name,
|
||||
backend_module_name=backend_module_name
|
||||
)
|
||||
|
||||
def _load_meta(self) -> Dict[str, Any]:
|
||||
@@ -57,9 +58,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_label_map(self) -> Dict[int, str]:
|
||||
"""Loads the mapping from integer IDs to string IDs."""
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
with open(label_map_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: int, **kwargs
|
||||
) -> int:
|
||||
) -> None:
|
||||
"""
|
||||
Ensures the embedding server is running if recompute is needed.
|
||||
This is a helper for subclasses.
|
||||
@@ -69,26 +78,21 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
||||
)
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
embedding_mode=embedding_mode,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(
|
||||
f"Failed to start embedding server on port {actual_port}"
|
||||
)
|
||||
|
||||
return actual_port
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
def compute_query_embedding(
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int = 5557,
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embedding for a query string.
|
||||
@@ -102,21 +106,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if use_server_if_available:
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
try:
|
||||
# TODO: Maybe we can directly use this port here?
|
||||
# For this internal method, it's ok to assume that the server is running
|
||||
# on that port?
|
||||
|
||||
# Ensure we have a server with passages_file for compatibility
|
||||
passages_source_file = (
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
# Convert to absolute path to ensure server can find it
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file.resolve()), zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
@@ -125,7 +120,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
print("⏭️ Falling back to direct model loading...")
|
||||
|
||||
# Fallback to direct computation
|
||||
from .embedding_compute import compute_embeddings
|
||||
from .api import compute_embeddings
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
||||
@@ -172,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -186,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||
zmq_port: ZMQ port for embedding server communication
|
||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# LEANN - The smallest vector index in the world
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Default installation (HNSW backend, recommended)
|
||||
uv pip install leann
|
||||
|
||||
# With DiskANN backend (for large-scale deployments)
|
||||
uv pip install leann[diskann]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.build_index("my_index.leann")
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher("my_index.leann")
|
||||
results = searcher.search("storage savings", top_k=3)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||
response = chat.ask("How much storage does LEANN save?")
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||
|
||||
A revolutionary vector database that democratizes personal AI.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Re-export main API from leann-core
|
||||
from leann_core import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]
|
||||
@@ -1,42 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.10"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { text = "MIT" }
|
||||
authors = [
|
||||
{ name = "LEANN Team" }
|
||||
]
|
||||
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw
|
||||
dependencies = [
|
||||
"leann-core>=0.1.0",
|
||||
"leann-backend-hnsw>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diskann = [
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/leann"
|
||||
Documentation = "https://leann.readthedocs.io"
|
||||
Repository = "https://github.com/yourusername/leann"
|
||||
Issues = "https://github.com/yourusername/leann/issues"
|
||||
@@ -9,6 +9,7 @@ requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
"leann-backend-diskann",
|
||||
"leann-backend-hnsw",
|
||||
"numpy>=1.26.0",
|
||||
"torch",
|
||||
@@ -33,9 +34,8 @@ dependencies = [
|
||||
"msgpack>=1.1.1",
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"psutil>=5.8.0",
|
||||
"mlx>=0.26.3",
|
||||
"mlx-lm>=0.26.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -48,10 +48,6 @@ dev = [
|
||||
"huggingface-hub>=0.20.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
"leann-backend-diskann",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = []
|
||||
|
||||
|
||||
12
research/micro/analyze_HNSW.py
Normal file
12
research/micro/analyze_HNSW.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import faiss
|
||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(hnsw_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
|
||||
# save_degree_distribution
|
||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
||||
11
research/micro/analyze_NSG.py
Normal file
11
research/micro/analyze_NSG.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import faiss
|
||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(nsg_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
||||
|
||||
# save degree distribution
|
||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
||||
63
research/micro/bnbtest.py
Normal file
63
research/micro/bnbtest.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
# import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
M = 2048
|
||||
N = 2048
|
||||
|
||||
bsz = 2048
|
||||
import torch_int
|
||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
||||
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(M, N),
|
||||
# nn.Linear(2048, 2048)
|
||||
)
|
||||
|
||||
int8_model = nn.Sequential(
|
||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
||||
)
|
||||
|
||||
int8_model.load_state_dict(fp16_model.state_dict())
|
||||
int8_model = int8_model.to(0) # Quantization happens here
|
||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
||||
|
||||
# Create random input tensor
|
||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int8_time
|
||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
||||
89
research/micro/data/transformer-batching-microbenchmarks.csv
Normal file
89
research/micro/data/transformer-batching-microbenchmarks.csv
Normal file
@@ -0,0 +1,89 @@
|
||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
||||
|
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
594
research/micro/embedd_micro.py
Executable file
594
research/micro/embedd_micro.py
Executable file
@@ -0,0 +1,594 @@
|
||||
# python embedd_micro.py --use_int8 Fastest
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao import quantize_
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_int4: bool = False
|
||||
use_int8: bool = False # Add this parameter
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
# Create quantization config
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Try xformers if available
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
# Int8 quantization using HuggingFace integration
|
||||
# Int8 quantization using TorchAO
|
||||
elif self.config.use_int8:
|
||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
||||
|
||||
# Import the quantize_ function and the quantization config
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||
print("- Successfully imported TorchAO")
|
||||
|
||||
# Load model normally first
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
print("- Model loaded in full precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply quantization - call the function to get the config, then apply it
|
||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
# For older PyTorch versions that have issues with tensor subclasses
|
||||
from torchao.utils import unwrap_tensor_subclass
|
||||
import torch
|
||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
||||
unwrap_tensor_subclass(model)
|
||||
|
||||
# Apply optimizations
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
# For better performance with int8 dynamic quantization
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
print("- Enabled fusion of int matmul with mul operations")
|
||||
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Reset peak memory stats
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = (
|
||||
self.cuda_graphs.get_or_create(batch_size)
|
||||
if self.cuda_graphs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
if i == 0: # Only print on first run
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
# Log memory usage
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
action="store_true",
|
||||
help="Enable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int4",
|
||||
action="store_true",
|
||||
help="Enable INT4 quantization using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
action="store_true",
|
||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_linear8bitlt",
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int4=args.use_int4,
|
||||
use_int8=args.use_int8, # Add this line
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
376
research/micro/embedd_micro_seq.py
Normal file
376
research/micro/embedd_micro_seq.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
import math
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
||||
|
||||
if effective_batch_size not in self.graphs:
|
||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
||||
self.model, effective_batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[effective_batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16:
|
||||
model = model.half()
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
# No need to do anything as it's automatically enabled
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Optimize LayerNorm
|
||||
try:
|
||||
num_layernorms = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.LayerNorm):
|
||||
module.forward = torch.jit.script(module.forward)
|
||||
num_layernorms += 1
|
||||
if num_layernorms > 0:
|
||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
||||
except Exception as e:
|
||||
print(f"- LayerNorm optimization failed: {e}")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model = self._load_model()
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
return ModelOptimizer.optimize(model, self.config)
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
original_batch_size = input_ids.shape[0]
|
||||
print(f"Original input_ids shape: {input_ids.shape}")
|
||||
|
||||
# Split large batches to avoid OOM
|
||||
max_batch_size = self.config.max_batch_size
|
||||
if original_batch_size > max_batch_size:
|
||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
||||
total_time = 0
|
||||
outputs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, original_batch_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, original_batch_size)
|
||||
batch_slice = input_ids[i:end_idx]
|
||||
mask_slice = attention_mask[i:end_idx]
|
||||
|
||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
||||
|
||||
# Use CUDA graph if available (with the smaller batch size)
|
||||
chunk_cuda_graph = None
|
||||
if cuda_graph_wrapper is not None:
|
||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
||||
|
||||
with self.timer.timing():
|
||||
if chunk_cuda_graph is not None:
|
||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
||||
else:
|
||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
||||
|
||||
total_time += self.timer.elapsed_time()
|
||||
outputs.append(chunk_output.last_hidden_state)
|
||||
|
||||
# Combine outputs
|
||||
combined_output = torch.cat(outputs, dim=0)
|
||||
print(f"Combined output shape: {combined_output.shape}")
|
||||
|
||||
# Create a wrapper object similar to model output to maintain consistency
|
||||
class DummyOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.last_hidden_state = hidden_states
|
||||
|
||||
output = DummyOutput(combined_output)
|
||||
return total_time, output
|
||||
else:
|
||||
# Process normally for small batches
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = None
|
||||
if self.cuda_graphs is not None:
|
||||
if batch_size <= self.config.max_batch_size:
|
||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
||||
else:
|
||||
# For large batches, we'll use the max_batch_size graph in chunks
|
||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
# Run benchmark
|
||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
times.append(elapsed_time)
|
||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_fp16",
|
||||
action="store_true",
|
||||
help="Disable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_batch_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum batch size before splitting to prevent OOM",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=not args.no_fp16,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
max_batch_size=args.max_batch_size,
|
||||
)
|
||||
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Print overall summary
|
||||
print("\n===== BENCHMARK SUMMARY =====")
|
||||
print(f"Model: {config.model_path}")
|
||||
print(f"Sequence Length: {config.seq_length}")
|
||||
print(f"FP16: {config.use_fp16}")
|
||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
||||
print(f"Flash Attention: {config.use_flash_attention}")
|
||||
print(f"Max Batch Size: {config.max_batch_size}")
|
||||
print("\nResults:")
|
||||
|
||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
||||
print("-" * 50)
|
||||
for bs in sorted(results.keys()):
|
||||
r = results[bs]
|
||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
218
research/micro/int4benchmark.py
Normal file
218
research/micro/int4benchmark.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Import necessary functions from the quantize.py file
|
||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
||||
# needed for GPTQ with padding
|
||||
if groupsize > w.shape[-1]:
|
||||
groupsize = w.shape[-1]
|
||||
assert groupsize > 1
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
||||
torch.bfloat16
|
||||
).reshape(w.shape[0], -1)
|
||||
|
||||
def pack_scales_and_zeros(scales, zeros):
|
||||
assert scales.shape == zeros.shape
|
||||
assert scales.dtype == torch.bfloat16
|
||||
assert zeros.dtype == torch.bfloat16
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
||||
],
|
||||
2,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
||||
return w_int32, scales_and_zeros
|
||||
|
||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
||||
assert groupsize > 1
|
||||
# needed for GPTQ single column quantize
|
||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
||||
groupsize = w.shape[-1]
|
||||
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
w_int32 = (
|
||||
to_quant.sub(min_val)
|
||||
.div(scales)
|
||||
.round()
|
||||
.clamp_(min_int, max_int)
|
||||
.to(torch.int32)
|
||||
.reshape_as(w)
|
||||
)
|
||||
|
||||
return w_int32
|
||||
|
||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
||||
weight_bf16, n_bit=4, groupsize=groupsize
|
||||
)
|
||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
||||
return weight_int4pack, scales_and_zeros
|
||||
|
||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
||||
origin_x_size = x.size()
|
||||
x = x.reshape(-1, origin_x_size[-1])
|
||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
||||
new_shape = origin_x_size[:-1] + (out_features,)
|
||||
c = c.reshape(new_shape)
|
||||
return c
|
||||
|
||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int,
|
||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
|
||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales_and_zeros",
|
||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
input = input.to(torch.bfloat16)
|
||||
return linear_forward_int4(
|
||||
input,
|
||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
||||
)
|
||||
|
||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
||||
# in_features must be divisible by inner_k_tiles * 16
|
||||
# out_features must be divisible by 8
|
||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
||||
out_features = 2048 # Must be divisible by 8
|
||||
groupsize = 128
|
||||
inner_k_tiles = 8
|
||||
|
||||
# Create models
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(in_features, out_features, bias=False)
|
||||
)
|
||||
|
||||
# Create INT4 model
|
||||
int4_model = nn.Sequential(
|
||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
||||
)
|
||||
|
||||
# Quantize the weights and set up the INT4 model
|
||||
with torch.no_grad():
|
||||
# Convert FP16 weights to INT4
|
||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
||||
fp16_weight, groupsize, inner_k_tiles
|
||||
)
|
||||
|
||||
# Set the quantized weights in the INT4 model
|
||||
int4_model[0].weight.copy_(weight_int4pack)
|
||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
||||
|
||||
# Move models to GPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
fp16_model = fp16_model.to(device)
|
||||
int4_model = int4_model.to(device)
|
||||
|
||||
# Create random input tensor
|
||||
batch_size = 1024
|
||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
||||
|
||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int4_time
|
||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
||||
|
||||
# Calculate memory savings
|
||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
||||
|
||||
memory_reduction = fp16_memory / int4_memory
|
||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
||||
|
||||
# Check accuracy
|
||||
with torch.no_grad():
|
||||
fp16_output = fp16_model(input_tensor_bf16)
|
||||
int4_output = int4_model(input_tensor)
|
||||
|
||||
# Calculate error metrics
|
||||
abs_error = torch.abs(fp16_output - int4_output)
|
||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
||||
|
||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
||||
83
research/micro/int8.py
Normal file
83
research/micro/int8.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import nvmath.bindings.cublas
|
||||
import ctypes
|
||||
|
||||
# 创建 CUBLAS 句柄
|
||||
handle = nvmath.bindings.cublas.create()
|
||||
|
||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
||||
|
||||
# 确保张量在 CUDA 上
|
||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
||||
# 确保张量是连续的
|
||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
||||
|
||||
# 获取指针
|
||||
a_ptr = a.data_ptr()
|
||||
b_ptr = b.data_ptr()
|
||||
c_ptr = c.data_ptr()
|
||||
|
||||
# 设置参数
|
||||
transa = 0 # CUBLAS_OP_N (不转置)
|
||||
transb = 0 # CUBLAS_OP_N (不转置)
|
||||
transc = 0 # CUBLAS_OP_N (不转置)
|
||||
|
||||
# 设置偏置值
|
||||
a_bias = 0
|
||||
b_bias = 0
|
||||
c_bias = 0
|
||||
|
||||
# 设置正确的 leading dimensions
|
||||
lda = k # A 的 leading dimension
|
||||
ldb = n # B 的 leading dimension
|
||||
ldc = n # C 的 leading dimension
|
||||
|
||||
c_mult = 1
|
||||
c_shift = 0
|
||||
|
||||
# 打印调试信息
|
||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
||||
|
||||
try:
|
||||
# 调用 uint8gemm_bias
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr, a_bias, lda,
|
||||
b_ptr, b_bias, ldb,
|
||||
c_ptr, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
# 尝试使用 ctypes 转换指针
|
||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
||||
|
||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
||||
|
||||
# 再次尝试调用
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr_c, a_bias, lda,
|
||||
b_ptr_c, b_bias, ldb,
|
||||
c_ptr_c, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
|
||||
# 销毁 CUBLAS 句柄
|
||||
nvmath.bindings.cublas.destroy(handle)
|
||||
|
||||
# 打印结果
|
||||
print("Result:")
|
||||
print(c)
|
||||
23
research/micro/llm_compress.py
Normal file
23
research/micro/llm_compress.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from llmcompressor import oneshot
|
||||
|
||||
# Select quantization algorithm. In this case, we:
|
||||
# * apply SmoothQuant to make the activations easier to quantize
|
||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
||||
# * quantize the activations to int8 (dynamic per token)
|
||||
recipe = [
|
||||
SmoothQuantModifier(smoothing_strength=0.8),
|
||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
||||
]
|
||||
|
||||
# Apply quantization using the built in open_platypus dataset.
|
||||
# * See examples for demos showing how to pass a custom calibration set
|
||||
oneshot(
|
||||
model="facebook/contriever",
|
||||
dataset="open_platypus",
|
||||
recipe=recipe,
|
||||
output_dir="contriever-INT4",
|
||||
max_seq_length=2048,
|
||||
num_calibration_samples=512,
|
||||
)
|
||||
41
research/micro/nvmath_test.py
Normal file
41
research/micro/nvmath_test.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
||||
|
||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
||||
scales are used to dequantize input operands and quantize the result. Without proper
|
||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
||||
|
||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
||||
capability 8.9 or higher.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import nvmath
|
||||
|
||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
||||
# transpose it.
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
||||
|
||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
||||
|
||||
# Perform the multiplication. The result of the multiplication will be:
|
||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
||||
|
||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
||||
print(result_without_scaling)
|
||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
||||
print(result)
|
||||
0
research/micro/result.md
Normal file
0
research/micro/result.md
Normal file
58
research/micro/save_small_model.py
Normal file
58
research/micro/save_small_model.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from pathlib import Path
|
||||
|
||||
def save_model_in_pth_format(model_name, output_dir):
|
||||
"""
|
||||
Download a model from Hugging Face and save it in PTH format
|
||||
for use with quantization benchmarks.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model on Hugging Face
|
||||
output_dir: Directory to save the model
|
||||
"""
|
||||
print(f"Loading model {model_name}...")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Extract and save the model weights in PTH format
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
# Save the model weights
|
||||
model_path = Path(output_dir) / "model.pth"
|
||||
torch.save(model_state_dict, model_path)
|
||||
|
||||
print(f"Model saved to {model_path}")
|
||||
|
||||
# Print model size information
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
||||
|
||||
print(f"Model parameters: {param_count:,}")
|
||||
print(f"Model size: {model_size_mb:.2f} MB")
|
||||
|
||||
return model_path
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use a small model for testing
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
output_dir = "./tinyllama-1.1b-chat"
|
||||
|
||||
model_path = save_model_in_pth_format(model_name, output_dir)
|
||||
|
||||
print("\nYou can now use this model with the INT4 benchmark script.")
|
||||
print("Example command:")
|
||||
print(f"python int4benchmark.py --model_path {model_path}")
|
||||
677
research/micro/transformer-batching-benchmark.ipynb
Normal file
677
research/micro/transformer-batching-benchmark.ipynb
Normal file
@@ -0,0 +1,677 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "cab91cfc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import copy\n",
|
||||
"import dataclasses\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import pathlib\n",
|
||||
"import itertools\n",
|
||||
"import multiprocessing\n",
|
||||
"import scipy\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import pickle\n",
|
||||
"import gzip\n",
|
||||
"import threading\n",
|
||||
"import queue\n",
|
||||
"import pytz\n",
|
||||
"import traceback\n",
|
||||
"from datetime import datetime\n",
|
||||
"from tqdm.auto import tqdm, trange\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as mtick\n",
|
||||
"%matplotlib inline\n",
|
||||
"%config InlineBackend.figure_format='retina'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8d24fbd7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sat Apr 12 00:10:05 2025 \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
|
||||
"|-----------------------------------------+------------------------+----------------------+\n",
|
||||
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
||||
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
||||
"| | | MIG M. |\n",
|
||||
"|=========================================+========================+======================|\n",
|
||||
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
|
||||
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
|
||||
"| | | N/A |\n",
|
||||
"+-----------------------------------------+------------------------+----------------------+\n",
|
||||
" \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| Processes: |\n",
|
||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
||||
"| ID ID Usage |\n",
|
||||
"|=========================================================================================|\n",
|
||||
"| No running processes found |\n",
|
||||
"+-----------------------------------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "538b2c11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
|
||||
" latency = []\n",
|
||||
" \n",
|
||||
" # First run, ignore min_secs\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" \n",
|
||||
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
|
||||
" min_nanos = int(min_secs * 1e9)\n",
|
||||
" start_nanos = time.perf_counter_ns()\n",
|
||||
" while True:\n",
|
||||
" now_nanos = time.perf_counter_ns()\n",
|
||||
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
|
||||
" break\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" return np.array(latency)\n",
|
||||
"\n",
|
||||
"def tail_mean(xs, skip=0.2):\n",
|
||||
" return xs[int(len(xs) * skip):].mean()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "02c9c9b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"torch.set_grad_enabled(False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3405fdc7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
|
||||
"seqlen_list = [256]\n",
|
||||
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "10dc981a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[(12, 256), (3, 256)]\n",
|
||||
"[256]\n",
|
||||
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(nd_list)\n",
|
||||
"print(seqlen_list)\n",
|
||||
"print(bs_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7e0ee385",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" seqlen_list = [1] + seqlen_list\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" maxbs = max(bs_list)\n",
|
||||
" print(maxbs, n, d, seqlen)\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" def run():\n",
|
||||
" torch.matmul(X[:bs], W)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, X, W\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c206a502",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a3a2103c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "3aaad98a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = {}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "18137de3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/22 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_init\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "26c76e15",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_ar\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "313e36eb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"dense\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "50c37959",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(data, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "828ddb54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_dense = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"dense\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"dense\")\n",
|
||||
")\n",
|
||||
"df_qk_init = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_init\")\n",
|
||||
")\n",
|
||||
"df_qk_ar = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_ar\")\n",
|
||||
")\n",
|
||||
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "c296a395",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a25cdd5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63b8a531",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af90eff1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
|
||||
" return transformers.OPTConfig(\n",
|
||||
" num_hidden_layers=n_layers,\n",
|
||||
" hidden_size=d_model,\n",
|
||||
" ffn_dim=d_model*4,\n",
|
||||
" num_attention_heads=n_heads,\n",
|
||||
" **kwargs\n",
|
||||
" )\n",
|
||||
"optcfg = {\n",
|
||||
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
|
||||
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
|
||||
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
|
||||
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
|
||||
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
|
||||
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
|
||||
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
|
||||
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
|
||||
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
|
||||
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
|
||||
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
|
||||
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
|
||||
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b9ebbec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
|
||||
" bs, tgt_len = input_ids.shape\n",
|
||||
" if past_key_values is not None:\n",
|
||||
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
|
||||
" assert bs == _bs\n",
|
||||
" else:\n",
|
||||
" src_len = 0\n",
|
||||
" if attention_mask is None:\n",
|
||||
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
|
||||
" ret = model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" past_key_values=past_key_values,\n",
|
||||
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
|
||||
" )\n",
|
||||
" return ret\n",
|
||||
"\n",
|
||||
"def time_greedy_generate(model, input_ids, new_tokens):\n",
|
||||
" ts = []\n",
|
||||
" output = input_ids\n",
|
||||
" past_key_values = None\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
|
||||
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
|
||||
" for _ in range(new_tokens):\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" \n",
|
||||
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
|
||||
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
|
||||
" output = torch.cat([output, input_ids], axis=1)\n",
|
||||
" past_key_values = ret.past_key_values\n",
|
||||
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
|
||||
" \n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" ts.append((ed-st)/1e9)\n",
|
||||
" return np.array(ts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc92f940",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"opt_config = optcfg[\"6.7b\"]\n",
|
||||
"\n",
|
||||
"torch.set_default_dtype(torch.bfloat16)\n",
|
||||
"with transformers.modeling_utils.no_init_weights():\n",
|
||||
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
|
||||
"torch.set_default_dtype(torch.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c19fa396",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = {}\n",
|
||||
"input_tokens = 200\n",
|
||||
"new_tokens = 500\n",
|
||||
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
|
||||
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
|
||||
" stack = []\n",
|
||||
" for _ in range(10):\n",
|
||||
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
|
||||
" stack.append(l)\n",
|
||||
" db[bs] = np.median(np.stack(stack), axis=0)\n",
|
||||
" del x\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
"del model\n",
|
||||
"torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(db, f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
165
research/paper_plot/acc_fig.py
Normal file
165
research/paper_plot/acc_fig.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# Set plot parameters
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
# Path settings
|
||||
FIGURE_PATH = "./paper_plot/figures"
|
||||
|
||||
# Load accuracy data
|
||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
||||
|
||||
# Create figure with 4 subplots (one for each dataset)
|
||||
fig, axs = plt.subplots(1, 4)
|
||||
fig.set_size_inches(9, 2.5)
|
||||
|
||||
# Reduce the spacing between subplots
|
||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
||||
|
||||
# Define datasets and their columns
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
||||
metrics = ["Exact Match", "F1"]
|
||||
|
||||
# Define bar settings - make bars thicker
|
||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
||||
# width = total_width / n
|
||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
||||
# It's also used as the base for calculating the actual plotted bar width.
|
||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
||||
n = 3 # Number of models
|
||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
||||
|
||||
# Colors and hatches
|
||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
||||
|
||||
# Create plots for each dataset
|
||||
for i, dataset in enumerate(datasets):
|
||||
ax = axs[i]
|
||||
|
||||
# Get data for this dataset and convert to percentages
|
||||
em_values = [
|
||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
||||
]
|
||||
f1_values = [
|
||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
||||
]
|
||||
|
||||
# Define x positions for bars
|
||||
# For EM: center - width, center, center + width
|
||||
# For F1: center - width, center, center + width
|
||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
||||
bar_offsets = [-width, 0, width]
|
||||
|
||||
# Plot all bars on the same axis
|
||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
||||
for j, model_label in enumerate(labels):
|
||||
x_pos = metric_group_center + bar_offsets[j]
|
||||
bar_value = values_to_plot[j]
|
||||
|
||||
ax.bar(
|
||||
x_pos,
|
||||
bar_value,
|
||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
||||
color="white",
|
||||
edgecolor=edgecolors[j],
|
||||
hatch=hatches[j],
|
||||
linewidth=1.5,
|
||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
||||
)
|
||||
|
||||
# Add value on top of bar
|
||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
||||
|
||||
# Set x-ticks and labels
|
||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
||||
|
||||
# Now, shift these labels slightly to the right
|
||||
# Adjust this value to control the amount of shift (in data coordinates)
|
||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
||||
# horizontal_shift = 0.7 # Try adjusting this value
|
||||
|
||||
# for label in xticklabels:
|
||||
# # Get the current x position (which is the tick location)
|
||||
# current_x_pos = label.get_position()[0]
|
||||
# # Set the new x position by adding the shift
|
||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
||||
# # Ensure the label remains horizontally centered on this new x position
|
||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
||||
# label.set_horizontalalignment('center')
|
||||
|
||||
# Set title
|
||||
ax.set_title(dataset, fontsize=14)
|
||||
|
||||
# Set y-label for all subplots
|
||||
if i == 0:
|
||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
||||
else:
|
||||
# Hide y-tick labels for non-first subplots to save space
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
# Set y-limits based on data range
|
||||
all_values = em_values + f1_values
|
||||
max_val = max(all_values)
|
||||
min_val = min(all_values)
|
||||
|
||||
# Special handling for GPQA which has very low values
|
||||
if dataset == "GPQA":
|
||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
||||
else:
|
||||
# Reduce the extra space above the bars
|
||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
||||
|
||||
# Format y-ticks as percentages
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
||||
|
||||
# Set x-limits to properly space the bars with less blank space
|
||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
||||
|
||||
# Add a box around the subplot
|
||||
# for spine in ax.spines.values():
|
||||
# spine.set_visible(True)
|
||||
# spine.set_linewidth(1.0)
|
||||
|
||||
# Add legend to first subplot
|
||||
if i == 0:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
||||
ncol=3, # Changed to 3 columns for three labels
|
||||
loc="upper center",
|
||||
labelspacing=0.1,
|
||||
edgecolor="black",
|
||||
facecolor="white",
|
||||
framealpha=1,
|
||||
shadow=False,
|
||||
fancybox=False,
|
||||
handlelength=1.0,
|
||||
handletextpad=0.6,
|
||||
columnspacing=0.8,
|
||||
prop={"weight": "bold", "size": 12},
|
||||
)
|
||||
|
||||
# Save figure with tight layout but no additional padding
|
||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
||||
plt.show()
|
||||
309
research/paper_plot/analyze_visits.py
Normal file
309
research/paper_plot/analyze_visits.py
Normal file
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
||||
# per degree bin, styled for academic publications, with caching.
|
||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
from collections import Counter
|
||||
import os # For robust filepath manipulation
|
||||
import math # For calculating scaling factor
|
||||
import pickle # For caching data
|
||||
|
||||
# %%
|
||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
||||
|
||||
# --- Define styles from reference ---
|
||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
||||
|
||||
# %%
|
||||
# --- File Paths ---
|
||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
||||
visit_log_file = './re.log'
|
||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
||||
# --- CACHE FILE PATH: Keep this consistent ---
|
||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
||||
|
||||
# --- Configuration ---
|
||||
# Set to True to bypass cache and force recomputation.
|
||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
||||
FORCE_RECOMPUTE = False
|
||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
||||
|
||||
# Create directory for figures if it doesn't exist
|
||||
output_dir = os.path.dirname(output_image_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created directory: {output_dir}")
|
||||
|
||||
# %%
|
||||
# --- Attempt to load data from cache or compute ---
|
||||
df_plot_data = None
|
||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
||||
|
||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
||||
cache_content = pickle.load(f)
|
||||
df_plot_data = cache_content['data']
|
||||
bin_size_for_plot = cache_content['bin_size']
|
||||
# Basic validation of cached data
|
||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
||||
'degree_bin_label' not in df_plot_data.columns or \
|
||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
||||
not isinstance(bin_size_for_plot, int):
|
||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
else:
|
||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
||||
|
||||
# --- Modify the label loaded from cache for display purpose ---
|
||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
||||
# Check if "0-4" label exists
|
||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
||||
# Use .loc to ensure the modification is on the original DataFrame
|
||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
||||
except Exception as e:
|
||||
print(f"Error loading from cache: {e}. Recomputing.")
|
||||
df_plot_data = None # Invalidate to trigger recomputation
|
||||
|
||||
if df_plot_data is None:
|
||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
||||
# --- 1. Read Degree Distribution File ---
|
||||
degrees_data = []
|
||||
try:
|
||||
with open(degree_file, 'r') as f:
|
||||
for i, line in enumerate(f):
|
||||
line_stripped = line.strip()
|
||||
if line_stripped:
|
||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
||||
|
||||
|
||||
if not degrees_data:
|
||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
||||
exit()
|
||||
df_degrees = pd.DataFrame(degrees_data)
|
||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
||||
|
||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
||||
visit_counts = Counter()
|
||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
||||
try:
|
||||
with open(visit_log_file, 'r') as f_log:
|
||||
for line_num, line in enumerate(f_log, 1):
|
||||
match = node_id_pattern.search(line)
|
||||
if match:
|
||||
try:
|
||||
node_id = int(match.group(2))
|
||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
||||
except ValueError:
|
||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
||||
if not df_degrees.empty:
|
||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
||||
# Generate visit counts to test different probability magnitudes
|
||||
if node_id_val % 23 == 0: # Very low probability
|
||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
||||
elif node_id_val % 11 == 0: # Low probability
|
||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
||||
elif node_id_val % 5 == 0: # Moderate probability
|
||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
||||
|
||||
if not visit_counts:
|
||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
||||
else:
|
||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
||||
df_visits = pd.DataFrame(df_visits_list)
|
||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
||||
|
||||
# --- 3. Merge Degree Data with Visit Data ---
|
||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
||||
print(f"Merged data contains {len(df_merged)} entries.")
|
||||
|
||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
||||
current_bin_size = 5
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
if not df_degrees.empty:
|
||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
||||
|
||||
df_merged_with_bins = df_merged.copy()
|
||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
||||
|
||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
||||
node_count_in_bin=('node_id', 'nunique')
|
||||
).reset_index()
|
||||
|
||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
||||
# This value is what gets cached.
|
||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
||||
|
||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
||||
|
||||
bin_to_drop_label = '60-64'
|
||||
original_length = len(df_binned_analysis)
|
||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
||||
if len(df_plot_data_intermediate) < original_length:
|
||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
||||
else:
|
||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
||||
|
||||
df_plot_data = df_plot_data_intermediate
|
||||
|
||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty:
|
||||
try:
|
||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
||||
except Exception as e:
|
||||
print(f"Error saving data to cache: {e}")
|
||||
elif df_plot_data is None or df_plot_data.empty:
|
||||
print("Computed data for binned plot is empty, not saving to cache.")
|
||||
else:
|
||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
||||
df_plot_data = pd.DataFrame()
|
||||
bin_size_for_plot = current_bin_size
|
||||
|
||||
# %%
|
||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
||||
|
||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
||||
base_name, ext = os.path.splitext(output_image_file)
|
||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
||||
binned_output_image_file = base_name + ext
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
||||
|
||||
df_plot_data_plotting = df_plot_data.copy()
|
||||
# Calculate per-query probability: (avg visits over N queries) / N
|
||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
||||
|
||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
||||
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
||||
|
||||
apply_scaling_to_label_and_values = False # Initialize flag
|
||||
exponent_for_label_display = 0 # Initialize exponent
|
||||
|
||||
if pd.notna(max_probability) and max_probability > 0:
|
||||
potential_exponent = math.floor(math.log10(max_probability))
|
||||
|
||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
||||
apply_scaling_to_label_and_values = True
|
||||
exponent_for_label_display = potential_exponent
|
||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
||||
|
||||
if apply_scaling_to_label_and_values:
|
||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
||||
else:
|
||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
||||
|
||||
elif pd.notna(max_probability) and max_probability == 0:
|
||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
||||
else:
|
||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
||||
|
||||
ax.bar(
|
||||
df_plot_data_plotting['degree_bin_label'],
|
||||
y_axis_values_to_plot,
|
||||
color='white',
|
||||
edgecolor=edgecolors_ref[0],
|
||||
linewidth=1.5,
|
||||
width=0.8
|
||||
)
|
||||
|
||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
||||
|
||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
||||
|
||||
num_bins = len(df_plot_data_plotting)
|
||||
if num_bins > 12:
|
||||
ax.set_xticks(ax.get_xticks())
|
||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
||||
elif num_bins > 8:
|
||||
ax.tick_params(axis='x', labelsize=9)
|
||||
else:
|
||||
ax.tick_params(axis='x', labelsize=10)
|
||||
|
||||
ax.tick_params(axis='y', labelsize=10)
|
||||
|
||||
padding_factor = 0.05
|
||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
||||
|
||||
upper_y_limit = 0.1 # Default small upper limit
|
||||
if pd.notna(current_max_y_on_axis):
|
||||
if current_max_y_on_axis > 0:
|
||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
||||
min_meaningful_limit = 0.01
|
||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
||||
|
||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
||||
|
||||
else: # current_max_y_on_axis is 0
|
||||
upper_y_limit = 0.1
|
||||
ax.set_ylim(0, upper_y_limit)
|
||||
else:
|
||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
else:
|
||||
if df_plot_data is None:
|
||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
||||
elif df_plot_data.empty:
|
||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
||||
|
||||
# %%
|
||||
print("Script finished.")
|
||||
7
research/paper_plot/b.md
Normal file
7
research/paper_plot/b.md
Normal file
@@ -0,0 +1,7 @@
|
||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
81
research/paper_plot/cache_degree_data.py
Normal file
81
research/paper_plot/cache_degree_data.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
||||
BIG_GRAPH_PATHS = [
|
||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
||||
]
|
||||
STATS_FILE_NAME = "degree_distribution.txt"
|
||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
||||
"HNSW-Base",
|
||||
"DegreeGuide",
|
||||
"HNSW-D9",
|
||||
"RandCut",
|
||||
]
|
||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
||||
|
||||
# --- Cache File Configuration ---
|
||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
||||
|
||||
def create_degree_data_cache():
|
||||
"""
|
||||
Reads degree distribution data from specified text files and saves it
|
||||
into a compressed NumPy (.npz) cache file.
|
||||
"""
|
||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
||||
|
||||
cached_data = {}
|
||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
||||
|
||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
||||
method_label = BIG_GRAPH_LABELS[i]
|
||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
||||
|
||||
print(f"Processing: {method_label} from {degree_file_path}")
|
||||
|
||||
try:
|
||||
# Load degrees as integers
|
||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
||||
|
||||
if degrees.size == 0:
|
||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
||||
cached_data[method_label] = np.array([], dtype=int)
|
||||
else:
|
||||
# Store the loaded degrees array with the method label as the key
|
||||
cached_data[method_label] = degrees
|
||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
||||
except Exception as e:
|
||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
||||
|
||||
if not cached_data:
|
||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
||||
return
|
||||
|
||||
try:
|
||||
# Save all collected degree arrays into a single .npz file.
|
||||
# Using savez_compressed for potentially smaller file size.
|
||||
np.savez_compressed(cache_file_path, **cached_data)
|
||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("--- Degree Distribution Data Caching Script ---")
|
||||
create_degree_data_cache()
|
||||
print("--- Caching script finished. ---")
|
||||
4
research/paper_plot/data/acc.csv
Normal file
4
research/paper_plot/data/acc.csv
Normal file
@@ -0,0 +1,4 @@
|
||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
||||
|
3
research/paper_plot/data/big_graph_degree_data.npz
Normal file
3
research/paper_plot/data/big_graph_degree_data.npz
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
||||
size 227482438
|
||||
21
research/paper_plot/data/branches.csv
Normal file
21
research/paper_plot/data/branches.csv
Normal file
@@ -0,0 +1,21 @@
|
||||
2,1,512,1024,0.541,0.326,1.659509202
|
||||
2,2,512,1024,0.979,0.621,1.576489533
|
||||
2,4,512,1024,1.846,0.977,1.889457523
|
||||
2,8,512,1024,3.575,1.943,1.83993824
|
||||
2,16,512,1024,7.035,3.733,1.884543263
|
||||
2,32,512,1024,15.655,8.517,1.838088529
|
||||
2,64,512,1024,32.772,17.43,1.88020654
|
||||
4,1,512,1024,2.675,1.38,1.938405797
|
||||
4,2,512,1024,5.397,2.339,2.307396323
|
||||
4,4,512,1024,10.672,4.944,2.158576052
|
||||
4,8,512,1024,21.061,9.266,2.272933305
|
||||
4,16,512,1024,46.332,18.334,2.527108105
|
||||
4,32,512,1024,99.607,36.156,2.754923111
|
||||
4,64,512,1024,186.348,72.356,2.575432583
|
||||
8,1,512,1024,7.325,4.087,1.792268167
|
||||
8,2,512,1024,14.109,7.491,1.883460152
|
||||
8,4,512,1024,28.499,14.013,2.033754371
|
||||
8,8,512,1024,65.222,27.453,2.375769497
|
||||
8,16,512,1024,146.294,52.55,2.783901047
|
||||
8,32,512,1024,277.099,103.61,2.674442621
|
||||
8,64,512,1024,512.979,208.36,2.461984066
|
||||
|
9
research/paper_plot/data/latency_ablation.csv
Normal file
9
research/paper_plot/data/latency_ablation.csv
Normal file
@@ -0,0 +1,9 @@
|
||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
||||
NQ,Latency,6.9,5.8,4.2,3.7
|
||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
||||
|
25
research/paper_plot/data/main_latency.csv
Normal file
25
research/paper_plot/data/main_latency.csv
Normal file
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
||||
|
25
research/paper_plot/data/main_latency_small.csv
Normal file
25
research/paper_plot/data/main_latency_small.csv
Normal file
@@ -0,0 +1,25 @@
|
||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
||||
|
3
research/paper_plot/data/ram_storage.csv
Normal file
3
research/paper_plot/data/ram_storage.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
||||
RAM,190,171,10,0,0,0,0
|
||||
Storage,185.4,171,240,171,0.5,5,59
|
||||
|
12
research/paper_plot/data/swithc_e2e.csv
Normal file
12
research/paper_plot/data/swithc_e2e.csv
Normal file
@@ -0,0 +1,12 @@
|
||||
Torch,8,55.592
|
||||
Torch,16,75.439
|
||||
Torch,32,110.025
|
||||
Torch,64,186.496
|
||||
Tutel,8,56.718
|
||||
Tutel,16,82.121
|
||||
Tutel,32,125.070
|
||||
Tutel,64,216.191
|
||||
BRT,8,56.725
|
||||
BRT,16,79.291
|
||||
BRT,32,93.180
|
||||
BRT,64,118.923
|
||||
|
6
research/paper_plot/data/vary_cache.csv
Normal file
6
research/paper_plot/data/vary_cache.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
||||
Latency,,,,,
|
||||
NQ,4.616,4.133,3.826,3.511,3.323
|
||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
||||
|
151
research/paper_plot/disk_cache.py
Normal file
151
research/paper_plot/disk_cache.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import matplotlib
|
||||
from matplotlib.axes import Axes
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
# plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
||||
plt.rcParams['text.latex.preamble'] = r"""
|
||||
\usepackage{helvet} % Use Helvetica font for text
|
||||
\usepackage{sfmath} % Use sans-serif font for math
|
||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
||||
"""
|
||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
||||
SAVE_PTH = "./paper_plot/figures"
|
||||
font_size = 16
|
||||
|
||||
# New data in dictionary format
|
||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
||||
|
||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
||||
latency_data = {
|
||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
||||
}
|
||||
cache_hit_counts = {
|
||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
||||
}
|
||||
|
||||
# Create the figure with 4 subplots in a 2x2 grid
|
||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
||||
|
||||
# Bar style settings
|
||||
width = 0.7
|
||||
x = np.arange(len(cache_ratios))
|
||||
|
||||
# Define hatch patterns for different cache ratios
|
||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
||||
|
||||
# Find max cache hit value across all datasets for unified y-axis
|
||||
all_hit_counts = []
|
||||
for dataset in datasets:
|
||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
||||
max_unified_hit = max(all_hit_counts) * 1.13
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
latencies = latency_data[dataset]
|
||||
hit_counts = cache_hit_counts[dataset]
|
||||
|
||||
for j, val in enumerate(latencies):
|
||||
container = axes[i].bar(
|
||||
x[j],
|
||||
val,
|
||||
width=width,
|
||||
color="white",
|
||||
edgecolor="black",
|
||||
linewidth=1.0,
|
||||
zorder=10,
|
||||
)
|
||||
axes[i].bar_label(
|
||||
container,
|
||||
[f"{val:.2f}"],
|
||||
fontsize=10,
|
||||
zorder=200,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
axes[i].set_title(dataset, fontsize=font_size)
|
||||
axes[i].set_xticks(x)
|
||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
||||
|
||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
||||
max_val = max(latencies) * max_val_ratios[i]
|
||||
axes[i].set_ylim(0, max_val)
|
||||
axes[i].tick_params(axis='y', labelsize=12)
|
||||
|
||||
if i % 2 == 0:
|
||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
||||
|
||||
ax2: Axes = axes[i].twinx()
|
||||
ax2.plot(x, hit_counts,
|
||||
linestyle='--',
|
||||
marker='o',
|
||||
markersize=6,
|
||||
linewidth=1.5,
|
||||
color='k',
|
||||
markerfacecolor='none',
|
||||
zorder=20)
|
||||
|
||||
ax2.set_ylim(0, max_unified_hit)
|
||||
ax2.tick_params(axis='y', labelsize=12)
|
||||
if i % 2 == 1:
|
||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
||||
|
||||
for j, val in enumerate(hit_counts):
|
||||
if val > 0:
|
||||
ax2.annotate(f"{val:.1f}%",
|
||||
(x[j], val),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 5),
|
||||
ha='center',
|
||||
va='bottom',
|
||||
fontsize=10,
|
||||
fontweight='bold')
|
||||
|
||||
# Create legend for both plots
|
||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
||||
|
||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
||||
fig.legend(handles=[bar_patch, line_patch],
|
||||
loc='upper center', # Position the legend at the upper center
|
||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
||||
# 0.97 means 97% from the bottom, so near the top)
|
||||
ncol=3,
|
||||
fontsize=font_size-2)
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Set common x-axis label - you might want to add this back if needed
|
||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
||||
|
||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
||||
# Adjust rect to make space for the legend at the top.
|
||||
# (left, bottom, right, top_for_subplots)
|
||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
||||
# --- END OF MODIFICATION ---
|
||||
|
||||
# Create directory if it doesn't exist (optional, good practice)
|
||||
import os
|
||||
if not os.path.exists(SAVE_PTH):
|
||||
os.makedirs(SAVE_PTH)
|
||||
|
||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
||||
# plt.show() # Optional: to display the plot
|
||||
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.pdf
Normal file
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.png
Normal file
BIN
research/paper_plot/figures/H_hnsw_performance_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.pdf
Normal file
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.png
Normal file
BIN
research/paper_plot/figures/H_hnsw_recall_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 100 KiB |
BIN
research/paper_plot/figures/accuracy_comparison.pdf
Normal file
BIN
research/paper_plot/figures/accuracy_comparison.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/degree_distribution.pdf
Normal file
BIN
research/paper_plot/figures/degree_distribution.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/degree_distribution_small.pdf
Normal file
BIN
research/paper_plot/figures/degree_distribution_small.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/disk_cache_latency.pdf
Normal file
BIN
research/paper_plot/figures/disk_cache_latency.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/figure15.pdf
Normal file
BIN
research/paper_plot/figures/figure15.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/gpu_throughput_vs_batch_size.pdf
Normal file
BIN
research/paper_plot/figures/gpu_throughput_vs_batch_size.pdf
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 41 KiB |
BIN
research/paper_plot/figures/latency_speedup.pdf
Normal file
BIN
research/paper_plot/figures/latency_speedup.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/main_exp_fig_1.pdf
Normal file
BIN
research/paper_plot/figures/main_exp_fig_1.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/main_exp_fig_2.pdf
Normal file
BIN
research/paper_plot/figures/main_exp_fig_2.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/plot1_em_f1.pdf
Normal file
BIN
research/paper_plot/figures/plot1_em_f1.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/plot2_latency.pdf
Normal file
BIN
research/paper_plot/figures/plot2_latency.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/ram_storage_double_column.pdf
Normal file
BIN
research/paper_plot/figures/ram_storage_double_column.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/sparse_a2a_branches.pdf
Normal file
BIN
research/paper_plot/figures/sparse_a2a_branches.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/speed_A10_revised.pdf
Normal file
BIN
research/paper_plot/figures/speed_A10_revised.pdf
Normal file
Binary file not shown.
BIN
research/paper_plot/figures/speed_MAC_revised.pdf
Normal file
BIN
research/paper_plot/figures/speed_MAC_revised.pdf
Normal file
Binary file not shown.
107
research/paper_plot/gpu_under.py
Normal file
107
research/paper_plot/gpu_under.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
||||
# \file: /gpu_utilization_plot.py
|
||||
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
|
||||
# Author: AI Assistant
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd # Using pandas for data structuring, similar to example
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Apply styling similar to the example script
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["xtick.direction"] = "in"
|
||||
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
|
||||
|
||||
# New Benchmark data (4th set)
|
||||
data = {
|
||||
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
|
||||
'avg_time_s': [
|
||||
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
|
||||
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
|
||||
],
|
||||
'throughput_seq_s': [
|
||||
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
|
||||
890.80, 946.93, 1022.75, 1017.03, 1043.17,
|
||||
]
|
||||
}
|
||||
benchmark_df = pd.DataFrame(data)
|
||||
|
||||
# Create the plot
|
||||
# Increased width slightly for more x-axis labels
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_size_inches(8, 5)
|
||||
|
||||
# Generate equally spaced x-coordinates (indices)
|
||||
x_indices = np.arange(len(benchmark_df))
|
||||
|
||||
# Plotting throughput vs. batch size (using indices for x-axis)
|
||||
ax.plot(
|
||||
x_indices, # Use equally spaced indices for plotting
|
||||
benchmark_df['throughput_seq_s'],
|
||||
marker='o', # Add markers to data points
|
||||
linestyle='-',
|
||||
color="#63B8B6", # A color inspired by the example's 'edgecolors'
|
||||
linewidth=2,
|
||||
markersize=6,
|
||||
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
|
||||
)
|
||||
|
||||
# Setting labels for axes
|
||||
ax.set_xlabel("Batch Size", fontsize=14)
|
||||
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
|
||||
|
||||
# Customizing Y-axis for the new data range:
|
||||
# Start Y from 0 to include the anomalous low point and show full scale.
|
||||
y_min_val = 200
|
||||
# Round up y_max_val to the nearest 100, as max throughput > 1000
|
||||
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
|
||||
ax.set_ylim((y_min_val, y_max_val))
|
||||
# Set y-ticks every 100 units, ensuring the top tick is included.
|
||||
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
|
||||
|
||||
# Customizing X-axis for equally spaced ticks:
|
||||
# Set tick positions to the indices
|
||||
ax.set_xticks(x_indices)
|
||||
# Set tick labels to the actual batch_size values
|
||||
ax.set_xticklabels(benchmark_df['batch_size'])
|
||||
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
|
||||
ax.tick_params(axis='y', labelsize=12)
|
||||
|
||||
|
||||
# Add a light grid for better readability, common in academic plots
|
||||
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
|
||||
|
||||
# Remove title (as requested)
|
||||
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
|
||||
|
||||
# Optional: Add a legend if you have multiple lines or want to label the single line
|
||||
# ax.legend(
|
||||
# loc="center right", # Location might need adjustment due to data shape
|
||||
# edgecolor="black",
|
||||
# facecolor="white",
|
||||
# framealpha=1.0,
|
||||
# shadow=False,
|
||||
# fancybox=False,
|
||||
# prop={"weight": "bold", "size": 10}
|
||||
# ).set_zorder(100)
|
||||
|
||||
# Adjust layout to prevent labels from being cut off
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the figure
|
||||
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
|
||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
# Display the plot (optional, depending on environment)
|
||||
plt.show()
|
||||
|
||||
# %%
|
||||
# This is just to mimic the '%%' cell structure from the example.
|
||||
# No actual code needed here for this script.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user