Compare commits

..

14 Commits

Author SHA1 Message Date
Yichuan Wang
f83c97e6d1 Merge branch 'main' into readme-polish 2025-07-19 21:47:17 -07:00
Andy Lee
6e755f0402 docs: follow yichuan's suggestion 2025-07-19 21:44:31 -07:00
Andy Lee
cc6b904c44 docs: follow yichuan's suggestion 2025-07-19 21:21:41 -07:00
Andy Lee
bda028cc1b docs: polish 2025-07-19 21:02:25 -07:00
Andy Lee
bed814e7e6 docs: polish 2025-07-19 20:45:50 -07:00
Andy Lee
96f74973b1 docs: how it works earlier 2025-07-19 20:42:52 -07:00
Andy Lee
1f90cdfafb docs: polish 2025-07-19 20:35:15 -07:00
Andy Lee
8f4f66d871 docs: highlight applications 2025-07-19 20:23:29 -07:00
Andy Lee
43b52a8c0a docs: polish 2025-07-19 20:21:25 -07:00
Andy Lee
1a3180bc0f docs: readme effects 2025-07-19 19:54:21 -07:00
Andy Lee
fe4a748a69 docs: logo with text 2025-07-19 16:47:06 -07:00
Andy Lee
d296f372e0 docs: logo 2025-07-19 16:26:31 -07:00
Andy Lee
909835dd2d docs: logo 2025-07-19 16:24:40 -07:00
Andy Lee
1eea69e8d7 docs: polish 2025-07-19 16:16:24 -07:00
146 changed files with 16276 additions and 6543 deletions

View File

@@ -1,11 +0,0 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
uses: ./.github/workflows/build-reusable.yml

View File

@@ -1,167 +0,0 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
build:
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python: '3.9'
- os: ubuntu-22.04
python: '3.10'
- os: ubuntu-22.04
python: '3.11'
- os: ubuntu-22.04
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
- os: macos-latest
python: '3.9'
- os: macos-latest
python: '3.10'
- os: macos-latest
python: '3.11'
- os: macos-latest
python: '3.12'
- os: macos-latest
python: '3.13'
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build packages
run: |
# Build core (platform independent)
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann-core
uv build
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent)
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann
uv build
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/

View File

@@ -1,129 +0,0 @@
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
jobs:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
# Remove 'v' prefix if present for validation
VERSION_CLEAN="${{ inputs.version }}"
VERSION_CLEAN="${VERSION_CLEAN#v}"
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
exit 1
fi
echo "✅ Version format valid: ${{ inputs.version }}"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: 'main'
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: 'main'
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

5
.gitignore vendored
View File

@@ -12,6 +12,7 @@ outputs/
*.idx
*.map
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl
@@ -83,6 +84,4 @@ test_*.py
packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json
batchtest.py
*.passages.json

4
.gitmodules vendored
View File

@@ -1,9 +1,9 @@
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
path = packages/leann-backend-diskann/third_party/DiskANN
url = https://github.com/yichuan-w/DiskANN.git
url = https://github.com/yichuan520030910320/DiskANN.git
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
path = packages/leann-backend-hnsw/third_party/faiss
url = https://github.com/yichuan-w/faiss.git
url = https://github.com/yichuan520030910320/faiss.git
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
path = packages/leann-backend-hnsw/third_party/msgpack-c
url = https://github.com/msgpack/msgpack-c.git

9
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,9 @@
{
"recommendations": [
"llvm-vs-code-extensions.vscode-clangd",
"ms-python.python",
"ms-vscode.cmake-tools",
"vadimcn.vscode-lldb",
"eamodio.gitlens",
]
}

283
.vscode/launch.json vendored Executable file
View File

@@ -0,0 +1,283 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
// new emdedder
{
"name": "New Embedder",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"args": [
"--search",
"--use-original",
"--domain",
"dpr",
"--nprobe",
"5000",
"--load",
"flat",
"--embedder",
"intfloat/multilingual-e5-small"
]
}
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
{
"name": "main.py",
"type": "debugpy",
"request": "launch",
"program": "demo/main.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--query",
"1000",
"--load",
"bm25"
]
},
{
"name": "Simple Build",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/simple_build.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
//# Fix for Intel MKL error
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
//python faiss/demo/build_demo.py
{
"name": "Build Demo",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"faiss/demo/build_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
{
"name": "DiskANN Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"sglang",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings",
"--port",
"8082",
"--diskann-search-memory-maximum",
"2",
"--diskann-graph",
"240",
"--search-only"
],
"env": {
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
},
"preLaunchTask": "CMake: build",
},
{
"name": "DiskANN Serve MAC",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--mode",
"serve",
"--engine",
"ollama",
"--load-indices",
"diskann",
"--domain",
"rpj_wiki",
"--lazy-load",
"--recompute-beighbor-embeddings"
],
"preLaunchTask": "CMake: build",
"env": {
"KMP_DUPLICATE_LIB_OK": "TRUE",
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
"KMP_BLOCKTIME": "0"
}
},
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "ric/main_ric.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--config-name",
"${input:configSelection}"
],
"justMyCode": false
},
//python ./demo/validate_equivalence.py sglang
{
"name": "Validate Equivalence",
"type": "debugpy",
"request": "launch",
"program": "demo/validate_equivalence.py",
"console": "integratedTerminal",
"args": [
"sglang"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
{
"name": "Retrieval Demo",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"vllm",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
// "flat",
"ivf_flat"
],
},
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
{
"name": "Retrieval Demo DiskANN",
"type": "debugpy",
"request": "launch",
"program": "demo/retrieval_demo.py",
"console": "integratedTerminal",
"args": [
"--engine",
"sglang",
"--skip-embeddings",
"--domain",
"dpr",
"--load-indices",
"diskann",
"--hnsw-M",
"64",
"--hnsw-efConstruction",
"150",
"--hnsw-efSearch",
"128",
"--hnsw-sq-bits",
"8"
],
},
{
"name": "Find Probe",
"type": "debugpy",
"request": "launch",
"program": "find_probe.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
},
{
"name": "Python: Attach",
"type": "debugpy",
"request": "attach",
"processId": "${command:pickProcess}",
"justMyCode": true
},
{
"name": "Edge RAG",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"edgerag_demo.py"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
"MKL_NUM_THREADS": "1",
"OMP_NUM_THREADS": "1",
}
},
{
"name": "Launch Embedding Server",
"type": "debugpy",
"request": "launch",
"program": "demo/embedding_server.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"--domain",
"rpj_wiki",
"--zmq-port",
"5556",
]
},
{
"name": "HNSW Serve",
"type": "lldb",
"request": "launch",
"program": "${workspaceFolder}/.venv/bin/python",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"args": [
"demo/main.py",
"--domain",
"rpj_wiki",
"--load",
"hnsw",
"--mode",
"serve",
"--search",
"--skip-pa",
"--recompute",
"--hnsw-old"
],
"env": {
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
}
},
],
"inputs": [
{
"id": "configSelection",
"type": "pickString",
"description": "Select a configuration",
"options": [
"example_config",
"vllm_gritlm"
],
"default": "example_config"
}
],
}

43
.vscode/settings.json vendored Executable file
View File

@@ -0,0 +1,43 @@
{
"python.analysis.extraPaths": [
"./sglang_repo/python"
],
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
"cmake.configureArgs": [
"-DPYBIND=True",
"-DUPDATE_EDITABLE_INSTALL=ON",
],
"cmake.environment": {
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
},
"cmake.buildDirectory": "${workspaceFolder}/build",
"files.associations": {
"*.tcc": "cpp",
"deque": "cpp",
"string": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"atomic": "cpp",
"inplace_vector": "cpp",
"*.ipp": "cpp",
"forward_list": "cpp",
"list": "cpp",
"any": "cpp",
"system_error": "cpp",
"__hash_table": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"ios": "cpp",
"set": "cpp",
"__string": "cpp",
"string_view": "cpp",
"ranges": "cpp",
"iosfwd": "cpp"
},
"lldb.displayFormat": "auto",
"lldb.showDisassembly": "auto",
"lldb.dereferencePointers": true,
"lldb.consoleMode": "commands",
}

16
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "cmake",
"label": "CMake: build",
"command": "build",
"targets": [
"all"
],
"group": "build",
"problemMatcher": [],
"detail": "CMake template build task"
}
]
}

356
README.md
View File

@@ -12,74 +12,67 @@
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
LEANN is a revolutionary vector database that makes personal AI accessible to everyone. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
RAG your **[emails](#-search-your-entire-life)**, **[browser history](#-time-machine-for-the-web)**, **[WeChat](#-wechat-detective)**, or 60M documents on your laptop, in nearly zero cost. No cloud, no API keys, completely private.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Read more →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
## Why LEANN?
<p align="center">
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="100%">
</p>
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
## Why This Matters
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
🪶 **Lightweight:** Smart graph pruning means less storage, less memory usage, better performance on your existing hardware.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
📈 **Scalability:** Organize our messy personal data that would crash traditional vector DBs, with performance that gets better as your data grows more personalized.
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation
> `pip leann` coming soon!
## Quick Start in 1 minute
```bash
git clone git@github.com:yichuan-w/LEANN.git leann
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
cd leann
git submodule update --init --recursive
```
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq pkgconf
# Install with HNSW backend (default, recommended for most users)
# Install uv first if you don't have it:
# curl -LsSf https://astral.sh/uv/install.sh | sh
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
```
**Linux:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
brew install llvm libomp boost protobuf
export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
uv sync
```
**Linux (Ubuntu/Debian):**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
uv sync
```
**Ollama Setup (Recommended for full privacy):**
**Ollama Setup (Optional for Local LLM):**
> *You can skip this installation if you only want to use OpenAI API for generation.*
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
**macOS:**
First, [download Ollama for macOS](https://ollama.com/download/mac).
*macOS:*
```bash
# Install Ollama
brew install ollama
# Pull a lightweight model (recommended for consumer hardware)
ollama pull llama3.2:1b
```
**Linux:**
*Linux:*
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
@@ -91,78 +84,62 @@ ollama serve &
ollama pull llama3.2:1b
```
## Quick Start in 30s
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
## Dead Simple API
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from leann.api import LeannBuilder, LeannSearcher
# 1. Build the index (no embeddings stored!)
# 1. Build index (no embeddings stored!)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language and it is very popular")
builder.add_text("Machine learning transforms industries")
builder.add_text("Python is a powerful programming language")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
# 3. Chat with LEANN using retrieved results
llm_config = {
"type": "ollama",
"model": "llama3.2:1b"
}
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
response = chat.ask(
"Compare the two retrieved programming languages and say which one is more popular today.",
top_k=2,
)
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
print(results)
```
## RAG on Everything!
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
[Try the interactive demo →](demo.ipynb)
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
## Wild Things You Can Do
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
<p align="center">
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p>
### 📚 Process Any Documents (.pdf, .txt, .md)
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
```bash
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
```
```
# Or use python directly
source .venv/bin/activate
python ./examples/main_cli_example.py
```
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
### 🕵️ Search Your Entire Life
```bash
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
python examples/mail_reader_leann.py
# "What did my boss say about the Christmas party last year?"
# "Find all emails from my mom about birthday plans"
```
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
**90K emails → 14MB.** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
@@ -195,16 +172,13 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses"
</details>
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
### 🌐 Time Machine for the Web
```bash
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
python examples/google_history_reader_leann.py
# "What was that AI paper I read last month?"
# "Show me all the cooking videos I watched"
```
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
@@ -253,17 +227,13 @@ Once the index is built, you can ask questions like:
</details>
### 💬 WeChat Detective: Unlock Your Golden Memories!
<p align="center">
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
</p>
### 💬 WeChat Detective
```bash
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
python examples/wechat_history_reader_leann.py
# "Show me all group chats about weekend plans"
```
**400K messages → 64MB storage** Search years of chat history in any language.
**400K messages → 64MB.** Search years of chat history in any language.
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
@@ -274,13 +244,7 @@ First, you need to install the WeChat exporter:
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting:**
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
</details>
<details>
@@ -315,73 +279,6 @@ Once the index is built, you can ask questions like:
</details>
## 🖥️ Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
```bash
# Build an index from documents
leann build my-docs --docs ./documents
# Search your documents
leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# List all your indexes
leann list
```
**Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX)
- Smart text chunking with overlap
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `~/.leann/indexes/`
- Support for advanced search parameters
<details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
**Build Command:**
```bash
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
Options:
--backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact Use compact storage (default: true)
--recompute Enable recomputation (default: true)
```
**Search Command:**
```bash
leann search INDEX_NAME QUERY [OPTIONS]
Options:
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute-embeddings Use recomputation for highest accuracy
--pruning-strategy {global,local,proportional}
```
**Ask Command:**
```bash
leann ask INDEX_NAME [OPTIONS]
Options:
--llm {ollama,openai,hf} LLM provider (default: ollama)
--model MODEL Model name (default: qwen3:8b)
--interactive Interactive chat mode
--top-k N Retrieval count (default: 20)
```
</details>
## 🏗️ Architecture & How It Works
<p align="center">
@@ -400,17 +297,18 @@ Options:
## Benchmarks
Run the comparison yourself:
```bash
python examples/compare_faiss_vs_leann.py
```
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
### Storage Comparison
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
| Savings| 91% | 97% | 97% | 97% | 95% |
| System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
Same dataset, same hardware, same embedding model. LEANN just works better.
## Reproduce Our Results
@@ -420,7 +318,33 @@ python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR datase
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
The evaluation script downloads data automatically on first run.
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## 🔬 Paper
If you find Leann useful, please cite:
@@ -439,15 +363,87 @@ If you find Leann useful, please cite:
}
```
## ✨ [Detailed Features →](docs/features.md)
## ✨ Features
## 🤝 [Contributing →](docs/contributing.md)
### 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
### 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## [FAQ →](docs/faq.md)
<!-- ## FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## 📈 Roadmap
### 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q3 2025
## 📈 [Roadmap →](docs/roadmap.md)
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
### 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion
## 📄 License
@@ -455,7 +451,11 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
---
<p align="center">

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 339 KiB

After

Width:  |  Height:  |  Size: 206 KiB

View File

@@ -1,97 +1,35 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start in 30s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# install this if you are using colab\n",
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
"! uv pip install leann --no-deps\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"os.environ['LEANN_LOG_LEVEL'] = 'INFO' # Enable more detailed logging"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
"# 1. Build index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
"builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"C# is a powerful programming language but it is not very popular\")\n",
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
"builder.add_text(\"Machine learning transforms industries\") \n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(\"knowledge.leann\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search with real-time embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannSearcher\n",
"\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your macbook\")\n",
"builder.build_index(\"knowledge.leann\")\n",
"# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2)\n",
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chat with LEANN using retrieved results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannChat\n",
"results = searcher.search(\"programming languages\", top_k=2, recompute_beighbor_embeddings=True)\n",
"print(results)\n",
"\n",
"llm_config = {\n",
" \"type\": \"hf\",\n",
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n",
"llm_config = {\"type\": \"ollama\", \"model\": \"qwen3:8b\"}\n",
"\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" \"Compare the two retrieved programming languages and say which one is more popular today. Respond in a single well-formed sentence.\",\n",
" top_k=2,\n",
" llm_kwargs={\"max_tokens\": 128}\n",
" recompute_beighbor_embeddings=True,\n",
")\n",
"response"
"print(response)"
]
}
],

View File

@@ -1,22 +0,0 @@
# Release Guide
## Setup (One-time)
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
## Release (One-click)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -1,11 +0,0 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results

View File

@@ -1,10 +0,0 @@
# FAQ
## 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

View File

@@ -1,22 +0,0 @@
# ✨ Detailed Features
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment

View File

@@ -1,21 +0,0 @@
# 📈 Roadmap
## 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
## 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
## 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -135,7 +135,6 @@ def test_leann_hnsw():
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")

View File

@@ -96,12 +96,14 @@ class EmlxReader(BaseReader):
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
{body}
"""

View File

@@ -37,7 +37,7 @@ def main():
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
print("Please install it with `uv pip install faiss-cpu`")
sys.exit(1)
from llama_index.core import (

View File

@@ -65,14 +65,12 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
if not all_documents:
print("No documents loaded from any source. Exiting.")
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -80,9 +78,7 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
all_texts.append(text)
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
@@ -97,13 +93,11 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=False,
is_recompute=False,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
@@ -224,15 +218,14 @@ async def query_leann_index(index_path: str, query: str):
"max_tokens": 1000
}
)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./google_history_index",
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')

View File

@@ -74,17 +74,22 @@ class ChromeHistoryReader(BaseReader):
# Create document content with metadata embedded in text
doc_content = f"""
[Title]: {title}
[URL of the page]: {url}
[Last visited time]: {last_visit}
[Visit times]: {visit_count}
[Typed times]: {typed_count}
[BROWSING HISTORY METADATA]
URL: {url}
Title: {title}
Last Visit: {last_visit}
Visit Count: {visit_count}
Typed Count: {typed_count}
Hidden: {hidden}
[END METADATA]
Title: {title}
URL: {url}
Last visited: {last_visit}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
# if len(title) > 150:
# print(f"Title is too long: {title}")
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1

View File

@@ -197,8 +197,8 @@ class WeChatHistoryReader(BaseReader):
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
max_length: Maximum length for concatenated message groups
time_window_minutes: Time window in minutes to group messages together
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
@@ -230,8 +230,8 @@ class WeChatHistoryReader(BaseReader):
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
# Check time window constraint
if last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
@@ -250,9 +250,9 @@ class WeChatHistoryReader(BaseReader):
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
# Check length constraint
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
if current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append({
'messages': current_group,
@@ -335,15 +335,14 @@ class WeChatHistoryReader(BaseReader):
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
time_str = timestamp.strftime('%H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
sender = "Me" if is_sent_from_self else "Contact"
message_parts.append(f"[{time_str}] {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
@@ -355,11 +354,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
doc_content = f"""
Contact: {contact_name}
{concatenated_text}
"""
return doc_content, contact_name
return doc_content
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
@@ -430,9 +431,9 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
max_length=max_length,
time_window_minutes=time_window_minutes,
overlap_messages=2 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
@@ -440,8 +441,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
if count >= max_count and max_count > 0:
break
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
doc_content = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1

View File

@@ -22,7 +22,7 @@ def get_mail_path():
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
@@ -74,7 +74,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
@@ -85,11 +85,9 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
all_texts.append(text)
all_texts.append(node.get_content())
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
@@ -158,7 +156,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -218,22 +216,22 @@ async def query_leann_index(index_path: str, query: str):
start_time = time.time()
chat_response = chat.ask(
query,
top_k=20,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
complexity=12,
beam_width=1,
)
end_time = time.time()
# print(f"Time taken: {end_time - start_time} seconds")
# highlight the answer
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index",
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
@@ -253,9 +251,6 @@ async def main():
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
# messages_dirs = [DEFAULT_MAIL_PATH]
# messages_dirs = messages_dirs[:1]
print('len(messages_dirs): ', len(messages_dirs))

View File

@@ -1,40 +1,40 @@
import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core import SimpleDirectoryReader, Settings
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import shutil
from pathlib import Path
dotenv.load_dotenv()
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
documents = SimpleDirectoryReader(
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print(f"--- Index directory not found, building new index ---")
print("Loading documents...")
documents = SimpleDirectoryReader(
args.data_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
@@ -58,19 +58,22 @@ async def main(args):
print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
query = args.query
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
chat_response = chat.ask(
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
@@ -102,18 +105,6 @@ if __name__ == "__main__":
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
parser.add_argument(
"--data-dir",
type=str,
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
parser.add_argument(
"--query",
type=str,
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
help="The query to ask the Leann chat system.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -74,11 +74,11 @@ def create_leann_index_from_multiple_wechat_exports(
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -86,11 +86,10 @@ def create_leann_index_from_multiple_wechat_exports(
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
all_texts.append(node.get_content())
print(
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
)
# Create LEANN index directory
@@ -225,7 +224,7 @@ async def query_leann_index(index_path: str, query: str):
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=16,
complexity=64,
beam_width=1,
llm_config={
"type": "openai",
@@ -234,7 +233,7 @@ async def query_leann_index(index_path: str, query: str):
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
print(f"Leann: {chat_response}")
async def main():
@@ -253,13 +252,13 @@ async def main():
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_magic_test_11Debug_new",
default="./wechat_history_june19_test",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=50,
default=5000,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(

View File

@@ -1,8 +1,8 @@
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
# DiskANN will handle everything itself, including compiling Python bindings
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
add_subdirectory(src/third_party/DiskANN)

View File

@@ -1,12 +1,10 @@
import numpy as np
import os
import struct
import sys
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
from typing import Dict, Any, List, Literal
import contextlib
import logging
import pickle
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
@@ -16,46 +14,6 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
if not should_suppress:
# Don't suppress, just yield
yield
return
# Save original file descriptors
stdout_fd = sys.stdout.fileno()
stderr_fd = sys.stderr.fileno()
# Save original stdout/stderr
stdout_dup = os.dup(stdout_fd)
stderr_dup = os.dup(stderr_fd)
try:
# Redirect to /dev/null
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fd)
os.dup2(devnull, stderr_fd)
os.close(devnull)
yield
finally:
# Restore original file descriptors
os.dup2(stdout_dup, stdout_fd)
os.dup2(stderr_dup, stderr_fd)
os.close(stdout_dup)
os.close(stderr_dup)
def _get_diskann_metrics():
from . import _diskannpy as diskannpy # type: ignore
@@ -107,20 +65,22 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None:
raise ValueError(
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
)
raise ValueError("Unsupported distance_metric.")
try:
from . import _diskannpy as diskannpy # type: ignore
@@ -142,40 +102,36 @@ class DiskannBuilder(LeannBackendBuilderInterface):
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(
index_path,
backend_module_name="leann_backend_diskann.diskann_embedding_server",
backend_module_name="leann_backend_diskann.embedding_server",
**kwargs,
)
from . import _diskannpy as diskannpy # type: ignore
# Initialize DiskANN index with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
from . import _diskannpy as diskannpy # type: ignore
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
distance_metric = kwargs.get("distance_metric", "mips").lower()
metric_enum = _get_diskann_metrics().get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
self.num_threads = kwargs.get("num_threads", 8)
self.zmq_port = kwargs.get("zmq_port", 6666)
self.num_threads = kwargs.get("num_threads", 8)
fake_zmq_port = 6666
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
fake_zmq_port, # Initial port, can be updated at runtime
"",
"",
)
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
self.zmq_port,
"",
"",
)
def search(
self,
@@ -186,7 +142,7 @@ class DiskannSearcher(BaseSearcher):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: int = 5557,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
@@ -205,7 +161,7 @@ class DiskannSearcher(BaseSearcher):
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
zmq_port: ZMQ port for embedding server
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
@@ -213,25 +169,22 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: DiskANN can now update port at runtime
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
current_port = self._index.get_zmq_port()
if zmq_port != current_port:
logger.debug(
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -241,26 +194,28 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
# Perform search with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
recompute_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
use_recompute,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -1,283 +0,0 @@
"""
DiskANN-specific embedding server
"""
import argparse
import threading
import time
import os
import zmq
import numpy as np
import json
from pathlib import Path
from typing import Optional
import sys
import logging
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logger = logging.getLogger(__name__)
# Force set logger level (don't rely on basicConfig in subprocess)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
"""
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
logger.error(f"Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
# Check port availability
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
logger.error(f"Port {zmq_port} is already in use")
return
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# Import protobuf after ensuring the path is correct
try:
from . import embedding_pb2
except ImportError as e:
logger.error(f"Failed to import protobuf module: {e}")
return
def zmq_server_thread():
"""ZMQ server thread using REP socket for universal compatibility"""
context = zmq.Context()
socket = context.socket(
zmq.REP
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
# REP socket receives single-part messages
message = socket.recv()
# Check for empty messages - REP socket requires response to every request
if len(message) == 0:
logger.debug("Received empty message, sending empty response")
socket.send(b"") # REP socket must respond to every request
continue
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
e2e_start = time.time()
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
texts = []
node_ids = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
if not node_ids:
raise RuntimeError(
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
)
logger.info(
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
)
except Exception as protobuf_error:
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
# Fallback to msgpack (for BaseSearcher direct text requests)
try:
import msgpack
request = msgpack.unpackb(message)
# For BaseSearcher compatibility, request is a list of texts directly
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error:
raise RuntimeError(
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
)
# Look up texts by node IDs (only if not direct text request)
if not is_text_request:
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}")
raise e
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Debug logging
logger.debug(f"Processing {len(texts)} texts")
logger.debug(
f"Text lengths: {[len(t) for t in texts[:5]]}"
) # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Prepare response based on request type
if is_text_request:
# For BaseSearcher compatibility: return msgpack format
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(
embeddings, dtype=np.float32
)
# Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
raise
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
return
if __name__ == "__main__":
import signal
import sys
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument(
"--passages-file",
type=str,
help="Metadata JSON file containing passage sources",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode",
)
args = parser.parse_args()
# Create and start the DiskANN embedding server
create_diskann_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
)

View File

@@ -0,0 +1,741 @@
#!/usr/bin/env python3
"""
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
"""
import pickle
import argparse
import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel
import os
from contextlib import contextmanager
import zmq
import numpy as np
import msgpack
from pathlib import Path
import logging
RED = "\033[91m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
RESET = "\033[0m"
# --- New Passage Loader from HNSW backend ---
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
self._meta_path = ''
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self) -> int:
return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
def keys(self):
return self.label_map.keys()
loader = LazyPassageLoader(passage_manager, label_map)
loader._meta_path = meta_file
return loader
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
Create and run embedding server in the current thread
This function is designed to be called in a separate thread
"""
logger.info(f"Initializing embedding server thread on port {zmq_port}")
try:
# Check if port is already occupied
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
embedding_mode = "openai"
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
logger.info("Using MLX for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
logger.info("Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "sentence-transformers":
# Initialize model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
# Select device
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda")
logger.info("Using CUDA device")
elif mps_available:
device = torch.device("mps")
logger.info("Using MPS device (Apple Silicon)")
else:
device = torch.device("cpu")
logger.info("Using CPU device")
# Load model
logger.info(f"Loading model {model_name}")
model = AutoModel.from_pretrained(model_name).to(device).eval()
# Optimize model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
logger.info(f"Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
logger.info(f"Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
return
# Use protobuf format for warmup
from . import embedding_pb2
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.node_ids.extend(ids_to_send)
request_bytes = req_proto.SerializeToString()
for i in range(3):
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
socket.send(request_bytes)
response_bytes = socket.recv()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
resp_proto.ParseFromString(response_bytes)
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side Protobuf ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during Protobuf ZMQ warmup: {e}")
class DeviceTimer:
"""Device timer"""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
self.start_event = None
self.end_event = None
@contextmanager
def timing(self):
self.start()
yield
self.end()
def start(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
torch.cuda.synchronize()
self.start_event.record()
else:
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.end_event.record()
torch.cuda.synchronize()
else:
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
def print_elapsed(self):
elapsed = self.elapsed_time()
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
"""Process text batch"""
if not texts_batch:
return np.array([])
# Filter out empty texts and their corresponding IDs
valid_texts = []
valid_ids = []
for i, text in enumerate(texts_batch):
if text.strip(): # Only include non-empty texts
valid_texts.append(text)
valid_ids.append(ids_batch[i])
if not valid_texts:
print("WARNING: No valid texts in batch")
return np.array([])
# Tokenize
token_timer = DeviceTimer("tokenization")
with token_timer.timing():
inputs = tokenizer(
valid_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(device)
# Compute embeddings
embed_timer = DeviceTimer("embedding computation")
with embed_timer.timing():
with torch.no_grad():
outputs = model(**inputs)
hidden_states = outputs.last_hidden_state
# Mean pooling
attention_mask = inputs['attention_mask']
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
batch_embeddings = sum_embeddings / sum_mask
embed_timer.print_elapsed()
return batch_embeddings.cpu().numpy()
# ZMQ server main loop - modified to use REP socket
context = zmq.Context()
socket = context.socket(zmq.ROUTER) # Changed to REP socket
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
# Set timeouts
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
from . import embedding_pb2
print(f"INFO: Embedding server ready to serve requests")
# Start warmup thread if enabled
if enable_warmup and len(passages) > 0:
import threading
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
while True:
try:
parts = socket.recv_multipart()
# --- Restore robust message format detection ---
# Must check parts length to avoid IndexError
if len(parts) >= 3:
identity = parts[0]
# empty = parts[1] # We usually don't care about the middle empty frame
message = parts[2]
elif len(parts) == 2:
# Can also handle cases without empty frame
identity = parts[0]
message = parts[1]
else:
# If received message format is wrong, print warning and ignore it instead of crashing
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
continue
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
# Handle control messages (MessagePack format)
try:
request_payload = msgpack.unpackb(message)
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
response = [current_meta_path]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(f"INFO: Updating server meta path to: {new_meta_path}")
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
# Clean up old model to free memory
if not use_mlx:
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
# Optimize new model
if cuda_available or mps_available:
try:
model = model.half()
model = torch.compile(model)
print(f"INFO: Using FP16 precision with model: {new_model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
# Update model name
model_name = new_model_name
response = ["SUCCESS"]
print(f"INFO: Successfully updated model to: {new_model_name}")
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send_multipart([identity, b'', msgpack.packb(response)])
continue
except:
# Not a control message, continue with normal protobuf processing
pass
e2e_start = time.time()
lookup_timer = DeviceTimer("text lookup")
# Parse request
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = req_proto.node_ids
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
# Add debug information
if len(node_ids) > 0:
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
# Look up texts
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
if txt:
texts.append(txt)
else:
# If text is empty, we still need a placeholder for batch processing,
# but record its ID as missing
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed()
if missing_ids:
print(f"WARNING: Missing passages for IDs: {missing_ids}")
# Process batch
total_size = len(texts)
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
all_embeddings = []
if total_size > max_batch_size:
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
if embedding_mode == "mlx":
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
elif embedding_mode == "openai":
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
else: # sentence-transformers
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if embedding_mode == "sentence-transformers":
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
if embedding_mode == "mlx":
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
elif embedding_mode == "openai":
hidden = compute_embeddings_openai(texts, model_name)
else: # sentence-transformers
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
# Serialize response
ser_start = time.time()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
resp_proto.missing_ids.extend(missing_ids)
response_data = resp_proto.SerializeToString()
# REP socket sends a single response
socket.send_multipart([identity, b'', response_data])
ser_end = time.time()
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if embedding_mode == "sentence-transformers":
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
except zmq.Again:
print("INFO: ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"ERROR: Error in ZMQ server: {e}")
try:
# Send empty response to maintain REQ-REP state
empty_resp = embedding_pb2.NodeEmbeddingResponse()
socket.send(empty_resp.SerializeToString())
except:
# If sending fails, recreate socket
socket.close()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 5000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
print("INFO: ZMQ socket recreated after error")
except Exception as e:
print(f"ERROR: Failed to start embedding server: {e}")
raise
def create_embedding_server(
domain="demo",
load_passages=True,
load_embeddings=False,
use_fp16=True,
use_int8=False,
use_cuda_graphs=False,
zmq_port=5555,
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
choices=["sentence-transformers", "mlx", "openai"],
help="Embedding backend mode")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
# Handle backward compatibility with use_mlx
embedding_mode = args.embedding_mode
if args.use_mlx:
embedding_mode = "mlx"
create_embedding_server(
domain=args.domain,
load_passages=args.load_passages,
load_embeddings=args.load_embeddings,
use_fp16=args.use_fp16,
use_int8=args.use_int8,
use_cuda_graphs=args.use_cuda_graphs,
zmq_port=args.zmq_port,
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
embedding_mode=embedding_mode,
enable_warmup=not args.disable_warmup,
)

View File

@@ -4,16 +4,13 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.1.13"
dependencies = ["leann-core==0.1.13", "numpy", "protobuf>=3.19.0"]
version = "0.1.0"
dependencies = ["leann-core==0.1.0", "numpy"]
[tool.scikit-build]
# Key: simplified CMake path
# 关键:简化的 CMake 路径
cmake.source-dir = "third_party/DiskANN"
# Key: Python package in root directory, paths match exactly
# 关键:Python 包在根目录,路径完全匹配
wheel.packages = ["leann_backend_diskann"]
# Use default redirect mode
editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# 使用默认的 redirect 模式
editable.mode = "redirect"

View File

@@ -1,7 +1,6 @@
# 最终简化版
cmake_minimum_required(VERSION 3.24)
project(leann_backend_hnsw_wrapper)
set(CMAKE_C_COMPILER_WORKS 1)
set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS
if(APPLE)
@@ -12,9 +11,15 @@ if(APPLE)
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Use system ZeroMQ instead of building from source
find_package(PkgConfig REQUIRED)
pkg_check_modules(ZMQ REQUIRED libzmq)
# Build ZeroMQ from source
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/libzmq)
# Add cppzmq headers
include_directories(third_party/cppzmq)
@@ -24,7 +29,6 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
# Faiss configuration - streamlined build
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
@@ -32,24 +36,4 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
# Avoid building demos and benchmarks
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
# NEW: Tell Faiss to only build the generic version
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss)

View File

@@ -1,9 +1,10 @@
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
from typing import Dict, Any, List, Literal
import pickle
import shutil
import logging
import time
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -15,8 +16,6 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
def get_metric_map():
from . import faiss # type: ignore
@@ -48,10 +47,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute:
if self.is_compact:
# TODO: support this case @andy
raise ValueError("is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index.")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss # type: ignore
@@ -62,9 +57,13 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -86,7 +85,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
@@ -95,11 +94,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
)
if success:
logger.info("✅ CSR conversion successful.")
# index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
print("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
logger.info(
print(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
@@ -136,22 +135,31 @@ class HNSWSearcher(BaseSearcher):
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = (
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load label mapping
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
with open(label_map_file, "rb") as f:
self.label_map = pickle.load(f)
def search(
self,
query: np.ndarray,
top_k: int,
zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -169,7 +177,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -178,14 +186,15 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings:
if zmq_port is None:
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -193,10 +202,7 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
if zmq_port is not None:
params.zmq_port = (
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.zmq_port = zmq_port
params.efSearch = complexity
params.beam_size = beam_width
@@ -233,7 +239,11 @@ class HNSWSearcher(BaseSearcher):
)
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,22 +6,12 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.1.13"
version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.1.13",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
dependencies = ["leann-core==0.1.0", "numpy"]
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# CMake definitions to optimize compilation
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"
cmake.build-type = "Debug"
build.verbose = true

View File

@@ -4,44 +4,16 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.1.13"
description = "Core API and plugin system for LEANN"
version = "0.1.0"
description = "Core API and plugin system for Leann."
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
# All required dependencies included
dependencies = [
"numpy>=1.20.0",
"tqdm>=4.60.0",
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
"tqdm>=4.60.0"
]
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # 限制torch版本避免冲突
"transformers>=4.30.0,<5.0.0", # 限制transformers版本
"accelerate>=0.20.0,<1.0.0", # 限制accelerate版本
]
[project.scripts]
leann = "leann.cli:main"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -5,18 +5,16 @@ with the correct, original embedding logic from the user's reference code.
import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
import uuid
import torch
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging
logger = logging.getLogger(__name__)
def compute_embeddings(
@@ -24,8 +22,7 @@ def compute_embeddings(
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
port: Optional[int] = None,
is_build=False,
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -42,63 +39,251 @@ def compute_embeddings(
Returns:
numpy array of embeddings
"""
if use_server:
# Use embedding server (for search/query)
if port is None:
raise ValueError("port is required when use_server is True")
return compute_embeddings_via_server(chunks, model_name, port=port)
# Override mode for backward compatibility
if use_mlx:
mode = "mlx"
# Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai"
if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
chunks, model_name, use_server=use_server
)
else:
# Use direct computation (for build_index)
from .embedding_compute import (
compute_embeddings as compute_embeddings_direct,
)
return compute_embeddings_direct(
chunks,
model_name,
mode=mode,
is_build=is_build,
raise ValueError(
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
)
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
def compute_embeddings_sentence_transformers(
chunks: List[str], model_name: str, use_server: bool = True
) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
chunks: List of text chunks to embed
model_name: Name of the sentence transformer model
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
"""
logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
if not use_server:
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Use embedding server for sentence-transformers too
# This avoids loading the model twice (once in API, once in server)
try:
# Import ZMQ client functionality and server manager
import zmq
import msgpack
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Ensure embedding server is running
port = 5557
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
server_started = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
enable_warmup=False,
)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
socket.close()
context.term()
# Connect to embedding server
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation
request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server
response = socket.recv()
embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32)
socket.close()
context.term()
return embeddings
except Exception as e:
# Fallback to direct sentence-transformers if server connection fails
print(
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
)
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
def _compute_embeddings_sentence_transformers_direct(
chunks: List[str], model_name: str
) -> np.ndarray:
"""Direct sentence-transformers computation (fallback)."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise RuntimeError(
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
) from e
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
model = model.half()
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
)
# use acclerater GPU or MAC GPU
if torch.cuda.is_available():
model = model.to("cuda")
elif torch.backends.mps.is_available():
model = model.to("mps")
# Generate embeddings
# give use an warning if OOM here means we need to turn down the batch size
embeddings = model.encode(
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
)
return embeddings
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using OpenAI API."""
try:
import openai
import os
except ImportError as e:
raise RuntimeError(
"openai not available. Install with: uv pip install openai"
) from e
# Get API key from environment
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
)
# OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(chunks), max_batch_size)
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
except ImportError:
# Fallback without progress bar
batch_iterator = range(0, len(chunks), max_batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_chunks)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
from tqdm import tqdm
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
print(
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Load model and tokenizer
model, tokenizer = load(model_name)
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i:i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)
@dataclass
class SearchResult:
id: str
@@ -114,31 +299,25 @@ class PassageManager:
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"] # .idx file
# Fix path resolution for Colab and other environments
if not Path(index_file).is_absolute():
# If relative path, try to resolve it properly
index_file = str(Path(index_file).resolve())
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not Path(index_file).exists():
raise FileNotFoundError(
f"Passage index file not found: {index_file}"
)
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
@@ -149,7 +328,7 @@ class LeannBuilder:
def __init__(
self,
backend_name: str,
embedding_model: str = "facebook/contriever",
embedding_model: str = "facebook/contriever-msmarco",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
@@ -165,12 +344,14 @@ class LeannBuilder:
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
if 'mlx' in self.embedding_model:
self.embedding_mode = "mlx"
self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(len(self.chunks)))
passage_id = metadata.get("id", str(uuid.uuid4()))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
@@ -196,13 +377,10 @@ class LeannBuilder:
with open(passages_file, "w", encoding="utf-8") as f:
try:
from tqdm import tqdm
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
except ImportError:
chunk_iterator = self.chunks
for chunk in chunk_iterator:
offset = f.tell()
json.dump(
@@ -220,11 +398,7 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -298,7 +472,7 @@ class LeannBuilder:
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
logger.info(
print(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
)
@@ -306,7 +480,7 @@ class LeannBuilder:
if len(self.chunks) != len(ids):
# If no text chunks provided, create placeholder text entries
if not self.chunks:
logger.info("No text chunks provided, creating placeholder entries...")
print("No text chunks provided, creating placeholder entries...")
for id_val in ids:
self.add_text(
f"Document {id_val}",
@@ -381,23 +555,15 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(
f"Index built successfully from precomputed embeddings: {index_path}"
)
print(f"Index built successfully from precomputed embeddings: {index_path}")
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
@@ -405,15 +571,16 @@ class LeannSearcher:
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
def search(
self,
@@ -422,39 +589,26 @@ class LeannSearcher:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Additional kwargs: {kwargs}")
print("🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}")
zmq_port = None
start_time = time.time()
if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
**kwargs,
)
del expected_zmq_port
zmq_time = time.time() - start_time
logger.info(f" Launching server time: {zmq_time} seconds")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
start_time = time.time()
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds")
print(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
results = self.backend_impl.search(
@@ -469,14 +623,14 @@ class LeannSearcher:
**kwargs,
)
search_time = time.time() - start_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(
print(f" Search time: {search_time} seconds")
print(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = []
if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
print(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
@@ -490,25 +644,15 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data['text']
logger.info(
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
print(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
)
except KeyError:
RED = "\033[91m"
logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
print(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
)
logger.info(f" {GREEN} Final enriched results: {len(enriched_results)} passages{RESET}")
print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results
@@ -530,15 +674,15 @@ class LeannChat:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):
if llm_kwargs is None:
llm_kwargs = {}
search_time = time.time()
results = self.searcher.search(
question,
top_k=top_k,
@@ -547,11 +691,9 @@ class LeannChat:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
zmq_port=zmq_port,
**search_kwargs,
)
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"

View File

@@ -9,7 +9,6 @@ from typing import Dict, Any, Optional, List
import logging
import os
import difflib
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -29,68 +28,6 @@ def check_ollama_models() -> List[str]:
return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import requests
import re
# Split model name and tag
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
return True, available_tags[:10] # Return up to 10 tags
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
@@ -306,66 +243,24 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
# Use intelligent fuzzy search based on locally installed models
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += "\n https://ollama.com/library # Browse available models"
error_msg += "\nTo list all models: ollama list"
error_msg += "\nTo download a new model: ollama pull <model_name>"
error_msg += "\nBrowse models: https://ollama.com/library"
return error_msg
elif llm_type == "hf":
@@ -480,9 +375,8 @@ class OllamaChat(LLMInterface):
"stream": False, # Keep it simple for now
"options": kwargs,
}
logger.debug(f"Sending request to Ollama: {payload}")
logger.info(f"Sending request to Ollama: {payload}")
try:
logger.info(f"Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
@@ -502,7 +396,7 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
@@ -513,7 +407,7 @@ class HFChat(LLMInterface):
raise ValueError(model_error)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.pipelines import pipeline
import torch
except ImportError:
raise ImportError(
@@ -522,101 +416,54 @@ class HFChat(LLMInterface):
# Auto-detect device
if torch.cuda.is_available():
self.device = "cuda"
device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
self.device = "cpu"
device = "cpu"
logger.info("No GPU detected. Using CPU.")
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pipeline = pipeline("text-generation", model=model_name, device=device)
def ask(self, prompt: str, **kwargs) -> str:
print('kwargs in HF: ', kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
formatted_prompt = prompt
else:
# Fallback for models without chat template
formatted_prompt = prompt
# Map OpenAI-style arguments to Hugging Face equivalents
if "max_tokens" in kwargs:
# Prefer user-provided max_new_tokens if both are present
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
# Remove the unsupported key to avoid errors in Transformers
kwargs.pop("max_tokens")
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Handle temperature=0 edge-case for greedy decoding
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
# Remove unsupported zero temperature and use deterministic generation
kwargs.pop("temperature")
kwargs.setdefault("do_sample", False)
# Set generation parameters
generation_config = {
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
# Sensible defaults for text generation
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = (
results[0].get("generated_text", "")
if isinstance(results[0], dict)
else str(results[0])
)
else:
generated_text = str(results)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt) :].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
class OpenAIChat(LLMInterface):

View File

@@ -1,372 +0,0 @@
import argparse
import asyncio
from pathlib import Path
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
"""Extract text from PDF using PyMuPDF for better quality."""
try:
import fitz # PyMuPDF
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text
except ImportError:
# Fallback to default reader
return None
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
"""Extract text from PDF using pdfplumber for better quality."""
try:
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
except ImportError:
# Fallback to default reader
return None
class LeannCLI:
def __init__(self):
self.indexes_dir = Path.home() / ".leann" / "indexes"
self.indexes_dir.mkdir(parents=True, exist_ok=True)
self.node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
def get_index_path(self, index_name: str) -> str:
index_dir = self.indexes_dir / index_name
return str(index_dir / "documents.leann")
def index_exists(self, index_name: str) -> bool:
index_dir = self.indexes_dir / index_name
meta_file = index_dir / "documents.leann.meta.json"
return meta_file.exists()
def create_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="leann",
description="LEANN - Local Enhanced AI Navigation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
leann build my-docs --docs ./documents # Build index named my-docs
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
""",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Build command
build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument(
"--docs", type=str, required=True, help="Documents directory"
)
build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
)
build_parser.add_argument(
"--embedding-model", type=str, default="facebook/contriever"
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild"
)
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument("--recompute", action="store_true", default=True)
# Search command
search_parser = subparsers.add_parser("search", help="Search documents")
search_parser.add_argument("index_name", help="Index name")
search_parser.add_argument("query", help="Search query")
search_parser.add_argument("--top-k", type=int, default=5)
search_parser.add_argument("--complexity", type=int, default=64)
search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument("--recompute-embeddings", action="store_true")
search_parser.add_argument(
"--pruning-strategy",
choices=["global", "local", "proportional"],
default="global",
)
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"--llm",
type=str,
default="ollama",
choices=["simulated", "ollama", "hf", "openai"],
)
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument("--interactive", "-i", action="store_true")
ask_parser.add_argument("--top-k", type=int, default=20)
ask_parser.add_argument("--complexity", type=int, default=32)
ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument("--recompute-embeddings", action="store_true")
ask_parser.add_argument(
"--pruning-strategy",
choices=["global", "local", "proportional"],
default="global",
)
# List command
list_parser = subparsers.add_parser("list", help="List all indexes")
return parser
def list_indexes(self):
print("Stored LEANN indexes:")
if not self.indexes_dir.exists():
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
if not index_dirs:
print(
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return
print(f"Found {len(index_dirs)} indexes:")
for i, index_dir in enumerate(index_dirs, 1):
index_name = index_dir.name
status = "" if self.index_exists(index_name) else ""
print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json"
size_mb = sum(
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
) / (1024 * 1024)
print(f" Size: {size_mb:.1f} MB")
if index_dirs:
example_name = index_dirs[0].name
print(f"\nUsage:")
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...")
# Try to use better PDF parsers first
documents = []
docs_path = Path(docs_dir)
for file_path in docs_path.rglob("*.pdf"):
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
# Load other file types with default reader
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=[".txt", ".md", ".docx"],
).load_data(show_progress=True)
documents.extend(other_docs)
all_texts = []
for doc in documents:
nodes = self.node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
async def build_index(self, args):
docs_dir = args.docs
index_name = args.index_name
index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name)
if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return
all_texts = self.load_documents(docs_dir)
if not all_texts:
print("No documents found")
return
index_dir.mkdir(parents=True, exist_ok=True)
print(f"Building index '{index_name}' with {args.backend} backend...")
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
is_recompute=args.recompute,
num_threads=args.num_threads,
)
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"Index built at {index_path}")
async def search_documents(self, args):
index_name = args.index_name
query = args.query
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
print(f" {result.text[:200]}...")
print()
async def ask_questions(self, args):
index_name = args.index_name
index_path = self.get_index_path(index_name)
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return
print(f"Starting chat with index '{index_name}'...")
print(f"Using {args.model} ({args.llm})")
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
chat = LeannChat(index_path=index_path, llm_config=llm_config)
if args.interactive:
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not user_input:
continue
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
else:
query = input("Enter your question: ").strip()
if query:
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
async def run(self, args=None):
parser = self.create_parser()
if args is None:
args = parser.parse_args()
if not args.command:
parser.print_help()
return
if args.command == "list":
self.list_indexes()
elif args.command == "build":
await self.build_index(args)
elif args.command == "search":
await self.search_documents(args)
elif args.command == "ask":
await self.ask_questions(args)
else:
parser.print_help()
def main():
import dotenv
dotenv.load_dotenv()
cli = LeannCLI()
asyncio.run(cli.run())
if __name__ == "__main__":
main()

View File

@@ -1,377 +0,0 @@
"""
Unified embedding computation module
Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import numpy as np
import torch
from typing import List, Dict, Any
import logging
import os
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
) -> np.ndarray:
"""
Unified embedding computation entry point
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
model_name,
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
Args:
texts: List of texts to compute embeddings for
model_name: Model name
use_fp16: Whether to use FP16 precision
device: Device to use ('auto', 'cuda', 'mps', 'cpu')
batch_size: Batch size for processing
is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size
"""
# Handle empty input
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
# Auto-detect device
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Apply optimizations based on benchmark results
if adaptive_optimization:
# Use optimal batch_size constants for different devices based on benchmark results
if device == "mps":
batch_size = 128 # MPS optimal batch size from benchmark
if model_name == "Qwen/Qwen3-Embedding-0.6B":
batch_size = 32
elif device == "cuda":
batch_size = 256 # CUDA optimal batch size
# Keep original batch_size for CPU
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
# Check if model is already cached
if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
else:
logger.info(
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}")
# Apply hardware optimizations
if device == "cuda":
# TODO: Haven't tested this yet
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.cuda.set_per_process_memory_fraction(0.9)
elif device == "mps":
try:
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning(
"Some MPS optimizations not available in this PyTorch version"
)
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
try:
torch.backends.mkldnn.enabled = True
except AttributeError:
pass
# Prepare optimized model and tokenizer parameters
model_kwargs = {
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
"low_cpu_mem_usage": True,
"_fast_init": True,
"attn_implementation": "eager", # Use eager attention for speed
}
tokenizer_kwargs = {
"use_fast": True,
"padding": True,
"truncation": True,
}
try:
# Try local loading first
model_kwargs["local_files_only"] = True
tokenizer_kwargs["local_files_only"] = True
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=True,
)
logger.info("Model loaded successfully! (local + optimized)")
except Exception as e:
logger.warning(f"Local loading failed ({e}), trying network download...")
# Fallback to network loading
model_kwargs["local_files_only"] = False
tokenizer_kwargs["local_files_only"] = False
model = SentenceTransformer(
model_name,
device=device,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
local_files_only=False,
)
logger.info("Model loaded successfully! (network + optimized)")
# Apply additional optimizations based on mode
if use_fp16 and device in ["cuda", "mps"]:
try:
model = model.half()
logger.info(f"Applied FP16 precision: {model_name}")
except Exception as e:
logger.warning(f"FP16 optimization failed: {e}")
# Apply torch.compile optimization
if device in ["cuda", "mps"]:
try:
model = torch.compile(model, mode="reduce-overhead", dynamic=True)
logger.info(f"Applied torch.compile optimization: {model_name}")
except Exception as e:
logger.warning(f"torch.compile optimization failed: {e}")
# Set model to eval mode and disable gradients for inference
model.eval()
for param in model.parameters():
param.requires_grad_(False)
# Cache the model
_model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length
max_batch_size = 1000 # Conservative batch size
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
response = client.embeddings.create(model=model_name, input=batch_texts)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
print(f"len of embeddings: {len(embeddings)}")
return embeddings
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
from mlx_lm.utils import load
except ImportError as e:
raise RuntimeError(
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
) from e
logger.info(
f"Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
)
# Cache MLX model and tokenizer
cache_key = f"mlx_{model_name}"
if cache_key in _model_cache:
logger.info(f"Using cached MLX model: {model_name}")
model, tokenizer = _model_cache[cache_key]
else:
logger.info(f"Loading and caching MLX model: {model_name}")
model, tokenizer = load(model_name)
_model_cache[cache_key] = (model, tokenizer)
logger.info(f"MLX model cached: {cache_key}")
# Process chunks in batches with progress bar
all_embeddings = []
try:
from tqdm import tqdm
batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)
for i in batch_iterator:
batch_chunks = chunks[i : i + batch_size]
# Tokenize all chunks in the batch
batch_token_ids = []
for chunk in batch_chunks:
token_ids = tokenizer.encode(chunk) # type: ignore
batch_token_ids.append(token_ids)
# Pad sequences to the same length for batch processing
max_length = max(len(ids) for ids in batch_token_ids)
padded_token_ids = []
for token_ids in batch_token_ids:
# Pad with tokenizer.pad_token_id or 0
padded = token_ids + [0] * (max_length - len(token_ids))
padded_token_ids.append(padded)
# Convert to MLX array with batch dimension
input_ids = mx.array(padded_token_ids)
# Get embeddings for the batch
embeddings = model(input_ids)
# Mean pooling for each sequence in the batch
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
# Convert batch embeddings to numpy
for j in range(len(batch_chunks)):
pooled_list = pooled[j].tolist() # Convert to list
pooled_numpy = np.array(pooled_list, dtype=np.float32)
all_embeddings.append(pooled_numpy)
# Stack numpy arrays
return np.stack(all_embeddings)

View File

@@ -1,39 +1,14 @@
import threading
import time
import atexit
import socket
import subprocess
import sys
import os
import logging
import zmq
import msgpack
from pathlib import Path
from typing import Optional
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(levelname)s - %(name)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port+100}")
import select
def _check_port(port: int) -> bool:
@@ -42,135 +17,151 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
Check if the existing server on the port is using the correct meta file.
Returns True if the server has the right meta path, False otherwise.
"""
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
cmdline = proc.info["cmdline"]
if not cmdline:
continue
# Send a special control message to query the server's meta path
control_request = ["__QUERY_META_PATH__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the meta path and if it matches
if isinstance(response, list) and len(response) > 0:
server_meta_path = response[0]
# Normalize paths for comparison
expected_path = Path(expected_meta_path).resolve()
server_path = Path(server_meta_path).resolve() if server_meta_path else None
return server_path == expected_path
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
logger.warning(f"Could not check process on port {port}: {e}")
print(f"WARNING: Could not query server meta path on port {port}: {e}")
return False
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
"""
Send a control message to update the server's meta path.
Returns True if successful, False otherwise.
"""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the meta path
control_request = ["__UPDATE_META_PATH__", new_meta_path]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
except Exception as e:
print(f"ERROR: Could not update server meta path on port {port}: {e}")
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
def _check_server_model(port: int, expected_model: str) -> bool:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
Check if the existing server on the port is using the correct embedding model.
Returns True if the server has the right model, False otherwise.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
# Send a special control message to query the server's model
control_request = ["__QUERY_MODEL__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the model name and if it matches
if isinstance(response, list) and len(response) > 0:
server_model = response[0]
return server_model == expected_model
return False
except Exception as e:
print(f"WARNING: Could not query server model on port {port}: {e}")
return False
def _update_server_model(port: int, new_model: str) -> bool:
"""
Send a control message to update the server's embedding model.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the model
control_request = ["__UPDATE_MODEL__", new_model]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server model on port {port}: {e}")
return False
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
A generic manager for handling the lifecycle of a backend-specific embedding server process.
"""
def __init__(self, backend_module_name: str):
@@ -184,238 +175,246 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
self._atexit_registered = False
atexit.register(self.stop_server)
def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start the embedding server."""
passages_file = kwargs.get("passages_file")
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
"""
Starts the embedding server process.
# Check if we have a compatible server already running
if self._has_compatible_running_server(model_name, passages_file):
logger.info("Found compatible running server!")
return True, port
Args:
port (int): The ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
Returns:
bool: True if the server is started successfully or already running, False otherwise.
"""
if self.server_process and self.server_process.poll() is None:
# Even if we have a running process, check if model/meta path match
if self.server_port is not None:
port_in_use = _check_port(self.server_port)
if port_in_use:
print(
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
)
# Find a compatible port or next available
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(self.server_port, model_name):
print(
"❌ Failed to update existing server model. Restarting server..."
)
self.stop_server()
# Continue to start new server below
else:
print(
f"✅ Successfully updated existing server model to: {model_name}"
)
if is_compatible:
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
# Also check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
self.stop_server()
# Continue to start new server below
else:
# No port stored - restart
print("⚠️ No port information stored. Restarting server...")
self.stop_server()
# Continue to start new server below
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
if _check_port(port):
# Port is in use, check if it's using the correct meta file and model
passages_file = kwargs.get("passages_file")
logger.info(f"Starting server on port {actual_port} for Colab environment")
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
print(f"INFO: Port {port} is in use. Checking server compatibility...")
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server."""
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(port, model_name):
raise RuntimeError(
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):
meta_matches = _check_server_meta_path(port, str(passages_file))
if not meta_matches:
print(
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
)
if not _update_server_meta_path(port, str(passages_file)):
raise RuntimeError(
"❌ Failed to update server meta path. This may cause data synchronization issues."
)
print(
f"✅ Successfully updated server meta path to: {passages_file}"
)
else:
print(
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
)
print(f"✅ Server on port {port} is compatible and ready to use.")
return True
logger.info(
"Existing server process is incompatible. Should start a new server."
print(
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
)
return False
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding="utf-8",
bufsize=1, # Line buffered
universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print(
"❌ ERROR: Server process terminated unexpectedly during startup."
)
self._print_recent_output()
return False
time.sleep(wait_interval)
print(
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
)
self.stop_server()
return False
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _build_server_command(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
if kwargs.get("passages_file"):
# Convert to absolute path to ensure subprocess can find the file
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
logger.error("Server terminated during startup.")
return False, port
time.sleep(wait_interval)
logger.error(f"Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
"""Stops the embedding server process if it's running."""
if not self.server_process:
return
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
if self.server_process and self.server_process.poll() is None:
print(
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
)
self.server_process.kill()
# Clean up process resources to prevent resource tracker warnings
try:
self.server_process.wait() # Ensure process is fully cleaned up
except Exception:
pass
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print(
"WARNING: Server process did not terminate gracefully, killing it."
)
self.server_process.kill()
self.server_process = None
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback
if not self._atexit_registered:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error(f"Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
import numpy as np
from typing import Dict, Any, List, Literal, Optional
from typing import Dict, Any, List, Literal
class LeannBackendBuilderInterface(ABC):
@@ -34,13 +34,6 @@ class LeannBackendSearcherInterface(ABC):
"""
pass
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running"""
pass
@abstractmethod
def search(
self,
@@ -51,7 +44,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""Search for nearest neighbors
@@ -64,7 +57,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters
Returns:
@@ -74,10 +67,7 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod
def compute_query_embedding(
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""Compute embedding for a query string

View File

@@ -7,37 +7,30 @@ import importlib.metadata
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
def register_backend(name: str):
"""A decorator to register a new backend class."""
def decorator(cls):
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
# print("INFO: Starting backend auto-discovery...")
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata["name"]
if dist_name.startswith("leann-backend-"):
backend_module_name = dist_name.replace("-", "_")
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")

View File

@@ -1,7 +1,8 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, Literal, Optional
from typing import Dict, Any, Literal
import numpy as np
@@ -42,10 +43,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name,
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
@@ -57,9 +58,17 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, "rb") as f:
return pickle.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> int:
) -> None:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
@@ -69,26 +78,21 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
server_started, actual_port = self.embedding_server_manager.start_server(
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
embedding_mode=self.embedding_mode,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
embedding_mode=embedding_mode,
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
return actual_port
raise RuntimeError(f"Failed to start embedding server on port {port}")
def compute_query_embedding(
self,
query: str,
use_server_if_available: bool = True,
zmq_port: int = 5557,
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
) -> np.ndarray:
"""
Compute embedding for a query string.
@@ -102,21 +106,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if use_server_if_available:
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
try:
# TODO: Maybe we can directly use this port here?
# For this internal method, it's ok to assume that the server is running
# on that port?
# Ensure we have a server with passages_file for compatibility
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
# Convert to absolute path to ensure server can find it
zmq_port = self._ensure_server_running(
str(passages_source_file.resolve()), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape
@@ -125,7 +120,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("⏭️ Falling back to direct model loading...")
# Fallback to direct computation
from .embedding_compute import compute_embeddings
from .api import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
@@ -172,7 +167,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: int = 5557,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -186,7 +181,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
zmq_port: ZMQ port for embedding server communication
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns:

View File

@@ -1,40 +0,0 @@
# LEANN - The smallest vector index in the world
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
## Installation
```bash
# Default installation (HNSW backend, recommended)
uv pip install leann
# With DiskANN backend (for large-scale deployments)
uv pip install leann[diskann]
```
## Quick Start
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.build_index("my_index.leann")
# Search
searcher = LeannSearcher("my_index.leann")
results = searcher.search("storage savings", top_k=3)
# Chat with your data
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
response = chat.ask("How much storage does LEANN save?")
```
## Documentation
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
## License
MIT License

View File

@@ -1,12 +0,0 @@
"""
LEANN - Low-storage Embedding Approximation for Neural Networks
A revolutionary vector database that democratizes personal AI.
"""
__version__ = "0.1.0"
# Re-export main API from leann-core
from leann_core import LeannBuilder, LeannSearcher, LeannChat
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]

View File

@@ -1,42 +0,0 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.1.13"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
authors = [
{ name = "LEANN Team" }
]
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
# Default installation: core + hnsw
dependencies = [
"leann-core>=0.1.0",
"leann-backend-hnsw>=0.1.0",
]
[project.optional-dependencies]
diskann = [
"leann-backend-diskann>=0.1.0",
]
[project.urls]
Homepage = "https://github.com/yourusername/leann"
Documentation = "https://leann.readthedocs.io"
Repository = "https://github.com/yourusername/leann"
Issues = "https://github.com/yourusername/leann/issues"

View File

@@ -9,6 +9,7 @@ requires-python = ">=3.10"
dependencies = [
"leann-core",
"leann-backend-diskann",
"leann-backend-hnsw",
"numpy>=1.26.0",
"torch",
@@ -25,24 +26,16 @@ dependencies = [
"requests>=2.25.0",
"sentence-transformers>=2.2.0",
"openai>=1.0.0",
# PDF parsing dependencies - essential for document processing
"PyPDF2>=3.0.0",
"pdfplumber>=0.11.0",
"pymupdf>=1.26.0",
"pypdfium2>=4.30.0",
# LlamaIndex core and readers - updated versions
"llama-index>=0.12.44",
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
"llama-index-readers-docling",
"llama-index-node-parser-docling",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
# Other dependencies
"ipykernel==6.29.5",
"msgpack>=1.1.1",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
"psutil>=5.8.0",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
]
[project.optional-dependencies]
@@ -55,18 +48,6 @@ dev = [
"huggingface-hub>=0.20.0",
]
diskann = [
"leann-backend-diskann",
]
# Add a new optional dependency group for document processing
documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents
"openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing
]
[tool.setuptools]
py-modules = []

View File

@@ -0,0 +1,12 @@
import faiss
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(hnsw_index.ntotal)
# print stats of the graph
print(hnsw_index.hnsw.print_neighbor_stats(0))
# save_degree_distribution
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")

View File

@@ -0,0 +1,11 @@
import faiss
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(nsg_index.ntotal)
# print stats of the graph
print(nsg_index.nsg.print_neighbor_stats(0))
# save degree distribution
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")

63
research/micro/bnbtest.py Normal file
View File

@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import time
# import bitsandbytes as bnb
from bitsandbytes.nn import Linear8bitLt
# set default to half
import torch
torch.set_default_dtype(torch.float16)
M = 2048
N = 2048
bsz = 2048
import torch_int
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
fp16_model = nn.Sequential(
nn.Linear(M, N),
# nn.Linear(2048, 2048)
)
int8_model = nn.Sequential(
Linear8bitLt(M, N, has_fp16_weights=False),
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
# Create random input tensor
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
int8_time = speed_test(int8_model, input_tensor, "INT8")
# Calculate speedup
speedup = fp16_time / int8_time
print(f"INT8 is {speedup:.2f}x faster than FP16")

View File

@@ -0,0 +1,89 @@
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
1 n d seqlen bs latency h flop io intensity throughput series
2 3 256 256 2048 0.009623501679245285 768 618475290624 167.48502132816208 3692720015.912285 64267177503366.266 dense
3 3 256 256 1024 0.004853848615384615 768 309237645312 166.15392854317415 1861151572.059558 63709783682138.234 dense
4 3 256 256 512 0.0024687246971962615 768 154618822656 163.57953256539062 945221081.3366361 62631051097597.516 dense
5 3 256 256 256 0.0012845360838052097 768 77309411328 157.64931990085577 490388486.1451936 60184694149645.54 dense
6 3 256 256 128 0.0006901147179878049 768 38654705664 147.57393422494675 261934506.70684624 56012000116019.945 dense
7 3 256 256 64 0.0003363830693015702 768 19327352832 153.1328437752606 126212981.84970059 57456378146882.51 dense
8 3 256 256 32 0.00018671159748991485 768 9663676416 141.10249365427362 68486928.65540518 51757237075334.75 dense
9 3 256 256 16 0.00012353640857142858 768 4831838208 111.40488993609125 43371868.24359184 39112665358133.98 dense
10 3 256 256 8 9.774760007849294e-05 768 2415919104 76.43260800265766 31608487.09906635 24715891766754.14 dense
11 3 256 256 4 6.672271167474822e-05 768 1207959552 64.82614227498455 18633833.660438772 18104173551704.773 dense
12 3 256 256 2 4.9758770289855074e-05 768 603979776 55.317122669351576 10918495.880745342 12138157202874.861 dense
13 3 256 1 2048 9.785507940251571e-05 768 2415919104 76.34865809334705 31643242.518371396 24688745017132.86 dense
14 3 256 1 1024 6.692813470149253e-05 768 1207959552 64.62717090938949 18691202.70936228 18048606275785.867 dense
15 3 256 1 512 4.9680950036205655e-05 768 603979776 55.40377142534654 10901419.893658841 12157170415618.898 dense
16 3 256 1 256 4.2781118741058655e-05 768 301989888 45.95672244805227 6571179.83862661 7058952568020.829 dense
17 3 256 1 128 5.0662328255350016e-05 768 150994944 31.046026784880404 4863583.512513602 2980418571348.519 dense
18 3 256 1 64 4.475009253945481e-05 768 75497472 30.75426042497223 2454862.219307235 1687090857598.4766 dense
19 3 256 1 32 4.51682671454219e-05 768 37748736 28.29313765537115 1334201.1218340008 835735758435.5786 dense
20 3 256 1 16 5.03585186661834e-05 768 18874368 24.401035466223117 773506.846712577 374799904761.1871 dense
21 3 256 1 8 5.023459565217391e-05 768 9437184 23.972005435021096 393675.19858030166 187862246674.45105 dense
22 3 256 1 4 5.053219391083726e-05 768 4718592 23.58765586356967 200044.97383259286 93377936614.54384 dense
23 3 256 1 2 4.4607398995335484e-05 768 2359296 26.58285456464288 88752.54515134107 52890239133.797226 dense
24 12 256 256 2048 0.14480779847058822 3072 9895604649984 44.620009282941716 221775046868.20184 68336130750540.26 dense
25 12 256 256 1024 0.07254347629166667 3072 4947802324992 44.664248332585096 110777691547.58836 68204648824643.82 dense
26 12 256 256 512 0.036310761444444443 3072 2473901162496 44.876147984203506 55127306456.13385 68131349056975.164 dense
27 12 256 256 256 0.01821551906896552 3072 1236950581248 45.24607467289738 27338295977.947884 67906414116709.98 dense
28 12 256 256 128 0.009229417903030302 3072 618475290624 45.67217092440895 13541622351.335684 67011299859001.46 dense
29 12 256 256 64 0.004754550595394737 3072 309237645312 46.31372736116993 6677019167.566916 65040352207320.695 dense
30 12 256 256 32 0.002405752659340659 3072 154618822656 49.68826015254682 3111777755.5766335 64270456921525.82 dense
31 12 256 256 16 0.0012287219045005488 3072 77309411328 56.323579604557374 1372594069.3184311 62918558743709.18 dense
32 12 256 256 8 0.0006206816149425287 3072 38654705664 70.95456179103653 544781120.315271 62277832520589.78 dense
33 12 256 256 4 0.0003875502697142857 3072 19327352832 81.16954743236613 238110885.71245712 49870569942445.75 dense
34 12 256 256 2 0.00027502018627941914 3072 9663676416 91.50537035282076 105607751.53129694 35138062215483.168 dense
35 12 256 1 2048 0.0006202853873290136 3072 38654705664 70.99988634205897 544433345.6784943 62317614526515.766 dense
36 12 256 1 1024 0.00038721467732724153 3072 19327352832 81.2398957010995 237904697.74985722 49913791918755.53 dense
37 12 256 1 512 0.000274364799 3072 9663676416 91.72395326121995 105356082.81599998 35221998052308.45 dense
38 12 256 1 256 0.00012488918589482266 3072 4831838208 176.31707535146046 27404255.647778228 38689003962834.75 dense
39 12 256 1 128 8.976711102514506e-05 3072 2415919104 227.78088507574267 10606329.425740216 26913187652026.21 dense
40 12 256 1 64 8.715176287471176e-05 3072 1207959552 225.59268282689945 5354604.31102229 13860414432884.701 dense
41 12 256 1 32 8.523013435114503e-05 3072 603979776 226.06539514085782 2671703.8033338524 7086458100741.991 dense
42 12 256 1 16 7.901561645904116e-05 3072 301989888 241.35704882952732 1251216.3595988373 3821901309300.556 dense
43 12 256 1 8 7.827949114210329e-05 3072 150994944 242.37091635608994 622991.1833900034 1928920867994.581 dense
44 12 256 1 4 7.779445951035782e-05 3072 75497472 243.25022783249054 310369.58391664835 970473636235.5986 dense
45 12 256 1 2 7.758845406626506e-05 3072 37748736 243.57933441822672 154975.11761480253 486525172518.07056 dense
46 3 256 256 2048 0.00507974918466899 768 206158430208 475.59810852303485 433471930.42508715 40584371927298.98 qk_init
47 3 256 256 1024 0.0025616677649325623 768 103079215104 471.5519977009198 218595649.27424532 40239103803811.82 qk_init
48 3 256 256 512 0.0013029336670480549 768 51539607552 463.55374128015677 111183672.92143403 39556585922573.38 qk_init
49 3 256 256 256 0.0006738189029345373 768 25769803776 448.1766342333362 57499213.050413854 38244406121244.69 qk_init
50 3 256 256 128 0.000358254672959467 768 12884901888 421.47375986100144 30571065.425874516 35965760841472.125 qk_init
51 3 256 256 64 0.0002007051105022831 768 6442450944 376.1611839930762 17126836.096194826 32099087700742.5 qk_init
52 3 256 256 32 0.00012189697230142565 768 3221225472 309.6773881032524 10401874.969721656 26425803784810.87 qk_init
53 3 256 256 16 8.453561698040722e-05 768 1610612736 223.2711923587723 7213705.982328083 19052475081281.902 qk_init
54 3 256 256 8 6.407660705009276e-05 768 805306368 147.2797083750448 5467870.468274581 12567868448003.822 qk_init
55 3 256 256 4 5.036328747284576e-05 768 402653184 93.69110391262903 4297667.197682838 7994974200544.344 qk_init
56 3 256 256 2 4.5488761135057476e-05 768 201326592 51.865470527877875 3881707.616858238 4425853485045.578 qk_init
57 12 256 256 2048 0.020202365999999996 3072 824633720832 478.3437947812648 1723935231.9999998 40818670488001.266 qk_init
58 12 256 256 1024 0.010124155888157895 3072 412316860416 477.2583770318811 863927969.1228071 40726048173387.19 qk_init
59 12 256 256 512 0.005085633937062937 3072 206158430208 475.04777848703077 433974095.9627039 40537410430893.29 qk_init
60 12 256 256 256 0.0025654916853281853 3072 103079215104 470.84913933193053 218921957.14800516 40179126556324.74 qk_init
61 12 256 256 128 0.0013045765704467354 3072 51539607552 462.9699702434292 111323867.34478809 39506770794105.96 qk_init
62 12 256 256 64 0.0006742801519939804 3072 25769803776 447.87005387442576 57538572.970153 38218244597284.33 qk_init
63 12 256 256 32 0.00035831976790671853 3072 12884901888 421.3971919051604 30576620.194706645 35959227042573.69 qk_init
64 12 256 256 16 0.0002005369068918302 3072 6442450944 376.4766953382971 17112482.721436176 32126011335534.68 qk_init
65 12 256 256 8 0.00012179187250509165 3072 3221225472 309.94462293386505 10392906.453767821 26448607823689.82 qk_init
66 12 256 256 4 8.452507263643351e-05 3072 1610612736 223.2990450204527 7212806.198308992 19054851841745.297 qk_init
67 12 256 256 2 6.412381767545489e-05 3072 805306368 147.17127491946468 5471899.108305484 12558615459794.32 qk_init
68 3 256 256 2048 0.0016183739398395718 768 805306368 811597824.0 0.9922480620155039 1265467.7325087283 qk_ar
69 3 256 256 1024 0.0008322699728813558 768 402653184 405798912.0 0.9922480620155039 1230369.9921491416 qk_ar
70 3 256 256 512 0.00043886859397590365 768 201326592 202899456.0 0.9922480620155039 1166636.2255762408 qk_ar
71 3 256 256 256 0.00024185948322147648 768 100663296 101449728.0 0.9922480620155039 1058465.8355760013 qk_ar
72 3 256 256 128 0.00014308985100166944 768 50331648 50724864.0 0.9922480620155039 894542.82818777 qk_ar
73 3 256 256 64 9.382939365815932e-05 768 25165824 25362432.0 0.9922480620155039 682089.028872613 qk_ar
74 3 256 256 32 6.856070612244899e-05 768 12582912 12681216.0 0.9922480620155039 466739.6503012703 qk_ar
75 3 256 256 16 5.452260553129549e-05 768 6291456 6340608.0 0.9922480620155039 293456.26174846216 qk_ar
76 3 256 256 8 4.608557533261417e-05 768 3145728 3170304.0 0.9922480620155039 173590.1080166944 qk_ar
77 3 256 256 4 4.386146957766642e-05 768 1572864 1585152.0 0.9922480620155039 91196.21477609445 qk_ar
78 3 256 256 2 4.330941094420601e-05 768 786432 792576.0 0.9922480620155039 46179.33969539622 qk_ar
79 12 256 256 2048 0.006347041645299144 3072 3221225472 3246391296.0 0.9922480620155039 322670.011392918 qk_ar
80 12 256 256 1024 0.0031943104467592586 3072 1610612736 1623195648.0 0.9922480620155039 320569.96872013 qk_ar
81 12 256 256 512 0.0016183416350267381 3072 805306368 811597824.0 0.9922480620155039 316373.2483416833 qk_ar
82 12 256 256 256 0.0008325934893977947 3072 402653184 405798912.0 0.9922480620155039 307472.9784221131 qk_ar
83 12 256 256 128 0.0004389725746987952 3072 201326592 202899456.0 0.9922480620155039 291589.9702568624 qk_ar
84 12 256 256 64 0.00024191767449664432 3072 100663296 101449728.0 0.9922480620155039 264552.8076159138 qk_ar
85 12 256 256 32 0.0001431546143572621 3072 50331648 50724864.0 0.9922480620155039 223534.53392804778 qk_ar
86 12 256 256 16 9.404283597678917e-05 3072 25165824 25362432.0 0.9922480620155039 170135.23501087292 qk_ar
87 12 256 256 8 6.855550037091989e-05 3072 12582912 12681216.0 0.9922480620155039 116693.773026467 qk_ar
88 12 256 256 4 5.4802094978165945e-05 3072 6291456 6340608.0 0.9922480620155039 72989.91036006316 qk_ar
89 12 256 256 2 4.608510707869206e-05 3072 3145728 3170304.0 0.9922480620155039 43397.96795057727 qk_ar

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

594
research/micro/embedd_micro.py Executable file
View File

@@ -0,0 +1,594 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torchao import quantize_
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from contextlib import contextmanager
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
if batch_size not in self.graphs:
self.graphs[batch_size] = CUDAGraphWrapper(
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# 定义替换函数
def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# 获取原始线性层的参数
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# 创建8bit线性层
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False
)
# 复制权重和偏置
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# 替换模块
setattr(model, name, new_module)
else:
# 递归处理子模块
replace_linear_with_linear8bitlt(module)
return model
# 替换所有线性层
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# 将模型移到GPU量化发生在这里
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# 使用原来的Int4量化方法
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
# Int8 quantization using TorchAO
elif self.config.use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
# Import the quantize_ function and the quantization config
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
print("- Successfully imported TorchAO")
# Load model normally first
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = AutoModel.from_pretrained(
self.config.model_path,
device_map="auto"
)
print("- Model loaded in full precision")
print(f"- Model type: {type(model)}")
# Apply quantization - call the function to get the config, then apply it
# quantize_(model, int8_dynamic_activation_int8_weight())
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
print("- Model successfully quantized with int8 weights and int8 activations")
# add torch compile
model = torch.compile(model)
# For older PyTorch versions that have issues with tensor subclasses
from torchao.utils import unwrap_tensor_subclass
import torch
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
unwrap_tensor_subclass(model)
# Apply optimizations
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# For better performance with int8 dynamic quantization
torch._inductor.config.force_fuse_int_mm_with_mul = True
print("- Enabled fusion of int matmul with mul operations")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {str(e)}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
# Reset peak memory stats
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = (
self.cuda_graphs.get_or_create(batch_size)
if self.cuda_graphs is not None
else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
"results": {str(k): v for k, v in results.items()}
},
f,
indent=2
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,376 @@
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import AutoModel
from tqdm import tqdm
from contextlib import contextmanager
import math
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_cuda_graphs: bool = False
use_flash_attention: bool = False
max_batch_size: int = 256 # Maximum batch size before splitting
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
self.model = model
self.seq_length = seq_length
self.max_batch_size = max_batch_size
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
# For CUDA graphs, we always use the actual batch size or max_batch_size
effective_batch_size = min(batch_size, self.max_batch_size)
if effective_batch_size not in self.graphs:
self.graphs[effective_batch_size] = CUDAGraphWrapper(
self.model, effective_batch_size, self.seq_length
)
return self.graphs[effective_batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16:
model = model.half()
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
# No need to do anything as it's automatically enabled
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Optimize LayerNorm
try:
num_layernorms = 0
for module in model.modules():
if isinstance(module, torch.nn.LayerNorm):
module.forward = torch.jit.script(module.forward)
num_layernorms += 1
if num_layernorms > 0:
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
except Exception as e:
print(f"- LayerNorm optimization failed: {e}")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model = self._load_model()
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
return ModelOptimizer.optimize(model, self.config)
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
original_batch_size = input_ids.shape[0]
print(f"Original input_ids shape: {input_ids.shape}")
# Split large batches to avoid OOM
max_batch_size = self.config.max_batch_size
if original_batch_size > max_batch_size:
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
total_time = 0
outputs = []
with torch.no_grad():
for i in range(0, original_batch_size, max_batch_size):
end_idx = min(i + max_batch_size, original_batch_size)
batch_slice = input_ids[i:end_idx]
mask_slice = attention_mask[i:end_idx]
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
# Use CUDA graph if available (with the smaller batch size)
chunk_cuda_graph = None
if cuda_graph_wrapper is not None:
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
with self.timer.timing():
if chunk_cuda_graph is not None:
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
else:
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
total_time += self.timer.elapsed_time()
outputs.append(chunk_output.last_hidden_state)
# Combine outputs
combined_output = torch.cat(outputs, dim=0)
print(f"Combined output shape: {combined_output.shape}")
# Create a wrapper object similar to model output to maintain consistency
class DummyOutput:
def __init__(self, hidden_states):
self.last_hidden_state = hidden_states
output = DummyOutput(combined_output)
return total_time, output
else:
# Process normally for small batches
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
print(f"Output shape: {output.last_hidden_state.shape}")
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = None
if self.cuda_graphs is not None:
if batch_size <= self.config.max_batch_size:
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
else:
# For large batches, we'll use the max_batch_size graph in chunks
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
# Run benchmark
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
times.append(elapsed_time)
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--no_fp16",
action="store_true",
help="Disable FP16 inference",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=256,
help="Maximum batch size before splitting to prevent OOM",
)
args = parser.parse_args()
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=not args.no_fp16,
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
max_batch_size=args.max_batch_size,
)
benchmark = Benchmark(config)
results = benchmark.run()
# Print overall summary
print("\n===== BENCHMARK SUMMARY =====")
print(f"Model: {config.model_path}")
print(f"Sequence Length: {config.seq_length}")
print(f"FP16: {config.use_fp16}")
print(f"CUDA Graphs: {config.use_cuda_graphs}")
print(f"Flash Attention: {config.use_flash_attention}")
print(f"Max Batch Size: {config.max_batch_size}")
print("\nResults:")
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
print("-" * 50)
for bs in sorted(results.keys()):
r = results[bs]
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,218 @@
import torch
import torch.nn as nn
import time
import torch.nn.functional as F
# Import necessary functions from the quantize.py file
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
# Define dimensions that satisfy the requirements for INT4 quantization
# in_features must be divisible by inner_k_tiles * 16
# out_features must be divisible by 8
in_features = 1024 # Must be divisible by inner_k_tiles * 16
out_features = 2048 # Must be divisible by 8
groupsize = 128
inner_k_tiles = 8
# Create models
fp16_model = nn.Sequential(
nn.Linear(in_features, out_features, bias=False)
)
# Create INT4 model
int4_model = nn.Sequential(
WeightOnlyInt4Linear(in_features, out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
)
# Quantize the weights and set up the INT4 model
with torch.no_grad():
# Convert FP16 weights to INT4
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
fp16_weight, groupsize, inner_k_tiles
)
# Set the quantized weights in the INT4 model
int4_model[0].weight.copy_(weight_int4pack)
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
# Move models to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16_model = fp16_model.to(device)
int4_model = int4_model.to(device)
# Create random input tensor
batch_size = 1024
input_tensor = torch.randn(batch_size, in_features, device=device)
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
int4_time = speed_test(int4_model, input_tensor, "INT4")
# Calculate speedup
speedup = fp16_time / int4_time
print(f"INT4 is {speedup:.2f}x faster than FP16")
# Calculate memory savings
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
memory_reduction = fp16_memory / int4_memory
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
# Check accuracy
with torch.no_grad():
fp16_output = fp16_model(input_tensor_bf16)
int4_output = int4_model(input_tensor)
# Calculate error metrics
abs_error = torch.abs(fp16_output - int4_output)
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
print(f"Max absolute error: {abs_error.max().item():.6f}")
print(f"Mean relative error: {rel_error.mean().item():.6f}")

83
research/micro/int8.py Normal file
View File

@@ -0,0 +1,83 @@
import torch
import nvmath.bindings.cublas
import ctypes
# 创建 CUBLAS 句柄
handle = nvmath.bindings.cublas.create()
# 准备数据 - 使用 uint8 类型,并确保内存连续
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
# 确保张量在 CUDA 上
assert a.is_cuda and b.is_cuda and c.is_cuda
# 确保张量是连续的
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
# 获取指针
a_ptr = a.data_ptr()
b_ptr = b.data_ptr()
c_ptr = c.data_ptr()
# 设置参数
transa = 0 # CUBLAS_OP_N (不转置)
transb = 0 # CUBLAS_OP_N (不转置)
transc = 0 # CUBLAS_OP_N (不转置)
# 设置偏置值
a_bias = 0
b_bias = 0
c_bias = 0
# 设置正确的 leading dimensions
lda = k # A 的 leading dimension
ldb = n # B 的 leading dimension
ldc = n # C 的 leading dimension
c_mult = 1
c_shift = 0
# 打印调试信息
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
try:
# 调用 uint8gemm_bias
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr, a_bias, lda,
b_ptr, b_bias, ldb,
c_ptr, c_bias, ldc,
c_mult, c_shift
)
except Exception as e:
print(f"Error: {e}")
# 尝试使用 ctypes 转换指针
a_ptr_c = ctypes.c_void_p(a_ptr).value
b_ptr_c = ctypes.c_void_p(b_ptr).value
c_ptr_c = ctypes.c_void_p(c_ptr).value
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
# 再次尝试调用
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr_c, a_bias, lda,
b_ptr_c, b_bias, ldb,
c_ptr_c, c_bias, ldc,
c_mult, c_shift
)
# 销毁 CUBLAS 句柄
nvmath.bindings.cublas.destroy(handle)
# 打印结果
print("Result:")
print(c)

View File

@@ -0,0 +1,23 @@
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor import oneshot
# Select quantization algorithm. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to int8 with GPTQ (static per channel)
# * quantize the activations to int8 (dynamic per token)
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
]
# Apply quantization using the built in open_platypus dataset.
# * See examples for demos showing how to pass a custom calibration set
oneshot(
model="facebook/contriever",
dataset="open_platypus",
recipe=recipe,
output_dir="contriever-INT4",
max_seq_length=2048,
num_calibration_samples=512,
)

View File

@@ -0,0 +1,41 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
"""
This example demonstrates basic matrix multiplication of FP8 tensors.
In narrow-precision operations, quantization scales must be provided for each tensor. These
scales are used to dequantize input operands and quantize the result. Without proper
scaling, the results of FP8 operations will likely exceed the type's range.
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
capability 8.9 or higher.
"""
import torch
import nvmath
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
# transpose it.
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
# range of the data type used. Scales can be provided either as a dictionary or as a
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
scales = {"a": 1, "b": 1, "d": 0.1}
# Perform the multiplication. The result of the multiplication will be:
# (scales.a * A) @ (scales.b * B) * scales.d
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
print(result_without_scaling)
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
print(result)

0
research/micro/result.md Normal file
View File

View File

@@ -0,0 +1,58 @@
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
def save_model_in_pth_format(model_name, output_dir):
"""
Download a model from Hugging Face and save it in PTH format
for use with quantization benchmarks.
Args:
model_name: Name of the model on Hugging Face
output_dir: Directory to save the model
"""
print(f"Loading model {model_name}...")
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Extract and save the model weights in PTH format
model_state_dict = model.state_dict()
# Save the model weights
model_path = Path(output_dir) / "model.pth"
torch.save(model_state_dict, model_path)
print(f"Model saved to {model_path}")
# Print model size information
param_count = sum(p.numel() for p in model.parameters())
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
print(f"Model parameters: {param_count:,}")
print(f"Model size: {model_size_mb:.2f} MB")
return model_path
if __name__ == "__main__":
# Use a small model for testing
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
output_dir = "./tinyllama-1.1b-chat"
model_path = save_model_in_pth_format(model_name, output_dir)
print("\nYou can now use this model with the INT4 benchmark script.")
print("Example command:")
print(f"python int4benchmark.py --model_path {model_path}")

View File

@@ -0,0 +1,677 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cab91cfc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import copy\n",
"import dataclasses\n",
"import os\n",
"import time\n",
"import pathlib\n",
"import itertools\n",
"import multiprocessing\n",
"import scipy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pickle\n",
"import gzip\n",
"import threading\n",
"import queue\n",
"import pytz\n",
"import traceback\n",
"from datetime import datetime\n",
"from tqdm.auto import tqdm, trange\n",
"from typing import Any\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as mtick\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d24fbd7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Apr 12 00:10:05 2025 \n",
"+-----------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
"|-----------------------------------------+------------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+========================+======================|\n",
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+------------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=========================================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "538b2c11",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
" latency = []\n",
" \n",
" # First run, ignore min_secs\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" \n",
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
" min_nanos = int(min_secs * 1e9)\n",
" start_nanos = time.perf_counter_ns()\n",
" while True:\n",
" now_nanos = time.perf_counter_ns()\n",
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
" break\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" return np.array(latency)\n",
"\n",
"def tail_mean(xs, skip=0.2):\n",
" return xs[int(len(xs) * skip):].mean()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02c9c9b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.set_grad_enabled(False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3405fdc7",
"metadata": {},
"outputs": [],
"source": [
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
"seqlen_list = [256]\n",
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "10dc981a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(12, 256), (3, 256)]\n",
"[256]\n",
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
]
}
],
"source": [
"print(nd_list)\n",
"print(seqlen_list)\n",
"print(bs_list)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7e0ee385",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
" seqlen_list = [1] + seqlen_list\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" maxbs = max(bs_list)\n",
" print(maxbs, n, d, seqlen)\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" def run():\n",
" torch.matmul(X[:bs], W)\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, X, W\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c206a502",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3a2103c",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3aaad98a",
"metadata": {},
"outputs": [],
"source": [
"data = {}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "18137de3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/22 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_init\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "26c76e15",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_ar\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "313e36eb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
"data[\"dense\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "50c37959",
"metadata": {},
"outputs": [],
"source": [
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(data, f)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "828ddb54",
"metadata": {},
"outputs": [],
"source": [
"df_dense = (\n",
" pd.DataFrame.from_dict(data[\"dense\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"dense\")\n",
")\n",
"df_qk_init = (\n",
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"qk_init\")\n",
")\n",
"df_qk_ar = (\n",
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
" .assign(series=\"qk_ar\")\n",
")\n",
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "c296a395",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a25cdd5a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b8a531",
"metadata": {},
"outputs": [],
"source": [
"import transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af90eff1",
"metadata": {},
"outputs": [],
"source": [
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
" return transformers.OPTConfig(\n",
" num_hidden_layers=n_layers,\n",
" hidden_size=d_model,\n",
" ffn_dim=d_model*4,\n",
" num_attention_heads=n_heads,\n",
" **kwargs\n",
" )\n",
"optcfg = {\n",
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b9ebbec",
"metadata": {},
"outputs": [],
"source": [
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
" bs, tgt_len = input_ids.shape\n",
" if past_key_values is not None:\n",
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
" assert bs == _bs\n",
" else:\n",
" src_len = 0\n",
" if attention_mask is None:\n",
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
" ret = model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" past_key_values=past_key_values,\n",
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
" )\n",
" return ret\n",
"\n",
"def time_greedy_generate(model, input_ids, new_tokens):\n",
" ts = []\n",
" output = input_ids\n",
" past_key_values = None\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
" for _ in range(new_tokens):\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" st = time.perf_counter_ns()\n",
" \n",
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
" output = torch.cat([output, input_ids], axis=1)\n",
" past_key_values = ret.past_key_values\n",
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
" \n",
" torch.cuda.synchronize()\n",
" ed = time.perf_counter_ns()\n",
" ts.append((ed-st)/1e9)\n",
" return np.array(ts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc92f940",
"metadata": {},
"outputs": [],
"source": [
"opt_config = optcfg[\"6.7b\"]\n",
"\n",
"torch.set_default_dtype(torch.bfloat16)\n",
"with transformers.modeling_utils.no_init_weights():\n",
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
"torch.set_default_dtype(torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c19fa396",
"metadata": {},
"outputs": [],
"source": [
"db = {}\n",
"input_tokens = 200\n",
"new_tokens = 500\n",
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
" stack = []\n",
" for _ in range(10):\n",
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
" stack.append(l)\n",
" db[bs] = np.median(np.stack(stack), axis=0)\n",
" del x\n",
" torch.cuda.empty_cache()\n",
"del model\n",
"torch.cuda.empty_cache()\n",
"\n",
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(db, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,165 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Set plot parameters
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Path settings
FIGURE_PATH = "./paper_plot/figures"
# Load accuracy data
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
# Create figure with 4 subplots (one for each dataset)
fig, axs = plt.subplots(1, 4)
fig.set_size_inches(9, 2.5)
# Reduce the spacing between subplots
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
# Define datasets and their columns
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
metrics = ["Exact Match", "F1"]
# Define bar settings - make bars thicker
# total_width, n = 0.9, 3 # increased total width and n for three models
# width = total_width / n
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
# It's also used as the base for calculating the actual plotted bar width.
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
n = 3 # Number of models
width = 0.64 # Distance between centers of adjacent bars in a group
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
# Colors and hatches
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
# Create plots for each dataset
for i, dataset in enumerate(datasets):
ax = axs[i]
# Get data for this dataset and convert to percentages
em_values = [
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
]
f1_values = [
acc_data.loc[0, f"{dataset} F1"] * 100,
acc_data.loc[1, f"{dataset} F1"] * 100,
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
]
# Define x positions for bars
# For EM: center - width, center, center + width
# For F1: center - width, center, center + width
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
bar_offsets = [-width, 0, width]
# Plot all bars on the same axis
for metric_idx, metric_group_center in enumerate(group_centers):
values_to_plot = em_values if metric_idx == 0 else f1_values
for j, model_label in enumerate(labels):
x_pos = metric_group_center + bar_offsets[j]
bar_value = values_to_plot[j]
ax.bar(
x_pos,
bar_value,
width=width * bar_width_plotting_factor, # Use the new factor for bar width
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
)
# Add value on top of bar
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
f"{bar_value:.1f}", ha='center', va='bottom',
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
# Set x-ticks and labels
ax.set_xticks(group_centers) # Position ticks at the center of each group
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
# Now, shift these labels slightly to the right
# Adjust this value to control the amount of shift (in data coordinates)
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
# horizontal_shift = 0.7 # Try adjusting this value
# for label in xticklabels:
# # Get the current x position (which is the tick location)
# current_x_pos = label.get_position()[0]
# # Set the new x position by adding the shift
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
# # Ensure the label remains horizontally centered on this new x position
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
# label.set_horizontalalignment('center')
# Set title
ax.set_title(dataset, fontsize=14)
# Set y-label for all subplots
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
else:
# Hide y-tick labels for non-first subplots to save space
ax.tick_params(axis='y', labelsize=10)
# Set y-limits based on data range
all_values = em_values + f1_values
max_val = max(all_values)
min_val = min(all_values)
# Special handling for GPQA which has very low values
if dataset == "GPQA":
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
else:
# Reduce the extra space above the bars
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
# Format y-ticks as percentages
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
# Set x-limits to properly space the bars with less blank space
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
# Add a box around the subplot
# for spine in ax.spines.values():
# spine.set_visible(True)
# spine.set_linewidth(1.0)
# Add legend to first subplot
if i == 0:
ax.legend(
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
ncol=3, # Changed to 3 columns for three labels
loc="upper center",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.0,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 12},
)
# Save figure with tight layout but no additional padding
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
plt.show()

View File

@@ -0,0 +1,309 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /hnsw_degree_visit_plot_binned_academic.py
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
# per degree bin, styled for academic publications, with caching.
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from collections import Counter
import os # For robust filepath manipulation
import math # For calculating scaling factor
import pickle # For caching data
# %%
# --- Matplotlib parameters for academic paper style (from reference) ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
# --- Define styles from reference ---
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
# %%
# --- File Paths ---
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
visit_log_file = './re.log'
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
# --- CACHE FILE PATH: Keep this consistent ---
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
# --- Configuration ---
# Set to True to bypass cache and force recomputation.
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
FORCE_RECOMPUTE = False
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
# Create directory for figures if it doesn't exist
output_dir = os.path.dirname(output_image_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created directory: {output_dir}")
# %%
# --- Attempt to load data from cache or compute ---
df_plot_data = None
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
try:
with open(CACHE_FILE_PATH, 'rb') as f:
cache_content = pickle.load(f)
df_plot_data = cache_content['data']
bin_size_for_plot = cache_content['bin_size']
# Basic validation of cached data
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
if not isinstance(df_plot_data, pd.DataFrame) or \
'degree_bin_label' not in df_plot_data.columns or \
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
not isinstance(bin_size_for_plot, int):
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
else:
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
# --- Modify the label loaded from cache for display purpose ---
# This modification only happens when data is loaded from cache and meets specific conditions.
# Assumption: If the bin_size_for_plot in cache is 5,
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
# Check if "0-4" label exists
if '0-4' in df_plot_data['degree_bin_label'].values:
# Use .loc to ensure the modification is on the original DataFrame
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
except Exception as e:
print(f"Error loading from cache: {e}. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
if df_plot_data is None:
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
# --- 1. Read Degree Distribution File ---
degrees_data = []
try:
with open(degree_file, 'r') as f:
for i, line in enumerate(f):
line_stripped = line.strip()
if line_stripped:
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
except FileNotFoundError:
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
if not degrees_data:
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
exit()
df_degrees = pd.DataFrame(degrees_data)
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
# --- 2. Read Visit Log File and Count Frequencies ---
visit_counts = Counter()
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
try:
with open(visit_log_file, 'r') as f_log:
for line_num, line in enumerate(f_log, 1):
match = node_id_pattern.search(line)
if match:
try:
node_id = int(match.group(2))
visit_counts[node_id] += 1 # Increment visit count for the node
except ValueError:
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
except FileNotFoundError:
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
if not df_degrees.empty:
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
# Generate visit counts to test different probability magnitudes
if node_id_val % 23 == 0: # Very low probability
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
elif node_id_val % 11 == 0: # Low probability
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
elif node_id_val % 5 == 0: # Moderate probability
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
else: # Higher probability (but still < 1000 visits for a single node usually)
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
visit_counts[node_id_val] = np.random.poisson(lambda_val)
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
if not visit_counts:
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
else:
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
df_visits = pd.DataFrame(df_visits_list)
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
# --- 3. Merge Degree Data with Visit Data ---
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
print(f"Merged data contains {len(df_merged)} entries.")
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
current_bin_size = 5
bin_size_for_plot = current_bin_size
if not df_degrees.empty:
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
df_merged_with_bins = df_merged.copy()
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
total_visit_count_in_bin=('visit_count', 'sum'),
node_count_in_bin=('node_id', 'nunique')
).reset_index()
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
# This value is what gets cached.
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
bin_to_drop_label = '60-64'
original_length = len(df_binned_analysis)
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
if len(df_plot_data_intermediate) < original_length:
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
else:
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
df_plot_data = df_plot_data_intermediate
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
if df_plot_data is not None and not df_plot_data.empty:
try:
with open(CACHE_FILE_PATH, 'wb') as f:
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
except Exception as e:
print(f"Error saving data to cache: {e}")
elif df_plot_data is None or df_plot_data.empty:
print("Computed data for binned plot is empty, not saving to cache.")
else:
print("Degree data (df_degrees) is empty. Cannot perform binning.")
df_plot_data = pd.DataFrame()
bin_size_for_plot = current_bin_size
# %%
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
base_name, ext = os.path.splitext(output_image_file)
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
binned_output_image_file = base_name + ext
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
df_plot_data_plotting = df_plot_data.copy()
# Calculate per-query probability: (avg visits over N queries) / N
df_plot_data_plotting['per_query_visit_probability'] = \
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
apply_scaling_to_label_and_values = False # Initialize flag
exponent_for_label_display = 0 # Initialize exponent
if pd.notna(max_probability) and max_probability > 0:
potential_exponent = math.floor(math.log10(max_probability))
if potential_exponent <= -4 or potential_exponent >= 0:
apply_scaling_to_label_and_values = True
exponent_for_label_display = potential_exponent
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
if apply_scaling_to_label_and_values:
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
else:
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
elif pd.notna(max_probability) and max_probability == 0:
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
else:
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
ax.bar(
df_plot_data_plotting['degree_bin_label'],
y_axis_values_to_plot,
color='white',
edgecolor=edgecolors_ref[0],
linewidth=1.5,
width=0.8
)
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
num_bins = len(df_plot_data_plotting)
if num_bins > 12:
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
elif num_bins > 8:
ax.tick_params(axis='x', labelsize=9)
else:
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
padding_factor = 0.05
current_max_y_on_axis = y_axis_values_to_plot.max()
upper_y_limit = 0.1 # Default small upper limit
if pd.notna(current_max_y_on_axis):
if current_max_y_on_axis > 0:
# Adjust minimum visible range based on whether scaling was applied and the exponent
min_meaningful_limit = 0.01
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
else: # current_max_y_on_axis is 0
upper_y_limit = 0.1
ax.set_ylim(0, upper_y_limit)
else:
ax.set_ylim(0, 1.0) # Default for empty or NaN data
plt.tight_layout()
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
print(f"Binned bar chart saved to {binned_output_image_file}")
plt.show()
plt.close(fig)
else:
if df_plot_data is None:
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
elif df_plot_data.empty:
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
# %%
print("Script finished.")

7
research/paper_plot/b.md Normal file
View File

@@ -0,0 +1,7 @@
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data up to 50× smaller than standard indexes while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.

View File

@@ -0,0 +1,81 @@
import numpy as np
import os
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
BIG_GRAPH_PATHS = [
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
]
STATS_FILE_NAME = "degree_distribution.txt"
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
# Average degrees are static and can be directly used in the plotting script or also cached.
# For simplicity here, we'll focus on caching the dynamic degree arrays.
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
# --- Cache File Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
def create_degree_data_cache():
"""
Reads degree distribution data from specified text files and saves it
into a compressed NumPy (.npz) cache file.
"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
cached_data = {}
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
for i, base_path in enumerate(BIG_GRAPH_PATHS):
method_label = BIG_GRAPH_LABELS[i]
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
print(f"Processing: {method_label} from {degree_file_path}")
try:
# Load degrees as integers
degrees = np.loadtxt(degree_file_path, dtype=int)
if degrees.size == 0:
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
# Store an empty array or handle as needed. For npz, an empty array is fine.
cached_data[method_label] = np.array([], dtype=int)
else:
# Store the loaded degrees array with the method label as the key
cached_data[method_label] = degrees
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
except FileNotFoundError:
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
# Storing None might require special handling when loading. Empty array is safer for np.load.
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
except Exception as e:
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
if not cached_data:
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
return
try:
# Save all collected degree arrays into a single .npz file.
# Using savez_compressed for potentially smaller file size.
np.savez_compressed(cache_file_path, **cached_data)
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
print("Cached arrays (keys):", list(cached_data.keys()))
except Exception as e:
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
if __name__ == "__main__":
print("--- Degree Distribution Data Caching Script ---")
create_degree_data_cache()
print("--- Caching script finished. ---")

View File

@@ -0,0 +1,4 @@
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
1 Model NQ Exact Match NQ F1 TriviaQA Exact Match TriviaQA F1 GPQA Exact Match GPQA F1 HotpotQA Exact Match HotpotQA F1
2 BM25 0.192 0.277 0.406 0.474 0.020089 0.04524 0.162 0.239
3 PQ 5 0.2075 0.291 0.422 0.495 0.0201 0.0445 0.148 0.219
4 Ours 0.265 0.361 0.533 0.604 0.02008 0.0452 0.182 0.2729

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
size 227482438

View File

@@ -0,0 +1,21 @@
2,1,512,1024,0.541,0.326,1.659509202
2,2,512,1024,0.979,0.621,1.576489533
2,4,512,1024,1.846,0.977,1.889457523
2,8,512,1024,3.575,1.943,1.83993824
2,16,512,1024,7.035,3.733,1.884543263
2,32,512,1024,15.655,8.517,1.838088529
2,64,512,1024,32.772,17.43,1.88020654
4,1,512,1024,2.675,1.38,1.938405797
4,2,512,1024,5.397,2.339,2.307396323
4,4,512,1024,10.672,4.944,2.158576052
4,8,512,1024,21.061,9.266,2.272933305
4,16,512,1024,46.332,18.334,2.527108105
4,32,512,1024,99.607,36.156,2.754923111
4,64,512,1024,186.348,72.356,2.575432583
8,1,512,1024,7.325,4.087,1.792268167
8,2,512,1024,14.109,7.491,1.883460152
8,4,512,1024,28.499,14.013,2.033754371
8,8,512,1024,65.222,27.453,2.375769497
8,16,512,1024,146.294,52.55,2.783901047
8,32,512,1024,277.099,103.61,2.674442621
8,64,512,1024,512.979,208.36,2.461984066
1 2 1 512 1024 0.541 0.326 1.659509202
2 2 2 512 1024 0.979 0.621 1.576489533
3 2 4 512 1024 1.846 0.977 1.889457523
4 2 8 512 1024 3.575 1.943 1.83993824
5 2 16 512 1024 7.035 3.733 1.884543263
6 2 32 512 1024 15.655 8.517 1.838088529
7 2 64 512 1024 32.772 17.43 1.88020654
8 4 1 512 1024 2.675 1.38 1.938405797
9 4 2 512 1024 5.397 2.339 2.307396323
10 4 4 512 1024 10.672 4.944 2.158576052
11 4 8 512 1024 21.061 9.266 2.272933305
12 4 16 512 1024 46.332 18.334 2.527108105
13 4 32 512 1024 99.607 36.156 2.754923111
14 4 64 512 1024 186.348 72.356 2.575432583
15 8 1 512 1024 7.325 4.087 1.792268167
16 8 2 512 1024 14.109 7.491 1.883460152
17 8 4 512 1024 28.499 14.013 2.033754371
18 8 8 512 1024 65.222 27.453 2.375769497
19 8 16 512 1024 146.294 52.55 2.783901047
20 8 32 512 1024 277.099 103.61 2.674442621
21 8 64 512 1024 512.979 208.36 2.461984066

View File

@@ -0,0 +1,9 @@
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
NQ,Latency,6.9,5.8,4.2,3.7
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
TriviaQA,Latency,17.054,14.542,12.046,10.83
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
GPQA,Latency,9.164,7.639,6.798,5.77
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
HotpotQA,Latency,60.279,39.827,50.664,29.868
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
1 Dataset Metric Original original + batch original + two_level original + two_level + batch
2 NQ Latency 6.9 5.8 4.2 3.7
3 NQ SpeedUp 1 1.18965517 1.64285714 1.86486486
4 TriviaQA Latency 17.054 14.542 12.046 10.83
5 TriviaQA SpeedUp 1 1.17274103 1.41573967 1.57469990
6 GPQA Latency 9.164 7.639 6.798 5.77
7 GPQA SpeedUp 1 1.19963346 1.34804354 1.58821490
8 HotpotQA Latency 60.279 39.827 50.664 29.868
9 HotpotQA SpeedUp 1 1.51352098 1.18977972 2.01817999

View File

@@ -0,0 +1,25 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25 LLM_Gen_Time_1B LLM_Gen_Time_3B LLM_Gen_Time_7B
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 3.323 0.021 0.085 0.217 0.472
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 4.616 0 0.085 0.217 0.472
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 19.494 0 0.085 0.217 0.472
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 7.971 0.033 0.316 0.717 1.468
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 13.843 0 0.316 0.717 1.468
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 44.363 0 0.316 0.717 1.468
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 3.752 0.033 0.139 0.156 0.315
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 5.777 0 0.139 0.156 0.315
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 20.944 0 0.139 0.156 0.315
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 8.889 0.036 0.325 0.692 1.415
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 17.145 0 0.325 0.692 1.415
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 47.909 0 0.325 0.692 1.415
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 1.897 0.137 0.443 0.396 0.651
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 1.733 0 0.443 0.396 0.651
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 4.033 0 0.443 0.396 0.651
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 4.762 0.100 0.37 0.813 1.676
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 5.223 0 0.37 0.813 1.676
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 9.715 0 0.37 0.813 1.676
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 10.358 0.70 0.144 0.196 0.420
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 15.515 0 0.144 0.196 0.420
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 61.757 0 0.144 0.196 0.420
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 23.636 0.052 0.144 0.196 0.420
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 44.803 0 0.144 0.196 0.420
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 140.62 0 0.144 0.196 0.420

View File

@@ -0,0 +1,25 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 4.243
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 8.136
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 27.275
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 10.672
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 19.941
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 61.383
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 5.612
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 10.737
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 36.387
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 12.825
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 24.977
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 85.734
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 2.269
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 3.200
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 7.445
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 6.123
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 8.507
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 19.577
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 14.713
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 33.561
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 68.626
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 34.783
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 53.004
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 95.413

View File

@@ -0,0 +1,3 @@
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
RAM,190,171,10,0,0,0,0
Storage,185.4,171,240,171,0.5,5,59
1 Hardware HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25
2 RAM 190 171 10 0 0 0 0
3 Storage 185.4 171 240 171 0.5 5 59

View File

@@ -0,0 +1,12 @@
Torch,8,55.592
Torch,16,75.439
Torch,32,110.025
Torch,64,186.496
Tutel,8,56.718
Tutel,16,82.121
Tutel,32,125.070
Tutel,64,216.191
BRT,8,56.725
BRT,16,79.291
BRT,32,93.180
BRT,64,118.923
1 Torch 8 55.592
2 Torch 16 75.439
3 Torch 32 110.025
4 Torch 64 186.496
5 Tutel 8 56.718
6 Tutel 16 82.121
7 Tutel 32 125.070
8 Tutel 64 216.191
9 BRT 8 56.725
10 BRT 16 79.291
11 BRT 32 93.180
12 BRT 64 118.923

View File

@@ -0,0 +1,6 @@
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
Latency,,,,,
NQ,4.616,4.133,3.826,3.511,3.323
TriviaQA,5.777,4.979,4.553,4.141,3.916
GPQA,1.733,1.593,1.468,1.336,1.259
Hotpot,15.515,13.479,12.383,11.216,10.606
1 Disk cache size 0 2.5%(180G*2.5%) 5% 8% 10%
2 Latency
3 NQ 4.616 4.133 3.826 3.511 3.323
4 TriviaQA 5.777 4.979 4.553 4.141 3.916
5 GPQA 1.733 1.593 1.468 1.336 1.259
6 Hotpot 15.515 13.479 12.383 11.216 10.606

View File

@@ -0,0 +1,151 @@
import matplotlib
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
# plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % Use Helvetica font for text
\usepackage{sfmath} % Use sans-serif font for math
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
\usepackage[T1]{fontenc} % Recommended for font encoding
"""
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# New data in dictionary format
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
latency_data = {
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
}
cache_hit_counts = {
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
}
# Create the figure with 4 subplots in a 2x2 grid
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
# Bar style settings
width = 0.7
x = np.arange(len(cache_ratios))
# Define hatch patterns for different cache ratios
hatch_patterns = ['//', '//', '//', '//', '//']
# Find max cache hit value across all datasets for unified y-axis
all_hit_counts = []
for dataset in datasets:
all_hit_counts.extend(cache_hit_counts[dataset])
max_unified_hit = max(all_hit_counts) * 1.13
for i, dataset in enumerate(datasets):
latencies = latency_data[dataset]
hit_counts = cache_hit_counts[dataset]
for j, val in enumerate(latencies):
container = axes[i].bar(
x[j],
val,
width=width,
color="white",
edgecolor="black",
linewidth=1.0,
zorder=10,
)
axes[i].bar_label(
container,
[f"{val:.2f}"],
fontsize=10,
zorder=200,
fontweight="bold",
)
axes[i].set_title(dataset, fontsize=font_size)
axes[i].set_xticks(x)
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
max_val = max(latencies) * max_val_ratios[i]
axes[i].set_ylim(0, max_val)
axes[i].tick_params(axis='y', labelsize=12)
if i % 2 == 0:
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
ax2: Axes = axes[i].twinx()
ax2.plot(x, hit_counts,
linestyle='--',
marker='o',
markersize=6,
linewidth=1.5,
color='k',
markerfacecolor='none',
zorder=20)
ax2.set_ylim(0, max_unified_hit)
ax2.tick_params(axis='y', labelsize=12)
if i % 2 == 1:
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
for j, val in enumerate(hit_counts):
if val > 0:
ax2.annotate(f"{val:.1f}%",
(x[j], val),
textcoords="offset points",
xytext=(0, 5),
ha='center',
va='bottom',
fontsize=10,
fontweight='bold')
# Create legend for both plots
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
# --- MODIFICATION FOR LEGEND AT THE TOP ---
fig.legend(handles=[bar_patch, line_patch],
loc='upper center', # Position the legend at the upper center
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
# 0.97 means 97% from the bottom, so near the top)
ncol=3,
fontsize=font_size-2)
# --- END OF MODIFICATION ---
# Set common x-axis label - you might want to add this back if needed
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
# --- MODIFICATION FOR TIGHT LAYOUT ---
# Adjust rect to make space for the legend at the top.
# (left, bottom, right, top_for_subplots)
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
# leaving the top portion (0.93 to 1.0) for the legend.
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
# --- END OF MODIFICATION ---
# Create directory if it doesn't exist (optional, good practice)
import os
if not os.path.exists(SAVE_PTH):
os.makedirs(SAVE_PTH)
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
# plt.show() # Optional: to display the plot

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /gpu_utilization_plot.py
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
# Author: AI Assistant
import numpy as np
import pandas as pd # Using pandas for data structuring, similar to example
from matplotlib import pyplot as plt
# Apply styling similar to the example script
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
# New Benchmark data (4th set)
data = {
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
'avg_time_s': [
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
],
'throughput_seq_s': [
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
890.80, 946.93, 1022.75, 1017.03, 1043.17,
]
}
benchmark_df = pd.DataFrame(data)
# Create the plot
# Increased width slightly for more x-axis labels
fig, ax = plt.subplots()
fig.set_size_inches(8, 5)
# Generate equally spaced x-coordinates (indices)
x_indices = np.arange(len(benchmark_df))
# Plotting throughput vs. batch size (using indices for x-axis)
ax.plot(
x_indices, # Use equally spaced indices for plotting
benchmark_df['throughput_seq_s'],
marker='o', # Add markers to data points
linestyle='-',
color="#63B8B6", # A color inspired by the example's 'edgecolors'
linewidth=2,
markersize=6,
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
)
# Setting labels for axes
ax.set_xlabel("Batch Size", fontsize=14)
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
# Customizing Y-axis for the new data range:
# Start Y from 0 to include the anomalous low point and show full scale.
y_min_val = 200
# Round up y_max_val to the nearest 100, as max throughput > 1000
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
ax.set_ylim((y_min_val, y_max_val))
# Set y-ticks every 100 units, ensuring the top tick is included.
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
# Customizing X-axis for equally spaced ticks:
# Set tick positions to the indices
ax.set_xticks(x_indices)
# Set tick labels to the actual batch_size values
ax.set_xticklabels(benchmark_df['batch_size'])
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
ax.tick_params(axis='y', labelsize=12)
# Add a light grid for better readability, common in academic plots
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
# Remove title (as requested)
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
# Optional: Add a legend if you have multiple lines or want to label the single line
# ax.legend(
# loc="center right", # Location might need adjustment due to data shape
# edgecolor="black",
# facecolor="white",
# framealpha=1.0,
# shadow=False,
# fancybox=False,
# prop={"weight": "bold", "size": 10}
# ).set_zorder(100)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Save the figure
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
print(f"Plot saved to {output_filename}")
# Display the plot (optional, depending on environment)
plt.show()
# %%
# This is just to mimic the '%%' cell structure from the example.
# No actual code needed here for this script.

Some files were not shown because too many files have changed in this diff Show More