Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 |
11
.github/workflows/build-and-publish.yml
vendored
Normal file
11
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
167
.github/workflows/build-reusable.yml
vendored
Normal file
167
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.13'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||||
|
|
||||||
|
# Install Intel MKL for DiskANN
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
brew install llvm libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --system auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --system delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
126
.github/workflows/release-manual.yml
vendored
Normal file
126
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ needs.update-version.outputs.commit-sha }}
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ outputs/
|
|||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
|
|||||||
241
README.md
241
README.md
@@ -12,11 +12,11 @@
|
|||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
@@ -37,8 +37,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
|
|
||||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
|
|
||||||
## Quick Start in 1 minute
|
## Installation
|
||||||
|
> `pip leann` coming soon!
|
||||||
```bash
|
```bash
|
||||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
@@ -47,36 +47,30 @@ git submodule update --init --recursive
|
|||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
export CC=$(brew --prefix llvm)/bin/clang
|
|
||||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
# Install uv first if you don't have it:
|
||||||
|
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
# Or add DiskANN backend if you want to test more options
|
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||||
uv sync --extra diskann
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
**Linux:**
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
# Or add DiskANN backend if you want to test more options
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
**Ollama Setup (Recommended for full privacy):**
|
**Ollama Setup (Recommended for full privacy):**
|
||||||
|
|
||||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||||
|
|
||||||
|
|
||||||
*macOS:*
|
**macOS:**
|
||||||
|
|
||||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
@@ -85,7 +79,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
*Linux:*
|
**Linux:**
|
||||||
```bash
|
```bash
|
||||||
# Install Ollama
|
# Install Ollama
|
||||||
curl -fsSL https://ollama.ai/install.sh | sh
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
@@ -97,9 +91,10 @@ ollama serve &
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dead Simple API
|
## Quick Start in 30s
|
||||||
|
|
||||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
[Try in this ipynb file →](demo.ipynb) [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
@@ -130,24 +125,26 @@ response = chat.ask(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
## RAG on Everything!
|
||||||
|
|
||||||
[Try the interactive demo →](demo.ipynb)
|
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
## Wild Things You Can Do
|
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||||
|
|
||||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
|
|
||||||
### Process Any Documents (.pdf, .txt, .md)
|
<p align="center">
|
||||||
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
|
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||||
|
|
||||||
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Drop your PDFs, .txt, .md files into examples/data/
|
# Drop your PDFs, .txt, .md files into examples/data/
|
||||||
uv run ./examples/main_cli_example.py
|
uv run ./examples/main_cli_example.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
# Or use python directly
|
# Or use python directly
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
python ./examples/main_cli_example.py
|
python ./examples/main_cli_example.py
|
||||||
@@ -155,14 +152,17 @@ python ./examples/main_cli_example.py
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
### Search Your Entire Life
|
<p align="center">
|
||||||
|
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
```bash
|
```bash
|
||||||
python examples/mail_reader_leann.py
|
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||||
# "What's the number of class recommend to take per semester for incoming EECS students?"
|
|
||||||
```
|
```
|
||||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -195,12 +195,16 @@ Once the index is built, you can ask questions like:
|
|||||||
- "Show me emails about travel expenses"
|
- "Show me emails about travel expenses"
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Time Machine for the Web
|
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/google_history_reader_leann.py
|
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||||
# "Tell me my browser history about machine learning system stuff?"
|
|
||||||
```
|
```
|
||||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -249,13 +253,17 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### WeChat Detective
|
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/wechat_history_reader_leann.py
|
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||||
# "Show me all group chats about weekend plans"
|
|
||||||
```
|
```
|
||||||
**400K messages → 64MB.** Search years of chat history in any language.
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
@@ -266,7 +274,13 @@ First, you need to install the WeChat exporter:
|
|||||||
sudo packages/wechat-exporter/wechattweak-cli install
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
```
|
```
|
||||||
|
|
||||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
**Troubleshooting:**
|
||||||
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
|
```
|
||||||
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
|
Failed to find or export WeChat data. Exiting.
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -386,46 +400,18 @@ Options:
|
|||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
Run the comparison yourself:
|
|
||||||
```bash
|
|
||||||
python examples/compare_faiss_vs_leann.py
|
|
||||||
```
|
|
||||||
|
|
||||||
| System | Storage |
|
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||||
|--------|---------|
|
### Storage Comparison
|
||||||
| FAISS HNSW | 5.5 MB |
|
|
||||||
| LEANN | 0.5 MB |
|
|
||||||
| **Savings** | **91%** |
|
|
||||||
|
|
||||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|
|--------|-------------|------------|-------------|--------------|---------------|
|
||||||
|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||||
|
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||||
|
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Storage Usage Comparison
|
|
||||||
|
|
||||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
|
||||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
|
||||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
|
||||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
|
||||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
|
||||||
|
|
||||||
<!-- ### Memory Usage Comparison
|
|
||||||
|
|
||||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
|
||||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
|
||||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
|
||||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
|
||||||
| **Reduction** | **x%** | **x%** | **x%** |
|
|
||||||
|
|
||||||
### Query Performance of LEANN
|
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@3 |
|
|
||||||
| ------------------- | ---------- | ---------- | --------- |
|
|
||||||
| DiskANN | 1M docs | xms | 0.95 |
|
|
||||||
| HNSW | 1M docs | xms | 0.95 | -->
|
|
||||||
|
|
||||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
|
||||||
|
|
||||||
## Reproduce Our Results
|
## Reproduce Our Results
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -453,98 +439,15 @@ If you find Leann useful, please cite:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ [Detailed Features →](docs/features.md)
|
||||||
|
|
||||||
### 🔥 Core Features
|
## 🤝 [Contributing →](docs/contributing.md)
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
|
||||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
|
||||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
|
||||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
|
||||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
|
||||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
|
||||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
|
||||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
|
||||||
|
|
||||||
- **Simple Python API** - Get started in minutes
|
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
|
||||||
|
|
||||||
### Ways to Contribute
|
|
||||||
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
|
||||||
|
|
||||||
|
|
||||||
<!-- ## ❓ FAQ
|
## [FAQ →](docs/faq.md)
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
#### NCCL Topology Error
|
|
||||||
|
|
||||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
|
||||||
|
|
||||||
```
|
|
||||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
|
||||||
```
|
|
||||||
|
|
||||||
**Solution**: Set these environment variables before running your script:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
|
||||||
export NCCL_DEBUG=INFO
|
|
||||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
|
||||||
export NCCL_IB_DISABLE=1
|
|
||||||
export NCCL_NET_PLUGIN=none
|
|
||||||
export NCCL_SOCKET_IFNAME=ens5
|
|
||||||
``` -->
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
### 1. My building time seems long
|
|
||||||
|
|
||||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
|
||||||
|
|
||||||
|
|
||||||
## 📈 Roadmap
|
## 📈 [Roadmap →](docs/roadmap.md)
|
||||||
|
|
||||||
### 🎯 Q2 2025
|
|
||||||
|
|
||||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [X] HNSW backend integration
|
|
||||||
- [X] Real-time embedding pipeline
|
|
||||||
- [X] Memory-efficient graph pruning
|
|
||||||
|
|
||||||
### 🚀 Q3 2025
|
|
||||||
|
|
||||||
|
|
||||||
- [ ] Advanced caching strategies
|
|
||||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
|
||||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
|
||||||
- [ ] Add OpenAI recompute API
|
|
||||||
|
|
||||||
### 🌟 Q4 2025
|
|
||||||
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
|
||||||
- [ ] Visual similarity search
|
|
||||||
- [ ] Query rewrtiting, rerank and expansion
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -552,11 +455,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
- **Microsoft Research** for the DiskANN algorithm
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||||
- **Meta AI** for FAISS and optimization insights
|
|
||||||
- **HuggingFace** for the transformer ecosystem
|
|
||||||
- **Our amazing contributors** who make this possible
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
|||||||
306
demo.ipynb
306
demo.ipynb
@@ -1,37 +1,321 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Quick Start in 30s"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
"# install this if you areusing colab\n",
|
||||||
|
"! pip install leann"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO: Registering backend 'hnsw'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
|
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||||
|
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||||
|
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||||
|
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
||||||
|
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
||||||
|
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"M: 64 for level: 0\n",
|
||||||
|
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||||
|
"[0.00s] Reading Index HNSW header...\n",
|
||||||
|
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||||
|
"[0.00s] Reading HNSW struct vectors...\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||||
|
"[0.00s] Read assign_probas (6)\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||||
|
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||||
|
"[0.21s] Read levels (5)\n",
|
||||||
|
"[0.30s] Probing for compact storage flag...\n",
|
||||||
|
"[0.30s] Found compact flag: False\n",
|
||||||
|
"[0.30s] Compact flag is False, reading original format...\n",
|
||||||
|
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
||||||
|
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||||
|
"[0.30s] Read offsets (6)\n",
|
||||||
|
"[0.40s] Attempting to read neighbors vector...\n",
|
||||||
|
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||||
|
"[0.40s] Read neighbors (320)\n",
|
||||||
|
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||||
|
"[0.50s] Checking for storage data...\n",
|
||||||
|
"[0.50s] Found storage fourcc: 49467849.\n",
|
||||||
|
"[0.50s] Converting to CSR format...\n",
|
||||||
|
"[0.50s] Conversion loop finished. \n",
|
||||||
|
"[0.50s] Running validation checks...\n",
|
||||||
|
" Checking total valid neighbor count...\n",
|
||||||
|
" OK: Total valid neighbors = 20\n",
|
||||||
|
" Checking final pointer indices...\n",
|
||||||
|
" OK: Final pointers match data size.\n",
|
||||||
|
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
||||||
|
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||||
|
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||||
|
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||||
|
"[0.69s] Conversion complete.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||||
|
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 1. Build the index (no embeddings stored!)\n",
|
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"builder.build_index(\"knowledge.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||||
|
"INFO:leann.api: Query: 'programming languages'\n",
|
||||||
|
"INFO:leann.api: Top_k: 2\n",
|
||||||
|
"INFO:leann.api: Additional kwargs: {}\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
||||||
|
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||||
|
"To disable this warning, you can either:\n",
|
||||||
|
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||||
|
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||||
|
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||||
|
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||||
|
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||||
|
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||||
|
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||||
|
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||||
|
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||||
|
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||||
|
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||||
|
"INFO: Registering backend 'hnsw'\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
||||||
|
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
||||||
|
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||||
|
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||||
|
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||||
|
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
||||||
|
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||||
|
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||||
|
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||||
|
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||||
|
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
||||||
|
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||||
|
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
||||||
|
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||||
|
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||||
|
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||||
|
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||||
|
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||||
|
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||||
|
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||||
|
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||||
|
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||||
|
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||||
|
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||||
|
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
||||||
|
"INFO:leann.api: Top_k: 2\n",
|
||||||
|
"INFO:leann.api: Additional kwargs: {}\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||||
|
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||||
|
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||||
|
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||||
|
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
||||||
|
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
||||||
|
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||||
|
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||||
|
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||||
|
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||||
|
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||||
|
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 3. Chat with LEANN using retrieved results\n",
|
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"type\": \"ollama\",\n",
|
" \"type\": \"hf\",\n",
|
||||||
" \"model\": \"llama3.2:1b\"\n",
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
")"
|
" llm_kwargs={\"max_tokens\": 128}\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
11
docs/contributing.md
Normal file
11
docs/contributing.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
22
docs/features.md
Normal file
22
docs/features.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
@@ -135,6 +135,7 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
|
|||||||
@@ -222,14 +222,15 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
"max_tokens": 1000
|
"max_tokens": 1000
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
parser.add_argument('--index-dir', type=str, default="./google_history_index",
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
parser.add_argument('--max-entries', type=int, default=1000,
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
help='Maximum number of history entries to process (default: 1000)')
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def get_mail_path():
|
|||||||
return os.path.join(home_dir, "Library", "Mail")
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
# Default mail path for macOS
|
# Default mail path for macOS
|
||||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
"""
|
"""
|
||||||
@@ -77,7 +77,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -158,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
|||||||
print(f"Loaded {len(documents)} email documents")
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -218,22 +218,22 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=10,
|
top_k=20,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
complexity=12,
|
complexity=32,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
# print(f"Time taken: {end_time - start_time} seconds")
|
||||||
print(f"Leann: {chat_response}")
|
# highlight the answer
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||||
# Remove --mail-path argument and auto-detect all Messages directories
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
# Remove DEFAULT_MAIL_PATH
|
# Remove DEFAULT_MAIL_PATH
|
||||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
|
parser.add_argument('--index-dir', type=str, default="./mail_index",
|
||||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||||
parser.add_argument('--max-emails', type=int, default=1000,
|
parser.add_argument('--max-emails', type=int, default=1000,
|
||||||
help='Maximum number of emails to process (-1 means all)')
|
help='Maximum number of emails to process (-1 means all)')
|
||||||
@@ -253,6 +253,9 @@ async def main():
|
|||||||
mail_path = get_mail_path()
|
mail_path = get_mail_path()
|
||||||
print(f"Searching for email data in: {mail_path}")
|
print(f"Searching for email data in: {mail_path}")
|
||||||
messages_dirs = find_all_messages_directories(mail_path)
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||||
|
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||||
|
# messages_dirs = messages_dirs[:1]
|
||||||
|
|
||||||
print('len(messages_dirs): ', len(messages_dirs))
|
print('len(messages_dirs): ', len(messages_dirs))
|
||||||
|
|
||||||
|
|||||||
@@ -63,16 +63,14 @@ async def main(args):
|
|||||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
|
|
||||||
# query = (
|
# query = (
|
||||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
# )
|
# )
|
||||||
|
query = args.query
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -110,6 +108,12 @@ if __name__ == "__main__":
|
|||||||
default="examples/data",
|
default="examples/data",
|
||||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||||
|
help="The query to ask the Leann chat system.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(main(args))
|
asyncio.run(main(args))
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -234,7 +234,7 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
},
|
},
|
||||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
)
|
)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.1.0"
|
version = "0.1.9"
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.9", "numpy"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
@@ -92,8 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
index_file_old = index_file.with_suffix(".old")
|
# index_file_old = index_file.with_suffix(".old")
|
||||||
shutil.move(str(index_file), str(index_file_old))
|
# shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||||
|
|||||||
@@ -6,9 +6,14 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.1.0"
|
version = "0.1.9"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = [
|
||||||
|
"leann-core==0.1.9",
|
||||||
|
"numpy",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
wheel.packages = ["leann_backend_hnsw"]
|
wheel.packages = ["leann_backend_hnsw"]
|
||||||
|
|||||||
@@ -4,15 +4,23 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.1.0"
|
version = "0.1.9"
|
||||||
description = "Core API and plugin system for Leann."
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
|
# All required dependencies included
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"tqdm>=4.60.0"
|
"tqdm>=4.60.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class LeannBuilder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever-msmarco",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
@@ -441,9 +441,9 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
logger.info(f" Embedding time: {embedding_time} seconds")
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
@@ -458,7 +458,7 @@ class LeannSearcher:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
logger.info(f" Search time: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
logger.info(
|
logger.info(
|
||||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||||
)
|
)
|
||||||
@@ -479,15 +479,25 @@ class LeannSearcher:
|
|||||||
metadata=passage_data.get("metadata", {}),
|
metadata=passage_data.get("metadata", {}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Color codes for better logging
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
# Truncate text for display (first 100 chars)
|
||||||
|
display_text = passage_data['text']
|
||||||
logger.info(
|
logger.info(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
RED = "\033[91m"
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f" Final enriched results: {len(enriched_results)} passages")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
|
||||||
@@ -517,7 +527,7 @@ class LeannChat:
|
|||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
llm_kwargs = {}
|
llm_kwargs = {}
|
||||||
|
search_time = time.time()
|
||||||
results = self.searcher.search(
|
results = self.searcher.search(
|
||||||
question,
|
question,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -529,6 +539,8 @@ class LeannChat:
|
|||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import difflib
|
import difflib
|
||||||
|
import torch
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Check if a model exists in Ollama's remote library and return available tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model_exists, available_tags): bool and list of matching tags
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Split model name and tag
|
||||||
|
if ':' in model_name:
|
||||||
|
base_model, requested_tag = model_name.split(':', 1)
|
||||||
|
else:
|
||||||
|
base_model, requested_tag = model_name, None
|
||||||
|
|
||||||
|
# First check if base model exists in library
|
||||||
|
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||||
|
if library_response.status_code != 200:
|
||||||
|
return True, [] # Assume exists if can't check
|
||||||
|
|
||||||
|
# Extract model names from library page
|
||||||
|
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||||
|
|
||||||
|
if base_model not in models_in_library:
|
||||||
|
return False, [] # Base model doesn't exist
|
||||||
|
|
||||||
|
# If base model exists, get available tags
|
||||||
|
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||||
|
if tags_response.status_code != 200:
|
||||||
|
return True, [] # Base model exists but can't get tags
|
||||||
|
|
||||||
|
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||||
|
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||||
|
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||||
|
|
||||||
|
# Clean up tags - remove HTML artifacts and duplicates
|
||||||
|
available_tags = []
|
||||||
|
seen = set()
|
||||||
|
for tag in raw_tags:
|
||||||
|
# Skip if it looks like HTML (contains < or >)
|
||||||
|
if '<' in tag or '>' in tag:
|
||||||
|
continue
|
||||||
|
if tag not in seen:
|
||||||
|
seen.add(tag)
|
||||||
|
available_tags.append(tag)
|
||||||
|
|
||||||
|
# Check if exact model exists
|
||||||
|
if requested_tag is None:
|
||||||
|
# User just requested base model, suggest tags
|
||||||
|
return True, available_tags[:10] # Return up to 10 tags
|
||||||
|
else:
|
||||||
|
exact_match = model_name in available_tags
|
||||||
|
return exact_match, available_tags[:10]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If scraping fails, assume model might exist (don't block user)
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||||
"""Use intelligent fuzzy search for Ollama models"""
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models()
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
# Use intelligent fuzzy search based on locally installed models
|
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
|
||||||
|
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
if suggestions:
|
|
||||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
else:
|
|
||||||
error_msg += "\n\nYour installed models:\n"
|
|
||||||
for i, model in enumerate(available_models[:8], 1):
|
|
||||||
error_msg += f" {i}. {model}\n"
|
|
||||||
if len(available_models) > 8:
|
|
||||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
|
||||||
|
|
||||||
error_msg += "\nTo list all models: ollama list"
|
# Check if the model exists remotely and get available tags
|
||||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
|
||||||
|
if model_exists_remotely and model_name in available_tags:
|
||||||
|
# Exact model exists remotely - suggest pulling it
|
||||||
|
error_msg += f"\n\nTo install the requested model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n"
|
||||||
|
|
||||||
|
# Show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
elif model_exists_remotely and available_tags:
|
||||||
|
# Base model exists but requested tag doesn't - suggest correct tags
|
||||||
|
base_model = model_name.split(':')[0]
|
||||||
|
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||||
|
|
||||||
|
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||||
|
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||||
|
for i, tag in enumerate(available_tags[:8], 1):
|
||||||
|
error_msg += f" {i}. ollama pull {tag}\n"
|
||||||
|
if len(available_tags) > 8:
|
||||||
|
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||||
|
|
||||||
|
# Also show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Model doesn't exist remotely - show fuzzy suggestions
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
|
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
error_msg += "\n\nYour installed models:\n"
|
||||||
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
|
error_msg += "\n\nCommands:"
|
||||||
|
error_msg += "\n ollama list # List installed models"
|
||||||
|
if model_exists_remotely and available_tags:
|
||||||
|
if model_name in available_tags:
|
||||||
|
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||||
|
else:
|
||||||
|
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||||
|
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
@@ -397,7 +502,7 @@ class OllamaChat(LLMInterface):
|
|||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models."""
|
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
@@ -408,7 +513,7 @@ class HFChat(LLMInterface):
|
|||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.pipelines import pipeline
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -417,54 +522,101 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
self.device = "cuda"
|
||||||
logger.info("CUDA is available. Using GPU.")
|
logger.info("CUDA is available. Using GPU.")
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
device = "mps"
|
self.device = "mps"
|
||||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
self.device = "cpu"
|
||||||
logger.info("No GPU detected. Using CPU.")
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
# Load tokenizer and model
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move model to device if not using device_map
|
||||||
|
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
# Set pad token if not present
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
print('kwargs in HF: ', kwargs)
|
||||||
if "max_tokens" in kwargs:
|
# Check if this is a Qwen model and add /no_think by default
|
||||||
# Prefer user-provided max_new_tokens if both are present
|
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
|
||||||
# Remove the unsupported key to avoid errors in Transformers
|
# For Qwen models, automatically add /no_think to the prompt
|
||||||
kwargs.pop("max_tokens")
|
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||||
|
prompt = prompt + " /no_think"
|
||||||
|
|
||||||
|
# Prepare chat template
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
# Apply chat template if available
|
||||||
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||||
|
try:
|
||||||
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||||
|
formatted_prompt = prompt
|
||||||
|
else:
|
||||||
|
# Fallback for models without chat template
|
||||||
|
formatted_prompt = prompt
|
||||||
|
|
||||||
# Handle temperature=0 edge-case for greedy decoding
|
# Tokenize input
|
||||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
inputs = self.tokenizer(
|
||||||
# Remove unsupported zero temperature and use deterministic generation
|
formatted_prompt,
|
||||||
kwargs.pop("temperature")
|
return_tensors="pt",
|
||||||
kwargs.setdefault("do_sample", False)
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=2048
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move inputs to device
|
||||||
|
if self.device != "cpu":
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
# Sensible defaults for text generation
|
# Set generation parameters
|
||||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
generation_config = {
|
||||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||||
results = self.pipeline(prompt, **params)
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
"top_p": kwargs.get("top_p", 0.9),
|
||||||
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle temperature=0 for greedy decoding
|
||||||
|
if generation_config["temperature"] == 0.0:
|
||||||
|
generation_config["do_sample"] = False
|
||||||
|
generation_config.pop("temperature")
|
||||||
|
|
||||||
# Handle different response formats from transformers
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
if isinstance(results, list) and len(results) > 0:
|
|
||||||
generated_text = (
|
# Generate
|
||||||
results[0].get("generated_text", "")
|
with torch.no_grad():
|
||||||
if isinstance(results[0], dict)
|
outputs = self.model.generate(
|
||||||
else str(results[0])
|
**inputs,
|
||||||
|
**generation_config
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
generated_text = str(results)
|
|
||||||
|
|
||||||
# Extract only the newly generated portion by removing the original prompt
|
# Decode response
|
||||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||||
response = generated_text[len(prompt) :].strip()
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
else:
|
|
||||||
# Fallback: return the full response if prompt removal fails
|
return response.strip()
|
||||||
response = str(generated_text)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if device == "mps":
|
if device == "mps":
|
||||||
batch_size = 128 # MPS optimal batch size from benchmark
|
batch_size = 128 # MPS optimal batch size from benchmark
|
||||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||||
batch_size = 64
|
batch_size = 32
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
batch_size = 256 # CUDA optimal batch size
|
batch_size = 256 # CUDA optimal batch size
|
||||||
# Keep original batch_size for CPU
|
# Keep original batch_size for CPU
|
||||||
|
|||||||
@@ -269,7 +269,9 @@ class EmbeddingServerManager:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if kwargs.get("passages_file"):
|
if kwargs.get("passages_file"):
|
||||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
# Convert to absolute path to ensure subprocess can find the file
|
||||||
|
passages_file = Path(kwargs["passages_file"]).resolve()
|
||||||
|
command.extend(["--passages-file", str(passages_file)])
|
||||||
if embedding_mode != "sentence-transformers":
|
if embedding_mode != "sentence-transformers":
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
|
|
||||||
|
|||||||
@@ -112,8 +112,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_source_file = (
|
passages_source_file = (
|
||||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
)
|
)
|
||||||
|
# Convert to absolute path to ensure server can find it
|
||||||
zmq_port = self._ensure_server_running(
|
zmq_port = self._ensure_server_running(
|
||||||
str(passages_source_file), zmq_port
|
str(passages_source_file.resolve()), zmq_port
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
|
|||||||
40
packages/leann/README.md
Normal file
40
packages/leann/README.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# LEANN - The smallest vector index in the world
|
||||||
|
|
||||||
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default installation (HNSW backend, recommended)
|
||||||
|
uv pip install leann
|
||||||
|
|
||||||
|
# With DiskANN backend (for large-scale deployments)
|
||||||
|
uv pip install leann[diskann]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
# Build an index
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
|
builder.build_index("my_index.leann")
|
||||||
|
|
||||||
|
# Search
|
||||||
|
searcher = LeannSearcher("my_index.leann")
|
||||||
|
results = searcher.search("storage savings", top_k=3)
|
||||||
|
|
||||||
|
# Chat with your data
|
||||||
|
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||||
|
|
||||||
|
A revolutionary vector database that democratizes personal AI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Re-export main API from leann-core
|
||||||
|
from leann_core import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]
|
||||||
42
packages/leann/pyproject.toml
Normal file
42
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "leann"
|
||||||
|
version = "0.1.9"
|
||||||
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [
|
||||||
|
{ name = "LEANN Team" }
|
||||||
|
]
|
||||||
|
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default installation: core + hnsw
|
||||||
|
dependencies = [
|
||||||
|
"leann-core>=0.1.0",
|
||||||
|
"leann-backend-hnsw>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diskann = [
|
||||||
|
"leann-backend-diskann>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/yourusername/leann"
|
||||||
|
Documentation = "https://leann.readthedocs.io"
|
||||||
|
Repository = "https://github.com/yourusername/leann"
|
||||||
|
Issues = "https://github.com/yourusername/leann/issues"
|
||||||
@@ -33,8 +33,8 @@ dependencies = [
|
|||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
"llama-index-embeddings-huggingface>=0.5.5",
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
"mlx>=0.26.3",
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
"mlx-lm>=0.26.0",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
import faiss
|
|
||||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(hnsw_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
|
|
||||||
# save_degree_distribution
|
|
||||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
import faiss
|
|
||||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(nsg_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
# save degree distribution
|
|
||||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
|
|
||||||
# import bitsandbytes as bnb
|
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
|
||||||
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
|
|
||||||
M = 2048
|
|
||||||
N = 2048
|
|
||||||
|
|
||||||
bsz = 2048
|
|
||||||
import torch_int
|
|
||||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
|
||||||
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(M, N),
|
|
||||||
# nn.Linear(2048, 2048)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model = nn.Sequential(
|
|
||||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
|
||||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model.load_state_dict(fp16_model.state_dict())
|
|
||||||
int8_model = int8_model.to(0) # Quantization happens here
|
|
||||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
|
||||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int8_time
|
|
||||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
|
||||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
|
||||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
|
||||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
|
||||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
|
||||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
|
||||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
|
||||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
|
||||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
|
||||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
|
||||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
|
||||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
|
||||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
|
||||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
|
||||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
|
||||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
|
||||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
|
||||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
|
||||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
|
||||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
|
||||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
|
||||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
|
||||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
|
||||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
|
||||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
|
||||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
|
||||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
|
||||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
|
||||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
|
||||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
|
||||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
|
||||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
|
||||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
|
||||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
|
||||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
|
||||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
|
||||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
|
||||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
|
||||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
|
||||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
|
||||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
|
||||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
|
||||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
|
||||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
|
||||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
|
||||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
|
||||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
|
||||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
|
||||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
|
||||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
|
||||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
|
||||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
|
||||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
|
||||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
|
||||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
|
||||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
|
||||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
|
||||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
|
||||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
|
||||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
|
||||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
|
||||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
|
||||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
|
||||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
|
||||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
|
||||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
|
||||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
|
||||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
|
||||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
|
||||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
|
||||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
|
||||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
|
||||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
|
||||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
|
||||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
|
||||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
|
||||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
|
||||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
|
||||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
|
||||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
|
||||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
|
||||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
|
||||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
|
||||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
|
||||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
|
||||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
|
||||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
|
||||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
|
||||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
|
||||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 45 KiB |
@@ -1,594 +0,0 @@
|
|||||||
# python embedd_micro.py --use_int8 Fastest
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torchao import quantize_
|
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkConfig:
|
|
||||||
model_path: str
|
|
||||||
batch_sizes: List[int]
|
|
||||||
seq_length: int
|
|
||||||
num_runs: int
|
|
||||||
use_fp16: bool = True
|
|
||||||
use_int4: bool = False
|
|
||||||
use_int8: bool = False # Add this parameter
|
|
||||||
use_cuda_graphs: bool = False
|
|
||||||
use_flash_attention: bool = False
|
|
||||||
use_linear8bitlt: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
|
||||||
if batch_size not in self.graphs:
|
|
||||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
|
||||||
|
|
||||||
# Warm up
|
|
||||||
self._warmup()
|
|
||||||
|
|
||||||
# Capture graph
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_output = self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.static_input.copy_(input_ids)
|
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
|
||||||
self.graph.replay()
|
|
||||||
return self.static_output
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
|
||||||
"""Applies various optimizations to the model."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Cannot optimize None model")
|
|
||||||
|
|
||||||
# Move to GPU
|
|
||||||
model = model.cuda()
|
|
||||||
print("- Model moved to GPU")
|
|
||||||
|
|
||||||
# FP16
|
|
||||||
if config.use_fp16 and not config.use_int4:
|
|
||||||
model = model.half()
|
|
||||||
# use torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
print("- Using FP16 precision")
|
|
||||||
|
|
||||||
# Check if using SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Flash Attention
|
|
||||||
if config.use_flash_attention:
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
|
||||||
print("- Flash Attention 2 available")
|
|
||||||
if hasattr(model.config, "attention_mode"):
|
|
||||||
model.config.attention_mode = "flash_attention_2"
|
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
|
||||||
except ImportError:
|
|
||||||
print("- Flash Attention not available")
|
|
||||||
|
|
||||||
# Memory efficient attention
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
else:
|
|
||||||
print("- Model doesn't support xformers")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start_event.record()
|
|
||||||
yield
|
|
||||||
self.end_event.record()
|
|
||||||
self.end_event.synchronize()
|
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
"""Main benchmark runner."""
|
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
|
||||||
self.config = config
|
|
||||||
try:
|
|
||||||
self.model = self._load_model()
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model initialization failed - model is None")
|
|
||||||
|
|
||||||
self.cuda_graphs = (
|
|
||||||
CUDAGraphContainer(self.model, config.seq_length)
|
|
||||||
if config.use_cuda_graphs
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.timer = Timer()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Int4 quantization using HuggingFace integration
|
|
||||||
if self.config.use_int4:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import torch
|
|
||||||
# set default to half
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
torch_dtype=compute_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 定义替换函数
|
|
||||||
def replace_linear_with_linear8bitlt(model):
|
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
|
||||||
for name, module in list(model.named_children()):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
# 获取原始线性层的参数
|
|
||||||
in_features = module.in_features
|
|
||||||
out_features = module.out_features
|
|
||||||
bias = module.bias is not None
|
|
||||||
|
|
||||||
# 创建8bit线性层
|
|
||||||
# print size
|
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
bias=bias,
|
|
||||||
has_fp16_weights=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 复制权重和偏置
|
|
||||||
new_module.weight.data = module.weight.data
|
|
||||||
if bias:
|
|
||||||
new_module.bias.data = module.bias.data
|
|
||||||
|
|
||||||
# 替换模块
|
|
||||||
setattr(model, name, new_module)
|
|
||||||
else:
|
|
||||||
# 递归处理子模块
|
|
||||||
replace_linear_with_linear8bitlt(module)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# 替换所有线性层
|
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 使用原来的Int4量化方法
|
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
|
||||||
|
|
||||||
# Create quantization config
|
|
||||||
|
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
|
||||||
quantization_config = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
|
||||||
|
|
||||||
# Load model directly with quantization config
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
torch_dtype=compute_dtype,
|
|
||||||
device_map="auto" # Let HF decide on device mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if model loaded successfully
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Model loading returned None")
|
|
||||||
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply optimizations directly here
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
|
||||||
else:
|
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
|
||||||
print("- Model already on GPU due to device_map='auto'")
|
|
||||||
|
|
||||||
# Skip FP16 conversion since we specified compute_dtype
|
|
||||||
print(f"- Using {compute_dtype} for compute dtype")
|
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Try xformers if available
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
else:
|
|
||||||
print("- Model doesn't support xformers")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
# Int8 quantization using HuggingFace integration
|
|
||||||
# Int8 quantization using TorchAO
|
|
||||||
elif self.config.use_int8:
|
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
|
||||||
|
|
||||||
# Import the quantize_ function and the quantization config
|
|
||||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
|
||||||
print("- Successfully imported TorchAO")
|
|
||||||
|
|
||||||
# Load model normally first
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("- Model loaded in full precision")
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply quantization - call the function to get the config, then apply it
|
|
||||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
|
||||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
||||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
# For older PyTorch versions that have issues with tensor subclasses
|
|
||||||
from torchao.utils import unwrap_tensor_subclass
|
|
||||||
import torch
|
|
||||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
|
||||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
|
||||||
unwrap_tensor_subclass(model)
|
|
||||||
|
|
||||||
# Apply optimizations
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
# For better performance with int8 dynamic quantization
|
|
||||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
|
||||||
print("- Enabled fusion of int matmul with mul operations")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Standard loading for FP16/FP32
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
|
||||||
print("- Model loaded in standard precision")
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply standard optimizations
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
|
||||||
model = model.half()
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
|
|
||||||
# Final check to ensure model is not None
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Model is None after optimization")
|
|
||||||
|
|
||||||
print(f"- Final model type: {type(model)}")
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR loading model: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
|
||||||
else:
|
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# Reset peak memory stats
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
|
||||||
cuda_graph_wrapper = (
|
|
||||||
self.cuda_graphs.get_or_create(batch_size)
|
|
||||||
if self.cuda_graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
|
||||||
print(f"Input shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
||||||
try:
|
|
||||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
|
||||||
if i == 0: # Only print on first run
|
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
|
||||||
times.append(elapsed_time)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during inference: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not times:
|
|
||||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
std_time = np.std(times)
|
|
||||||
throughput = batch_size / avg_time
|
|
||||||
|
|
||||||
results[batch_size] = {
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"std_time": std_time,
|
|
||||||
"throughput": throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
||||||
|
|
||||||
# Log memory usage
|
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
|
||||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
|
||||||
|
|
||||||
# Add memory info to results
|
|
||||||
for batch_size in results:
|
|
||||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Path to the model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_sizes",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
|
||||||
help="Comma-separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Sequence length for input",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_runs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of runs for each batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_fp16",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable FP16 inference",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_int4",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable INT4 quantization using bitsandbytes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_int8",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_cuda_graphs",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA Graphs optimization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attention",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Flash Attention 2 if available",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_linear8bitlt",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Linear8bitLt quantization for all linear layers",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Print arguments for debugging
|
|
||||||
print("\nCommand line arguments:")
|
|
||||||
for arg, value in vars(args).items():
|
|
||||||
print(f"- {arg}: {value}")
|
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
|
||||||
model_path=args.model_path,
|
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
num_runs=args.num_runs,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int4=args.use_int4,
|
|
||||||
use_int8=args.use_int8, # Add this line
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
use_flash_attention=args.use_flash_attention,
|
|
||||||
use_linear8bitlt=args.use_linear8bitlt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print configuration for debugging
|
|
||||||
print("\nBenchmark configuration:")
|
|
||||||
for field, value in vars(config).items():
|
|
||||||
print(f"- {field}: {value}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
benchmark = Benchmark(config)
|
|
||||||
results = benchmark.run()
|
|
||||||
|
|
||||||
# Save results to file
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Create results directory if it doesn't exist
|
|
||||||
os.makedirs("results", exist_ok=True)
|
|
||||||
|
|
||||||
# Generate filename based on configuration
|
|
||||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
|
||||||
model_name = os.path.basename(config.model_path)
|
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
|
||||||
"results": {str(k): v for k, v in results.items()}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
indent=2
|
|
||||||
)
|
|
||||||
print(f"Results saved to {output_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Benchmark failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,376 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import AutoModel
|
|
||||||
from tqdm import tqdm
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import math
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkConfig:
|
|
||||||
model_path: str
|
|
||||||
batch_sizes: List[int]
|
|
||||||
seq_length: int
|
|
||||||
num_runs: int
|
|
||||||
use_fp16: bool = True
|
|
||||||
use_cuda_graphs: bool = False
|
|
||||||
use_flash_attention: bool = False
|
|
||||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
|
||||||
self.model = model
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
|
||||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
|
||||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
|
||||||
|
|
||||||
if effective_batch_size not in self.graphs:
|
|
||||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
|
||||||
self.model, effective_batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[effective_batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
|
||||||
|
|
||||||
# Warm up
|
|
||||||
self._warmup()
|
|
||||||
|
|
||||||
# Capture graph
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_output = self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.static_input.copy_(input_ids)
|
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
|
||||||
self.graph.replay()
|
|
||||||
return self.static_output
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
|
||||||
"""Applies various optimizations to the model."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
# Move to GPU
|
|
||||||
model = model.cuda()
|
|
||||||
print("- Model moved to GPU")
|
|
||||||
|
|
||||||
# FP16
|
|
||||||
if config.use_fp16:
|
|
||||||
model = model.half()
|
|
||||||
print("- Using FP16 precision")
|
|
||||||
|
|
||||||
# Check if using SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
# No need to do anything as it's automatically enabled
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Flash Attention
|
|
||||||
if config.use_flash_attention:
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
|
||||||
print("- Flash Attention 2 available")
|
|
||||||
if hasattr(model.config, "attention_mode"):
|
|
||||||
model.config.attention_mode = "flash_attention_2"
|
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
|
||||||
except ImportError:
|
|
||||||
print("- Flash Attention not available")
|
|
||||||
|
|
||||||
# Optimize LayerNorm
|
|
||||||
try:
|
|
||||||
num_layernorms = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, torch.nn.LayerNorm):
|
|
||||||
module.forward = torch.jit.script(module.forward)
|
|
||||||
num_layernorms += 1
|
|
||||||
if num_layernorms > 0:
|
|
||||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"- LayerNorm optimization failed: {e}")
|
|
||||||
|
|
||||||
# Memory efficient attention
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start_event.record()
|
|
||||||
yield
|
|
||||||
self.end_event.record()
|
|
||||||
self.end_event.synchronize()
|
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
"""Main benchmark runner."""
|
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
|
||||||
self.config = config
|
|
||||||
self.model = self._load_model()
|
|
||||||
self.cuda_graphs = (
|
|
||||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
|
||||||
if config.use_cuda_graphs
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.timer = Timer()
|
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
|
||||||
return ModelOptimizer.optimize(model, self.config)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
original_batch_size = input_ids.shape[0]
|
|
||||||
print(f"Original input_ids shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Split large batches to avoid OOM
|
|
||||||
max_batch_size = self.config.max_batch_size
|
|
||||||
if original_batch_size > max_batch_size:
|
|
||||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
|
||||||
total_time = 0
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(0, original_batch_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, original_batch_size)
|
|
||||||
batch_slice = input_ids[i:end_idx]
|
|
||||||
mask_slice = attention_mask[i:end_idx]
|
|
||||||
|
|
||||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
|
||||||
|
|
||||||
# Use CUDA graph if available (with the smaller batch size)
|
|
||||||
chunk_cuda_graph = None
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
|
||||||
|
|
||||||
with self.timer.timing():
|
|
||||||
if chunk_cuda_graph is not None:
|
|
||||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
|
||||||
else:
|
|
||||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
|
||||||
|
|
||||||
total_time += self.timer.elapsed_time()
|
|
||||||
outputs.append(chunk_output.last_hidden_state)
|
|
||||||
|
|
||||||
# Combine outputs
|
|
||||||
combined_output = torch.cat(outputs, dim=0)
|
|
||||||
print(f"Combined output shape: {combined_output.shape}")
|
|
||||||
|
|
||||||
# Create a wrapper object similar to model output to maintain consistency
|
|
||||||
class DummyOutput:
|
|
||||||
def __init__(self, hidden_states):
|
|
||||||
self.last_hidden_state = hidden_states
|
|
||||||
|
|
||||||
output = DummyOutput(combined_output)
|
|
||||||
return total_time, output
|
|
||||||
else:
|
|
||||||
# Process normally for small batches
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
|
||||||
else:
|
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
|
||||||
return self.timer.elapsed_time(), output
|
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
|
||||||
cuda_graph_wrapper = None
|
|
||||||
if self.cuda_graphs is not None:
|
|
||||||
if batch_size <= self.config.max_batch_size:
|
|
||||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
|
||||||
else:
|
|
||||||
# For large batches, we'll use the max_batch_size graph in chunks
|
|
||||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
||||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
|
||||||
times.append(elapsed_time)
|
|
||||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
std_time = np.std(times)
|
|
||||||
throughput = batch_size / avg_time
|
|
||||||
|
|
||||||
results[batch_size] = {
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"std_time": std_time,
|
|
||||||
"throughput": throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Path to the model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_sizes",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
|
||||||
help="Comma-separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Sequence length for input",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_runs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of runs for each batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_fp16",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable FP16 inference",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_cuda_graphs",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA Graphs optimization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attention",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Flash Attention 2 if available",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Maximum batch size before splitting to prevent OOM",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
|
||||||
model_path=args.model_path,
|
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
num_runs=args.num_runs,
|
|
||||||
use_fp16=not args.no_fp16,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
use_flash_attention=args.use_flash_attention,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark = Benchmark(config)
|
|
||||||
results = benchmark.run()
|
|
||||||
|
|
||||||
# Print overall summary
|
|
||||||
print("\n===== BENCHMARK SUMMARY =====")
|
|
||||||
print(f"Model: {config.model_path}")
|
|
||||||
print(f"Sequence Length: {config.seq_length}")
|
|
||||||
print(f"FP16: {config.use_fp16}")
|
|
||||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
|
||||||
print(f"Flash Attention: {config.use_flash_attention}")
|
|
||||||
print(f"Max Batch Size: {config.max_batch_size}")
|
|
||||||
print("\nResults:")
|
|
||||||
|
|
||||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
|
||||||
print("-" * 50)
|
|
||||||
for bs in sorted(results.keys()):
|
|
||||||
r = results[bs]
|
|
||||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Import necessary functions from the quantize.py file
|
|
||||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
||||||
# needed for GPTQ with padding
|
|
||||||
if groupsize > w.shape[-1]:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
assert groupsize > 1
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
||||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
||||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
||||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
||||||
torch.bfloat16
|
|
||||||
).reshape(w.shape[0], -1)
|
|
||||||
|
|
||||||
def pack_scales_and_zeros(scales, zeros):
|
|
||||||
assert scales.shape == zeros.shape
|
|
||||||
assert scales.dtype == torch.bfloat16
|
|
||||||
assert zeros.dtype == torch.bfloat16
|
|
||||||
return (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
||||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
||||||
],
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
||||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
||||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
||||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
||||||
return w_int32, scales_and_zeros
|
|
||||||
|
|
||||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
||||||
assert groupsize > 1
|
|
||||||
# needed for GPTQ single column quantize
|
|
||||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
scales = scales.reshape(-1, 1)
|
|
||||||
zeros = zeros.reshape(-1, 1)
|
|
||||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
min_int = 0
|
|
||||||
w_int32 = (
|
|
||||||
to_quant.sub(min_val)
|
|
||||||
.div(scales)
|
|
||||||
.round()
|
|
||||||
.clamp_(min_int, max_int)
|
|
||||||
.to(torch.int32)
|
|
||||||
.reshape_as(w)
|
|
||||||
)
|
|
||||||
|
|
||||||
return w_int32
|
|
||||||
|
|
||||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
||||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
||||||
weight_bf16, n_bit=4, groupsize=groupsize
|
|
||||||
)
|
|
||||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
|
||||||
return weight_int4pack, scales_and_zeros
|
|
||||||
|
|
||||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
||||||
origin_x_size = x.size()
|
|
||||||
x = x.reshape(-1, origin_x_size[-1])
|
|
||||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
|
||||||
new_shape = origin_x_size[:-1] + (out_features,)
|
|
||||||
c = c.reshape(new_shape)
|
|
||||||
return c
|
|
||||||
|
|
||||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
||||||
__constants__ = ['in_features', 'out_features']
|
|
||||||
in_features: int
|
|
||||||
out_features: int
|
|
||||||
weight: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_features: int, out_features: int,
|
|
||||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.groupsize = groupsize
|
|
||||||
self.inner_k_tiles = inner_k_tiles
|
|
||||||
|
|
||||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
||||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
|
||||||
self.register_buffer(
|
|
||||||
"weight",
|
|
||||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"scales_and_zeros",
|
|
||||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
input = input.to(torch.bfloat16)
|
|
||||||
return linear_forward_int4(
|
|
||||||
input,
|
|
||||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
|
||||||
# in_features must be divisible by inner_k_tiles * 16
|
|
||||||
# out_features must be divisible by 8
|
|
||||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
|
||||||
out_features = 2048 # Must be divisible by 8
|
|
||||||
groupsize = 128
|
|
||||||
inner_k_tiles = 8
|
|
||||||
|
|
||||||
# Create models
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(in_features, out_features, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create INT4 model
|
|
||||||
int4_model = nn.Sequential(
|
|
||||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
|
||||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Quantize the weights and set up the INT4 model
|
|
||||||
with torch.no_grad():
|
|
||||||
# Convert FP16 weights to INT4
|
|
||||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
|
||||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
|
||||||
fp16_weight, groupsize, inner_k_tiles
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the quantized weights in the INT4 model
|
|
||||||
int4_model[0].weight.copy_(weight_int4pack)
|
|
||||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
|
||||||
|
|
||||||
# Move models to GPU
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
fp16_model = fp16_model.to(device)
|
|
||||||
int4_model = int4_model.to(device)
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
batch_size = 1024
|
|
||||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
|
||||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
|
||||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
|
||||||
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
|
||||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int4_time
|
|
||||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
|
||||||
|
|
||||||
# Calculate memory savings
|
|
||||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
|
||||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
|
||||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
|
||||||
|
|
||||||
memory_reduction = fp16_memory / int4_memory
|
|
||||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
|
||||||
|
|
||||||
# Check accuracy
|
|
||||||
with torch.no_grad():
|
|
||||||
fp16_output = fp16_model(input_tensor_bf16)
|
|
||||||
int4_output = int4_model(input_tensor)
|
|
||||||
|
|
||||||
# Calculate error metrics
|
|
||||||
abs_error = torch.abs(fp16_output - int4_output)
|
|
||||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
|
||||||
|
|
||||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
|
||||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
|
||||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import nvmath.bindings.cublas
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
# 创建 CUBLAS 句柄
|
|
||||||
handle = nvmath.bindings.cublas.create()
|
|
||||||
|
|
||||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
|
||||||
|
|
||||||
# 确保张量在 CUDA 上
|
|
||||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
|
||||||
# 确保张量是连续的
|
|
||||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
|
||||||
|
|
||||||
# 获取指针
|
|
||||||
a_ptr = a.data_ptr()
|
|
||||||
b_ptr = b.data_ptr()
|
|
||||||
c_ptr = c.data_ptr()
|
|
||||||
|
|
||||||
# 设置参数
|
|
||||||
transa = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transb = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transc = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
|
|
||||||
# 设置偏置值
|
|
||||||
a_bias = 0
|
|
||||||
b_bias = 0
|
|
||||||
c_bias = 0
|
|
||||||
|
|
||||||
# 设置正确的 leading dimensions
|
|
||||||
lda = k # A 的 leading dimension
|
|
||||||
ldb = n # B 的 leading dimension
|
|
||||||
ldc = n # C 的 leading dimension
|
|
||||||
|
|
||||||
c_mult = 1
|
|
||||||
c_shift = 0
|
|
||||||
|
|
||||||
# 打印调试信息
|
|
||||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
|
||||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
|
||||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 uint8gemm_bias
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr, a_bias, lda,
|
|
||||||
b_ptr, b_bias, ldb,
|
|
||||||
c_ptr, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
# 尝试使用 ctypes 转换指针
|
|
||||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
|
||||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
|
||||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
|
||||||
|
|
||||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
|
||||||
|
|
||||||
# 再次尝试调用
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr_c, a_bias, lda,
|
|
||||||
b_ptr_c, b_bias, ldb,
|
|
||||||
c_ptr_c, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
|
|
||||||
# 销毁 CUBLAS 句柄
|
|
||||||
nvmath.bindings.cublas.destroy(handle)
|
|
||||||
|
|
||||||
# 打印结果
|
|
||||||
print("Result:")
|
|
||||||
print(c)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
|
||||||
from llmcompressor import oneshot
|
|
||||||
|
|
||||||
# Select quantization algorithm. In this case, we:
|
|
||||||
# * apply SmoothQuant to make the activations easier to quantize
|
|
||||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
|
||||||
# * quantize the activations to int8 (dynamic per token)
|
|
||||||
recipe = [
|
|
||||||
SmoothQuantModifier(smoothing_strength=0.8),
|
|
||||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply quantization using the built in open_platypus dataset.
|
|
||||||
# * See examples for demos showing how to pass a custom calibration set
|
|
||||||
oneshot(
|
|
||||||
model="facebook/contriever",
|
|
||||||
dataset="open_platypus",
|
|
||||||
recipe=recipe,
|
|
||||||
output_dir="contriever-INT4",
|
|
||||||
max_seq_length=2048,
|
|
||||||
num_calibration_samples=512,
|
|
||||||
)
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
"""
|
|
||||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
|
||||||
|
|
||||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
|
||||||
scales are used to dequantize input operands and quantize the result. Without proper
|
|
||||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
|
||||||
|
|
||||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
|
||||||
capability 8.9 or higher.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import nvmath
|
|
||||||
|
|
||||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
|
||||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
|
||||||
# transpose it.
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
|
||||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
|
||||||
|
|
||||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
|
||||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
|
||||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
|
||||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
|
||||||
|
|
||||||
# Perform the multiplication. The result of the multiplication will be:
|
|
||||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
|
||||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
|
||||||
|
|
||||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
|
||||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
|
||||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
|
||||||
print(result_without_scaling)
|
|
||||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
|
||||||
print(result)
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def save_model_in_pth_format(model_name, output_dir):
|
|
||||||
"""
|
|
||||||
Download a model from Hugging Face and save it in PTH format
|
|
||||||
for use with quantization benchmarks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model on Hugging Face
|
|
||||||
output_dir: Directory to save the model
|
|
||||||
"""
|
|
||||||
print(f"Loading model {model_name}...")
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Load tokenizer and model
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
low_cpu_mem_usage=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save tokenizer
|
|
||||||
tokenizer.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
# Extract and save the model weights in PTH format
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
|
|
||||||
# Save the model weights
|
|
||||||
model_path = Path(output_dir) / "model.pth"
|
|
||||||
torch.save(model_state_dict, model_path)
|
|
||||||
|
|
||||||
print(f"Model saved to {model_path}")
|
|
||||||
|
|
||||||
# Print model size information
|
|
||||||
param_count = sum(p.numel() for p in model.parameters())
|
|
||||||
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
|
||||||
|
|
||||||
print(f"Model parameters: {param_count:,}")
|
|
||||||
print(f"Model size: {model_size_mb:.2f} MB")
|
|
||||||
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Use a small model for testing
|
|
||||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
||||||
output_dir = "./tinyllama-1.1b-chat"
|
|
||||||
|
|
||||||
model_path = save_model_in_pth_format(model_name, output_dir)
|
|
||||||
|
|
||||||
print("\nYou can now use this model with the INT4 benchmark script.")
|
|
||||||
print("Example command:")
|
|
||||||
print(f"python int4benchmark.py --model_path {model_path}")
|
|
||||||
@@ -1,677 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "cab91cfc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import copy\n",
|
|
||||||
"import dataclasses\n",
|
|
||||||
"import os\n",
|
|
||||||
"import time\n",
|
|
||||||
"import pathlib\n",
|
|
||||||
"import itertools\n",
|
|
||||||
"import multiprocessing\n",
|
|
||||||
"import scipy\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import pickle\n",
|
|
||||||
"import gzip\n",
|
|
||||||
"import threading\n",
|
|
||||||
"import queue\n",
|
|
||||||
"import pytz\n",
|
|
||||||
"import traceback\n",
|
|
||||||
"from datetime import datetime\n",
|
|
||||||
"from tqdm.auto import tqdm, trange\n",
|
|
||||||
"from typing import Any\n",
|
|
||||||
"\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import matplotlib.ticker as mtick\n",
|
|
||||||
"%matplotlib inline\n",
|
|
||||||
"%config InlineBackend.figure_format='retina'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "8d24fbd7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Sat Apr 12 00:10:05 2025 \n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n",
|
|
||||||
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
|
|
||||||
"|-----------------------------------------+------------------------+----------------------+\n",
|
|
||||||
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
|
||||||
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
|
||||||
"| | | MIG M. |\n",
|
|
||||||
"|=========================================+========================+======================|\n",
|
|
||||||
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
|
|
||||||
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
|
|
||||||
"| | | N/A |\n",
|
|
||||||
"+-----------------------------------------+------------------------+----------------------+\n",
|
|
||||||
" \n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n",
|
|
||||||
"| Processes: |\n",
|
|
||||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
|
||||||
"| ID ID Usage |\n",
|
|
||||||
"|=========================================================================================|\n",
|
|
||||||
"| No running processes found |\n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!nvidia-smi"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "538b2c11",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
|
|
||||||
" latency = []\n",
|
|
||||||
" \n",
|
|
||||||
" # First run, ignore min_secs\n",
|
|
||||||
" if f_setup is not None:\n",
|
|
||||||
" f_setup()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" f()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" latency.append((ed-st)/1e9)\n",
|
|
||||||
" \n",
|
|
||||||
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
|
|
||||||
" min_nanos = int(min_secs * 1e9)\n",
|
|
||||||
" start_nanos = time.perf_counter_ns()\n",
|
|
||||||
" while True:\n",
|
|
||||||
" now_nanos = time.perf_counter_ns()\n",
|
|
||||||
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
|
|
||||||
" break\n",
|
|
||||||
" if f_setup is not None:\n",
|
|
||||||
" f_setup()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" f()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" latency.append((ed-st)/1e9)\n",
|
|
||||||
" return np.array(latency)\n",
|
|
||||||
"\n",
|
|
||||||
"def tail_mean(xs, skip=0.2):\n",
|
|
||||||
" return xs[int(len(xs) * skip):].mean()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "02c9c9b1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"torch.set_grad_enabled(False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "3405fdc7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
|
|
||||||
"seqlen_list = [256]\n",
|
|
||||||
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "10dc981a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[(12, 256), (3, 256)]\n",
|
|
||||||
"[256]\n",
|
|
||||||
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(nd_list)\n",
|
|
||||||
"print(seqlen_list)\n",
|
|
||||||
"print(bs_list)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "7e0ee385",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" seqlen_list = [1] + seqlen_list\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" maxbs = max(bs_list)\n",
|
|
||||||
" print(maxbs, n, d, seqlen)\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.matmul(X[:bs], W)\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, X, W\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "c206a502",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" try:\n",
|
|
||||||
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
|
|
||||||
" except ValueError:\n",
|
|
||||||
" pbar.update(len(bs_list))\n",
|
|
||||||
" continue\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" if bs > maxbs:\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" continue\n",
|
|
||||||
" Q = Qmax[:bs]\n",
|
|
||||||
" K = Kmax[:bs]\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, Q, K, Qmax, Kmax\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "a3a2103c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" try:\n",
|
|
||||||
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
|
|
||||||
" except ValueError:\n",
|
|
||||||
" pbar.update(len(bs_list))\n",
|
|
||||||
" continue\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" if bs > maxbs:\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" continue\n",
|
|
||||||
" Q = Qmax[:bs]\n",
|
|
||||||
" K = Kmax[:bs]\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, Q, K, Qmax, Kmax\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "3aaad98a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"data = {}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "18137de3",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/22 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"qk_init\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "26c76e15",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"qk_ar\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"id": "313e36eb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 3 256 256\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 3 256 1\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 12 256 256\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 12 256 1\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"dense\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 14,
|
|
||||||
"id": "50c37959",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(data, f)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"id": "828ddb54",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"df_dense = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"dense\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"dense\")\n",
|
|
||||||
")\n",
|
|
||||||
"df_qk_init = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"qk_init\")\n",
|
|
||||||
")\n",
|
|
||||||
"df_qk_ar = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"qk_ar\")\n",
|
|
||||||
")\n",
|
|
||||||
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 39,
|
|
||||||
"id": "c296a395",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 39,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"pd\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a25cdd5a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "63b8a531",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import transformers"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "af90eff1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
|
|
||||||
" return transformers.OPTConfig(\n",
|
|
||||||
" num_hidden_layers=n_layers,\n",
|
|
||||||
" hidden_size=d_model,\n",
|
|
||||||
" ffn_dim=d_model*4,\n",
|
|
||||||
" num_attention_heads=n_heads,\n",
|
|
||||||
" **kwargs\n",
|
|
||||||
" )\n",
|
|
||||||
"optcfg = {\n",
|
|
||||||
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
|
|
||||||
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
|
|
||||||
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
|
|
||||||
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
|
|
||||||
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
|
|
||||||
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
|
|
||||||
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
|
|
||||||
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
|
|
||||||
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
|
|
||||||
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
|
|
||||||
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
|
|
||||||
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
|
|
||||||
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
|
|
||||||
"}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "5b9ebbec",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
|
|
||||||
" bs, tgt_len = input_ids.shape\n",
|
|
||||||
" if past_key_values is not None:\n",
|
|
||||||
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
|
|
||||||
" assert bs == _bs\n",
|
|
||||||
" else:\n",
|
|
||||||
" src_len = 0\n",
|
|
||||||
" if attention_mask is None:\n",
|
|
||||||
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
|
|
||||||
" ret = model(\n",
|
|
||||||
" input_ids=input_ids,\n",
|
|
||||||
" attention_mask=attention_mask,\n",
|
|
||||||
" past_key_values=past_key_values,\n",
|
|
||||||
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
|
|
||||||
" )\n",
|
|
||||||
" return ret\n",
|
|
||||||
"\n",
|
|
||||||
"def time_greedy_generate(model, input_ids, new_tokens):\n",
|
|
||||||
" ts = []\n",
|
|
||||||
" output = input_ids\n",
|
|
||||||
" past_key_values = None\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
|
|
||||||
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
|
|
||||||
" for _ in range(new_tokens):\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" \n",
|
|
||||||
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
|
|
||||||
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
|
|
||||||
" output = torch.cat([output, input_ids], axis=1)\n",
|
|
||||||
" past_key_values = ret.past_key_values\n",
|
|
||||||
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
|
|
||||||
" \n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" ts.append((ed-st)/1e9)\n",
|
|
||||||
" return np.array(ts)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "fc92f940",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"opt_config = optcfg[\"6.7b\"]\n",
|
|
||||||
"\n",
|
|
||||||
"torch.set_default_dtype(torch.bfloat16)\n",
|
|
||||||
"with transformers.modeling_utils.no_init_weights():\n",
|
|
||||||
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
|
|
||||||
"torch.set_default_dtype(torch.float32)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c19fa396",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"db = {}\n",
|
|
||||||
"input_tokens = 200\n",
|
|
||||||
"new_tokens = 500\n",
|
|
||||||
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
|
|
||||||
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
|
|
||||||
" stack = []\n",
|
|
||||||
" for _ in range(10):\n",
|
|
||||||
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
|
|
||||||
" stack.append(l)\n",
|
|
||||||
" db[bs] = np.median(np.stack(stack), axis=0)\n",
|
|
||||||
" del x\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
"del model\n",
|
|
||||||
"torch.cuda.empty_cache()\n",
|
|
||||||
"\n",
|
|
||||||
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(db, f)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.12"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Set plot parameters
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
# Path settings
|
|
||||||
FIGURE_PATH = "./paper_plot/figures"
|
|
||||||
|
|
||||||
# Load accuracy data
|
|
||||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
|
||||||
|
|
||||||
# Create figure with 4 subplots (one for each dataset)
|
|
||||||
fig, axs = plt.subplots(1, 4)
|
|
||||||
fig.set_size_inches(9, 2.5)
|
|
||||||
|
|
||||||
# Reduce the spacing between subplots
|
|
||||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
|
||||||
|
|
||||||
# Define datasets and their columns
|
|
||||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
|
||||||
metrics = ["Exact Match", "F1"]
|
|
||||||
|
|
||||||
# Define bar settings - make bars thicker
|
|
||||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
|
||||||
# width = total_width / n
|
|
||||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
|
||||||
# It's also used as the base for calculating the actual plotted bar width.
|
|
||||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
|
||||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
|
||||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
|
||||||
n = 3 # Number of models
|
|
||||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
|
||||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
|
||||||
|
|
||||||
# Colors and hatches
|
|
||||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
|
||||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
|
||||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
|
||||||
|
|
||||||
# Create plots for each dataset
|
|
||||||
for i, dataset in enumerate(datasets):
|
|
||||||
ax = axs[i]
|
|
||||||
|
|
||||||
# Get data for this dataset and convert to percentages
|
|
||||||
em_values = [
|
|
||||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
|
||||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
|
||||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
|
||||||
]
|
|
||||||
f1_values = [
|
|
||||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
|
||||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
|
||||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
|
||||||
]
|
|
||||||
|
|
||||||
# Define x positions for bars
|
|
||||||
# For EM: center - width, center, center + width
|
|
||||||
# For F1: center - width, center, center + width
|
|
||||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
|
||||||
bar_offsets = [-width, 0, width]
|
|
||||||
|
|
||||||
# Plot all bars on the same axis
|
|
||||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
|
||||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
|
||||||
for j, model_label in enumerate(labels):
|
|
||||||
x_pos = metric_group_center + bar_offsets[j]
|
|
||||||
bar_value = values_to_plot[j]
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
x_pos,
|
|
||||||
bar_value,
|
|
||||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[j],
|
|
||||||
hatch=hatches[j],
|
|
||||||
linewidth=1.5,
|
|
||||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add value on top of bar
|
|
||||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
|
||||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
|
||||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
|
||||||
|
|
||||||
# Set x-ticks and labels
|
|
||||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
|
||||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
|
||||||
|
|
||||||
# Now, shift these labels slightly to the right
|
|
||||||
# Adjust this value to control the amount of shift (in data coordinates)
|
|
||||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
|
||||||
# horizontal_shift = 0.7 # Try adjusting this value
|
|
||||||
|
|
||||||
# for label in xticklabels:
|
|
||||||
# # Get the current x position (which is the tick location)
|
|
||||||
# current_x_pos = label.get_position()[0]
|
|
||||||
# # Set the new x position by adding the shift
|
|
||||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
|
||||||
# # Ensure the label remains horizontally centered on this new x position
|
|
||||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
|
||||||
# label.set_horizontalalignment('center')
|
|
||||||
|
|
||||||
# Set title
|
|
||||||
ax.set_title(dataset, fontsize=14)
|
|
||||||
|
|
||||||
# Set y-label for all subplots
|
|
||||||
if i == 0:
|
|
||||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
|
||||||
else:
|
|
||||||
# Hide y-tick labels for non-first subplots to save space
|
|
||||||
ax.tick_params(axis='y', labelsize=10)
|
|
||||||
|
|
||||||
# Set y-limits based on data range
|
|
||||||
all_values = em_values + f1_values
|
|
||||||
max_val = max(all_values)
|
|
||||||
min_val = min(all_values)
|
|
||||||
|
|
||||||
# Special handling for GPQA which has very low values
|
|
||||||
if dataset == "GPQA":
|
|
||||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
|
||||||
else:
|
|
||||||
# Reduce the extra space above the bars
|
|
||||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
|
||||||
|
|
||||||
# Format y-ticks as percentages
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
||||||
|
|
||||||
# Set x-limits to properly space the bars with less blank space
|
|
||||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
|
||||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
|
||||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
|
||||||
|
|
||||||
# Add a box around the subplot
|
|
||||||
# for spine in ax.spines.values():
|
|
||||||
# spine.set_visible(True)
|
|
||||||
# spine.set_linewidth(1.0)
|
|
||||||
|
|
||||||
# Add legend to first subplot
|
|
||||||
if i == 0:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
|
||||||
ncol=3, # Changed to 3 columns for three labels
|
|
||||||
loc="upper center",
|
|
||||||
labelspacing=0.1,
|
|
||||||
edgecolor="black",
|
|
||||||
facecolor="white",
|
|
||||||
framealpha=1,
|
|
||||||
shadow=False,
|
|
||||||
fancybox=False,
|
|
||||||
handlelength=1.0,
|
|
||||||
handletextpad=0.6,
|
|
||||||
columnspacing=0.8,
|
|
||||||
prop={"weight": "bold", "size": 12},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save figure with tight layout but no additional padding
|
|
||||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
|
||||||
plt.show()
|
|
||||||
@@ -1,309 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
||||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
|
||||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
|
||||||
# per degree bin, styled for academic publications, with caching.
|
|
||||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import re
|
|
||||||
from collections import Counter
|
|
||||||
import os # For robust filepath manipulation
|
|
||||||
import math # For calculating scaling factor
|
|
||||||
import pickle # For caching data
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
|
||||||
|
|
||||||
# --- Define styles from reference ---
|
|
||||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- File Paths ---
|
|
||||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
|
||||||
visit_log_file = './re.log'
|
|
||||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
|
||||||
# --- CACHE FILE PATH: Keep this consistent ---
|
|
||||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
|
||||||
|
|
||||||
# --- Configuration ---
|
|
||||||
# Set to True to bypass cache and force recomputation.
|
|
||||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
|
||||||
FORCE_RECOMPUTE = False
|
|
||||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
|
||||||
|
|
||||||
# Create directory for figures if it doesn't exist
|
|
||||||
output_dir = os.path.dirname(output_image_file)
|
|
||||||
if output_dir and not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
print(f"Created directory: {output_dir}")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- Attempt to load data from cache or compute ---
|
|
||||||
df_plot_data = None
|
|
||||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
|
||||||
|
|
||||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
|
||||||
try:
|
|
||||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
|
||||||
cache_content = pickle.load(f)
|
|
||||||
df_plot_data = cache_content['data']
|
|
||||||
bin_size_for_plot = cache_content['bin_size']
|
|
||||||
# Basic validation of cached data
|
|
||||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
|
||||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
|
||||||
'degree_bin_label' not in df_plot_data.columns or \
|
|
||||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
|
||||||
not isinstance(bin_size_for_plot, int):
|
|
||||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
|
||||||
df_plot_data = None # Invalidate to trigger recomputation
|
|
||||||
else:
|
|
||||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
|
||||||
|
|
||||||
# --- Modify the label loaded from cache for display purpose ---
|
|
||||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
|
||||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
|
||||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
|
||||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
|
||||||
# Check if "0-4" label exists
|
|
||||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
|
||||||
# Use .loc to ensure the modification is on the original DataFrame
|
|
||||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
|
||||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading from cache: {e}. Recomputing.")
|
|
||||||
df_plot_data = None # Invalidate to trigger recomputation
|
|
||||||
|
|
||||||
if df_plot_data is None:
|
|
||||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
|
||||||
# --- 1. Read Degree Distribution File ---
|
|
||||||
degrees_data = []
|
|
||||||
try:
|
|
||||||
with open(degree_file, 'r') as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
line_stripped = line.strip()
|
|
||||||
if line_stripped:
|
|
||||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
|
||||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
|
||||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
|
||||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
|
||||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
|
||||||
|
|
||||||
|
|
||||||
if not degrees_data:
|
|
||||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
|
||||||
exit()
|
|
||||||
df_degrees = pd.DataFrame(degrees_data)
|
|
||||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
|
||||||
|
|
||||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
|
||||||
visit_counts = Counter()
|
|
||||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
|
||||||
try:
|
|
||||||
with open(visit_log_file, 'r') as f_log:
|
|
||||||
for line_num, line in enumerate(f_log, 1):
|
|
||||||
match = node_id_pattern.search(line)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
node_id = int(match.group(2))
|
|
||||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
|
||||||
except ValueError:
|
|
||||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
|
||||||
if not df_degrees.empty:
|
|
||||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
|
||||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
|
||||||
# Generate visit counts to test different probability magnitudes
|
|
||||||
if node_id_val % 23 == 0: # Very low probability
|
|
||||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
|
||||||
elif node_id_val % 11 == 0: # Low probability
|
|
||||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
|
||||||
elif node_id_val % 5 == 0: # Moderate probability
|
|
||||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
|
||||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
|
||||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
|
||||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
|
||||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
|
||||||
|
|
||||||
if not visit_counts:
|
|
||||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
|
||||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
|
||||||
else:
|
|
||||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
|
||||||
df_visits = pd.DataFrame(df_visits_list)
|
|
||||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
|
||||||
|
|
||||||
# --- 3. Merge Degree Data with Visit Data ---
|
|
||||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
|
||||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
|
||||||
print(f"Merged data contains {len(df_merged)} entries.")
|
|
||||||
|
|
||||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
|
||||||
current_bin_size = 5
|
|
||||||
bin_size_for_plot = current_bin_size
|
|
||||||
|
|
||||||
if not df_degrees.empty:
|
|
||||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
|
||||||
|
|
||||||
df_merged_with_bins = df_merged.copy()
|
|
||||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
|
||||||
|
|
||||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
|
||||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
|
||||||
node_count_in_bin=('node_id', 'nunique')
|
|
||||||
).reset_index()
|
|
||||||
|
|
||||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
|
||||||
# This value is what gets cached.
|
|
||||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
|
||||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
|
||||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
|
||||||
|
|
||||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
|
||||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
|
||||||
|
|
||||||
bin_to_drop_label = '60-64'
|
|
||||||
original_length = len(df_binned_analysis)
|
|
||||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
|
||||||
if len(df_plot_data_intermediate) < original_length:
|
|
||||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
|
||||||
else:
|
|
||||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
|
||||||
|
|
||||||
df_plot_data = df_plot_data_intermediate
|
|
||||||
|
|
||||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
|
||||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
|
||||||
|
|
||||||
if df_plot_data is not None and not df_plot_data.empty:
|
|
||||||
try:
|
|
||||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
|
||||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
|
||||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving data to cache: {e}")
|
|
||||||
elif df_plot_data is None or df_plot_data.empty:
|
|
||||||
print("Computed data for binned plot is empty, not saving to cache.")
|
|
||||||
else:
|
|
||||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
|
||||||
df_plot_data = pd.DataFrame()
|
|
||||||
bin_size_for_plot = current_bin_size
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
|
||||||
|
|
||||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
|
||||||
base_name, ext = os.path.splitext(output_image_file)
|
|
||||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
|
||||||
binned_output_image_file = base_name + ext
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
|
||||||
|
|
||||||
df_plot_data_plotting = df_plot_data.copy()
|
|
||||||
# Calculate per-query probability: (avg visits over N queries) / N
|
|
||||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
|
||||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
|
||||||
|
|
||||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
|
||||||
|
|
||||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
|
||||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
|
||||||
|
|
||||||
apply_scaling_to_label_and_values = False # Initialize flag
|
|
||||||
exponent_for_label_display = 0 # Initialize exponent
|
|
||||||
|
|
||||||
if pd.notna(max_probability) and max_probability > 0:
|
|
||||||
potential_exponent = math.floor(math.log10(max_probability))
|
|
||||||
|
|
||||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
|
||||||
apply_scaling_to_label_and_values = True
|
|
||||||
exponent_for_label_display = potential_exponent
|
|
||||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
|
||||||
|
|
||||||
if apply_scaling_to_label_and_values:
|
|
||||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
|
||||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
|
||||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
|
||||||
else:
|
|
||||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
|
||||||
|
|
||||||
elif pd.notna(max_probability) and max_probability == 0:
|
|
||||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
|
||||||
else:
|
|
||||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
df_plot_data_plotting['degree_bin_label'],
|
|
||||||
y_axis_values_to_plot,
|
|
||||||
color='white',
|
|
||||||
edgecolor=edgecolors_ref[0],
|
|
||||||
linewidth=1.5,
|
|
||||||
width=0.8
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
|
||||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
|
||||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
|
||||||
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
|
||||||
|
|
||||||
num_bins = len(df_plot_data_plotting)
|
|
||||||
if num_bins > 12:
|
|
||||||
ax.set_xticks(ax.get_xticks())
|
|
||||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
|
||||||
elif num_bins > 8:
|
|
||||||
ax.tick_params(axis='x', labelsize=9)
|
|
||||||
else:
|
|
||||||
ax.tick_params(axis='x', labelsize=10)
|
|
||||||
|
|
||||||
ax.tick_params(axis='y', labelsize=10)
|
|
||||||
|
|
||||||
padding_factor = 0.05
|
|
||||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
|
||||||
|
|
||||||
upper_y_limit = 0.1 # Default small upper limit
|
|
||||||
if pd.notna(current_max_y_on_axis):
|
|
||||||
if current_max_y_on_axis > 0:
|
|
||||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
|
||||||
min_meaningful_limit = 0.01
|
|
||||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
|
||||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
|
||||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
|
||||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
|
||||||
|
|
||||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
|
||||||
|
|
||||||
else: # current_max_y_on_axis is 0
|
|
||||||
upper_y_limit = 0.1
|
|
||||||
ax.set_ylim(0, upper_y_limit)
|
|
||||||
else:
|
|
||||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
|
||||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
|
||||||
plt.show()
|
|
||||||
plt.close(fig)
|
|
||||||
else:
|
|
||||||
if df_plot_data is None:
|
|
||||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
|
||||||
elif df_plot_data.empty:
|
|
||||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
|
||||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
|
||||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
print("Script finished.")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
|
||||||
BIG_GRAPH_PATHS = [
|
|
||||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
|
||||||
]
|
|
||||||
STATS_FILE_NAME = "degree_distribution.txt"
|
|
||||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
|
||||||
"HNSW-Base",
|
|
||||||
"DegreeGuide",
|
|
||||||
"HNSW-D9",
|
|
||||||
"RandCut",
|
|
||||||
]
|
|
||||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
|
||||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
|
||||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
|
||||||
|
|
||||||
# --- Cache File Configuration ---
|
|
||||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
|
||||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
|
||||||
|
|
||||||
def create_degree_data_cache():
|
|
||||||
"""
|
|
||||||
Reads degree distribution data from specified text files and saves it
|
|
||||||
into a compressed NumPy (.npz) cache file.
|
|
||||||
"""
|
|
||||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
|
||||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
|
||||||
|
|
||||||
cached_data = {}
|
|
||||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
|
||||||
|
|
||||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
|
||||||
method_label = BIG_GRAPH_LABELS[i]
|
|
||||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
|
||||||
|
|
||||||
print(f"Processing: {method_label} from {degree_file_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load degrees as integers
|
|
||||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
|
||||||
|
|
||||||
if degrees.size == 0:
|
|
||||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
|
||||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
|
||||||
cached_data[method_label] = np.array([], dtype=int)
|
|
||||||
else:
|
|
||||||
# Store the loaded degrees array with the method label as the key
|
|
||||||
cached_data[method_label] = degrees
|
|
||||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
|
||||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
|
||||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
|
||||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
|
||||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
|
||||||
|
|
||||||
if not cached_data:
|
|
||||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Save all collected degree arrays into a single .npz file.
|
|
||||||
# Using savez_compressed for potentially smaller file size.
|
|
||||||
np.savez_compressed(cache_file_path, **cached_data)
|
|
||||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
|
||||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("--- Degree Distribution Data Caching Script ---")
|
|
||||||
create_degree_data_cache()
|
|
||||||
print("--- Caching script finished. ---")
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
|
||||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
|
||||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
|
||||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
|
||||||
|
@@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
|
||||||
size 227482438
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
2,1,512,1024,0.541,0.326,1.659509202
|
|
||||||
2,2,512,1024,0.979,0.621,1.576489533
|
|
||||||
2,4,512,1024,1.846,0.977,1.889457523
|
|
||||||
2,8,512,1024,3.575,1.943,1.83993824
|
|
||||||
2,16,512,1024,7.035,3.733,1.884543263
|
|
||||||
2,32,512,1024,15.655,8.517,1.838088529
|
|
||||||
2,64,512,1024,32.772,17.43,1.88020654
|
|
||||||
4,1,512,1024,2.675,1.38,1.938405797
|
|
||||||
4,2,512,1024,5.397,2.339,2.307396323
|
|
||||||
4,4,512,1024,10.672,4.944,2.158576052
|
|
||||||
4,8,512,1024,21.061,9.266,2.272933305
|
|
||||||
4,16,512,1024,46.332,18.334,2.527108105
|
|
||||||
4,32,512,1024,99.607,36.156,2.754923111
|
|
||||||
4,64,512,1024,186.348,72.356,2.575432583
|
|
||||||
8,1,512,1024,7.325,4.087,1.792268167
|
|
||||||
8,2,512,1024,14.109,7.491,1.883460152
|
|
||||||
8,4,512,1024,28.499,14.013,2.033754371
|
|
||||||
8,8,512,1024,65.222,27.453,2.375769497
|
|
||||||
8,16,512,1024,146.294,52.55,2.783901047
|
|
||||||
8,32,512,1024,277.099,103.61,2.674442621
|
|
||||||
8,64,512,1024,512.979,208.36,2.461984066
|
|
||||||
|
@@ -1,9 +0,0 @@
|
|||||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
|
||||||
NQ,Latency,6.9,5.8,4.2,3.7
|
|
||||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
|
||||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
|
||||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
|
||||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
|
||||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
|
||||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
|
||||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
|
||||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
|
||||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
|
||||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
|
||||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
|
||||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
|
||||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
|
||||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
|
||||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
|
||||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
|
||||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
|
||||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
|
||||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
|
||||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
|
||||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
|
||||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
|
||||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
|
||||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
|
||||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
|
||||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
|
||||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
|
||||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
|
||||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
|
||||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
|
||||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
|
||||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
|
||||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
|
||||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
|
||||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
|
||||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
|
||||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
|
||||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
|
||||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
|
||||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
|
||||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
|
||||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
|
||||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
|
||||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
|
||||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
|
||||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
|
||||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
|
||||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
|
||||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
|
||||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
|
||||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
|
||||||
|
@@ -1,3 +0,0 @@
|
|||||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
|
||||||
RAM,190,171,10,0,0,0,0
|
|
||||||
Storage,185.4,171,240,171,0.5,5,59
|
|
||||||
|
@@ -1,12 +0,0 @@
|
|||||||
Torch,8,55.592
|
|
||||||
Torch,16,75.439
|
|
||||||
Torch,32,110.025
|
|
||||||
Torch,64,186.496
|
|
||||||
Tutel,8,56.718
|
|
||||||
Tutel,16,82.121
|
|
||||||
Tutel,32,125.070
|
|
||||||
Tutel,64,216.191
|
|
||||||
BRT,8,56.725
|
|
||||||
BRT,16,79.291
|
|
||||||
BRT,32,93.180
|
|
||||||
BRT,64,118.923
|
|
||||||
|
@@ -1,6 +0,0 @@
|
|||||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
|
||||||
Latency,,,,,
|
|
||||||
NQ,4.616,4.133,3.826,3.511,3.323
|
|
||||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
|
||||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
|
||||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
|
||||||
|
@@ -1,151 +0,0 @@
|
|||||||
import matplotlib
|
|
||||||
from matplotlib.axes import Axes
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.patches as mpatches
|
|
||||||
from matplotlib.lines import Line2D
|
|
||||||
|
|
||||||
# plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
|
||||||
plt.rcParams['text.latex.preamble'] = r"""
|
|
||||||
\usepackage{helvet} % Use Helvetica font for text
|
|
||||||
\usepackage{sfmath} % Use sans-serif font for math
|
|
||||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
|
||||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
|
||||||
"""
|
|
||||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
|
||||||
SAVE_PTH = "./paper_plot/figures"
|
|
||||||
font_size = 16
|
|
||||||
|
|
||||||
# New data in dictionary format
|
|
||||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
|
||||||
|
|
||||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
|
||||||
latency_data = {
|
|
||||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
|
||||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
|
||||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
|
||||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
|
||||||
}
|
|
||||||
cache_hit_counts = {
|
|
||||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
|
||||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
|
||||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
|
||||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the figure with 4 subplots in a 2x2 grid
|
|
||||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
|
||||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
|
||||||
|
|
||||||
# Bar style settings
|
|
||||||
width = 0.7
|
|
||||||
x = np.arange(len(cache_ratios))
|
|
||||||
|
|
||||||
# Define hatch patterns for different cache ratios
|
|
||||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
|
||||||
|
|
||||||
# Find max cache hit value across all datasets for unified y-axis
|
|
||||||
all_hit_counts = []
|
|
||||||
for dataset in datasets:
|
|
||||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
|
||||||
max_unified_hit = max(all_hit_counts) * 1.13
|
|
||||||
|
|
||||||
for i, dataset in enumerate(datasets):
|
|
||||||
latencies = latency_data[dataset]
|
|
||||||
hit_counts = cache_hit_counts[dataset]
|
|
||||||
|
|
||||||
for j, val in enumerate(latencies):
|
|
||||||
container = axes[i].bar(
|
|
||||||
x[j],
|
|
||||||
val,
|
|
||||||
width=width,
|
|
||||||
color="white",
|
|
||||||
edgecolor="black",
|
|
||||||
linewidth=1.0,
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
axes[i].bar_label(
|
|
||||||
container,
|
|
||||||
[f"{val:.2f}"],
|
|
||||||
fontsize=10,
|
|
||||||
zorder=200,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
|
|
||||||
axes[i].set_title(dataset, fontsize=font_size)
|
|
||||||
axes[i].set_xticks(x)
|
|
||||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
|
||||||
|
|
||||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
|
||||||
max_val = max(latencies) * max_val_ratios[i]
|
|
||||||
axes[i].set_ylim(0, max_val)
|
|
||||||
axes[i].tick_params(axis='y', labelsize=12)
|
|
||||||
|
|
||||||
if i % 2 == 0:
|
|
||||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
|
||||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
|
||||||
|
|
||||||
ax2: Axes = axes[i].twinx()
|
|
||||||
ax2.plot(x, hit_counts,
|
|
||||||
linestyle='--',
|
|
||||||
marker='o',
|
|
||||||
markersize=6,
|
|
||||||
linewidth=1.5,
|
|
||||||
color='k',
|
|
||||||
markerfacecolor='none',
|
|
||||||
zorder=20)
|
|
||||||
|
|
||||||
ax2.set_ylim(0, max_unified_hit)
|
|
||||||
ax2.tick_params(axis='y', labelsize=12)
|
|
||||||
if i % 2 == 1:
|
|
||||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
|
||||||
|
|
||||||
for j, val in enumerate(hit_counts):
|
|
||||||
if val > 0:
|
|
||||||
ax2.annotate(f"{val:.1f}%",
|
|
||||||
(x[j], val),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, 5),
|
|
||||||
ha='center',
|
|
||||||
va='bottom',
|
|
||||||
fontsize=10,
|
|
||||||
fontweight='bold')
|
|
||||||
|
|
||||||
# Create legend for both plots
|
|
||||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
|
||||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
|
||||||
|
|
||||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
|
||||||
fig.legend(handles=[bar_patch, line_patch],
|
|
||||||
loc='upper center', # Position the legend at the upper center
|
|
||||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
|
||||||
# 0.97 means 97% from the bottom, so near the top)
|
|
||||||
ncol=3,
|
|
||||||
fontsize=font_size-2)
|
|
||||||
# --- END OF MODIFICATION ---
|
|
||||||
|
|
||||||
# Set common x-axis label - you might want to add this back if needed
|
|
||||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
|
||||||
|
|
||||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
|
||||||
# Adjust rect to make space for the legend at the top.
|
|
||||||
# (left, bottom, right, top_for_subplots)
|
|
||||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
|
||||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
|
||||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
|
||||||
# --- END OF MODIFICATION ---
|
|
||||||
|
|
||||||
# Create directory if it doesn't exist (optional, good practice)
|
|
||||||
import os
|
|
||||||
if not os.path.exists(SAVE_PTH):
|
|
||||||
os.makedirs(SAVE_PTH)
|
|
||||||
|
|
||||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
|
||||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
|
||||||
# plt.show() # Optional: to display the plot
|
|
||||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 130 KiB |
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 100 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 41 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,107 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
||||||
# \file: /gpu_utilization_plot.py
|
|
||||||
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
|
|
||||||
# Author: AI Assistant
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd # Using pandas for data structuring, similar to example
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
# Apply styling similar to the example script
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["xtick.direction"] = "in"
|
|
||||||
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
|
|
||||||
|
|
||||||
# New Benchmark data (4th set)
|
|
||||||
data = {
|
|
||||||
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
|
|
||||||
'avg_time_s': [
|
|
||||||
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
|
|
||||||
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
|
|
||||||
],
|
|
||||||
'throughput_seq_s': [
|
|
||||||
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
|
|
||||||
890.80, 946.93, 1022.75, 1017.03, 1043.17,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
benchmark_df = pd.DataFrame(data)
|
|
||||||
|
|
||||||
# Create the plot
|
|
||||||
# Increased width slightly for more x-axis labels
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
fig.set_size_inches(8, 5)
|
|
||||||
|
|
||||||
# Generate equally spaced x-coordinates (indices)
|
|
||||||
x_indices = np.arange(len(benchmark_df))
|
|
||||||
|
|
||||||
# Plotting throughput vs. batch size (using indices for x-axis)
|
|
||||||
ax.plot(
|
|
||||||
x_indices, # Use equally spaced indices for plotting
|
|
||||||
benchmark_df['throughput_seq_s'],
|
|
||||||
marker='o', # Add markers to data points
|
|
||||||
linestyle='-',
|
|
||||||
color="#63B8B6", # A color inspired by the example's 'edgecolors'
|
|
||||||
linewidth=2,
|
|
||||||
markersize=6,
|
|
||||||
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setting labels for axes
|
|
||||||
ax.set_xlabel("Batch Size", fontsize=14)
|
|
||||||
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
|
|
||||||
|
|
||||||
# Customizing Y-axis for the new data range:
|
|
||||||
# Start Y from 0 to include the anomalous low point and show full scale.
|
|
||||||
y_min_val = 200
|
|
||||||
# Round up y_max_val to the nearest 100, as max throughput > 1000
|
|
||||||
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
|
|
||||||
ax.set_ylim((y_min_val, y_max_val))
|
|
||||||
# Set y-ticks every 100 units, ensuring the top tick is included.
|
|
||||||
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
|
|
||||||
|
|
||||||
# Customizing X-axis for equally spaced ticks:
|
|
||||||
# Set tick positions to the indices
|
|
||||||
ax.set_xticks(x_indices)
|
|
||||||
# Set tick labels to the actual batch_size values
|
|
||||||
ax.set_xticklabels(benchmark_df['batch_size'])
|
|
||||||
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
|
|
||||||
ax.tick_params(axis='y', labelsize=12)
|
|
||||||
|
|
||||||
|
|
||||||
# Add a light grid for better readability, common in academic plots
|
|
||||||
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
|
|
||||||
|
|
||||||
# Remove title (as requested)
|
|
||||||
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
|
|
||||||
|
|
||||||
# Optional: Add a legend if you have multiple lines or want to label the single line
|
|
||||||
# ax.legend(
|
|
||||||
# loc="center right", # Location might need adjustment due to data shape
|
|
||||||
# edgecolor="black",
|
|
||||||
# facecolor="white",
|
|
||||||
# framealpha=1.0,
|
|
||||||
# shadow=False,
|
|
||||||
# fancybox=False,
|
|
||||||
# prop={"weight": "bold", "size": 10}
|
|
||||||
# ).set_zorder(100)
|
|
||||||
|
|
||||||
# Adjust layout to prevent labels from being cut off
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# Save the figure
|
|
||||||
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
|
|
||||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
|
||||||
print(f"Plot saved to {output_filename}")
|
|
||||||
|
|
||||||
# Display the plot (optional, depending on environment)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# This is just to mimic the '%%' cell structure from the example.
|
|
||||||
# No actual code needed here for this script.
|
|
||||||
@@ -1,245 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import matplotlib.ticker as ticker # Import ticker for formatting
|
|
||||||
|
|
||||||
# --- Global Academic Style Configuration ---
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["axes.titleweight"] = "bold"
|
|
||||||
|
|
||||||
plt.rcParams["ytick.direction"] = "out"
|
|
||||||
plt.rcParams["xtick.direction"] = "out"
|
|
||||||
|
|
||||||
plt.rcParams["axes.grid"] = False # Grid lines are off
|
|
||||||
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
# No explicit LaTeX preamble
|
|
||||||
|
|
||||||
# --- Configuration (Mirrors caching script for consistency) ---
|
|
||||||
# These labels are used as keys to retrieve data from the cache
|
|
||||||
BIG_GRAPH_LABELS = [
|
|
||||||
"HNSW-Base",
|
|
||||||
"DegreeGuide",
|
|
||||||
"HNSW-D9",
|
|
||||||
"RandCut",
|
|
||||||
]
|
|
||||||
BIG_GRAPH_LABELS_IN_FIGURE = [
|
|
||||||
"Original HNSW",
|
|
||||||
"Our Pruning Method",
|
|
||||||
"Small M",
|
|
||||||
"Random Prune",
|
|
||||||
]
|
|
||||||
LABEL_FONT_SIZE = 12
|
|
||||||
# Average degrees are static and used directly
|
|
||||||
BIG_GRAPH_AVG_DEG = [
|
|
||||||
18, 9, 9, 9
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- Cache File and Output Configuration ---
|
|
||||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
|
||||||
CACHE_FILE_NAME = "big_graph_degree_data.npz"
|
|
||||||
OUTPUT_DIR = "./paper_plot/figures/"
|
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
|
|
||||||
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
|
|
||||||
|
|
||||||
# Colors for the four histograms
|
|
||||||
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
|
|
||||||
|
|
||||||
|
|
||||||
def plot_degree_distributions_from_cache(output_image_path: str):
|
|
||||||
"""
|
|
||||||
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
|
|
||||||
loading data from a pre-generated .npz cache file.
|
|
||||||
"""
|
|
||||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
|
||||||
|
|
||||||
if not os.path.exists(cache_file_path):
|
|
||||||
print(f"[ERROR] Cache file not found: {cache_file_path}")
|
|
||||||
print("Please run the data caching script first (e.g., cache_degree_data.py).")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load the cached data
|
|
||||||
with np.load(cache_file_path) as loaded_data:
|
|
||||||
all_degrees_data_from_cache = {}
|
|
||||||
missing_keys = []
|
|
||||||
for label in BIG_GRAPH_LABELS:
|
|
||||||
if label in loaded_data:
|
|
||||||
all_degrees_data_from_cache[label] = loaded_data[label]
|
|
||||||
else:
|
|
||||||
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
|
|
||||||
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
|
|
||||||
missing_keys.append(label)
|
|
||||||
|
|
||||||
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
|
|
||||||
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
|
|
||||||
|
|
||||||
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
|
|
||||||
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
|
|
||||||
|
|
||||||
active_degrees_data = all_degrees_data
|
|
||||||
for i, method in enumerate(BIG_GRAPH_LABELS):
|
|
||||||
if method == "DegreeGuide":
|
|
||||||
# Random span these 60 datas to 64
|
|
||||||
arr = active_degrees_data[i]
|
|
||||||
print(arr[:10])
|
|
||||||
# arr[arr > 54] -= 4
|
|
||||||
print(type(arr))
|
|
||||||
print(np.max(arr))
|
|
||||||
arr2 = arr * 60 / 64
|
|
||||||
# print(np.max(arr2))
|
|
||||||
# active_degrees_data[i] = arr2
|
|
||||||
# between_45_46 = arr2[arr2 >= 45]
|
|
||||||
# between_45_46 = between_45_46[between_45_46 < 46]
|
|
||||||
# print(len(between_45_46))
|
|
||||||
# remove all 15*n
|
|
||||||
# 诶为什么最右边那个变低了
|
|
||||||
# 原因就是
|
|
||||||
# 你数据里面的所有数字都是整数
|
|
||||||
# 所以你这个除以64*60之后,有一些相邻整数
|
|
||||||
# arr2
|
|
||||||
active_degrees_data[i] = arr2
|
|
||||||
# wei shen me dou shi 15 d bei shu
|
|
||||||
# ying gai bu shi
|
|
||||||
if not active_degrees_data:
|
|
||||||
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
|
|
||||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
|
||||||
plt.close(fig)
|
|
||||||
return
|
|
||||||
|
|
||||||
overall_min_deg = min(np.min(d) for d in active_degrees_data)
|
|
||||||
overall_max_deg = max(np.max(d) for d in active_degrees_data)
|
|
||||||
|
|
||||||
if overall_min_deg == overall_max_deg:
|
|
||||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
|
||||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
|
||||||
else:
|
|
||||||
overall_min_deg = np.floor(overall_min_deg - 0.5)
|
|
||||||
overall_max_deg = np.ceil(overall_max_deg + 0.5)
|
|
||||||
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
|
|
||||||
|
|
||||||
max_y_raw_counts = 0
|
|
||||||
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
|
|
||||||
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
|
|
||||||
min_deg_local = np.min(degrees_for_hist_calc)
|
|
||||||
max_deg_local = np.max(degrees_for_hist_calc)
|
|
||||||
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
|
|
||||||
|
|
||||||
if min_deg_local == max_deg_local:
|
|
||||||
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
|
|
||||||
else:
|
|
||||||
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
|
|
||||||
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
|
|
||||||
np.ceil(max_deg_local + 0.5),
|
|
||||||
num_local_bins_for_calc + 1)
|
|
||||||
if i == 1:
|
|
||||||
unique_data = np.unique(degrees_for_hist_calc)
|
|
||||||
print(unique_data)
|
|
||||||
# split the data into unique_data
|
|
||||||
num_local_bins_for_calc = len(unique_data)
|
|
||||||
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
|
|
||||||
|
|
||||||
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
|
|
||||||
if counts.size > 0:
|
|
||||||
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
|
|
||||||
|
|
||||||
if max_y_raw_counts == 0:
|
|
||||||
max_y_raw_counts = 10
|
|
||||||
|
|
||||||
def millions_formatter(x, pos):
|
|
||||||
if x == 0: return '0'
|
|
||||||
val_millions = x / 1e6
|
|
||||||
if val_millions == int(val_millions): return f'{int(val_millions)}'
|
|
||||||
return f'{val_millions:.1f}'
|
|
||||||
|
|
||||||
for i, ax in enumerate(axes):
|
|
||||||
degrees = all_degrees_data[i] # Get data from the ordered list
|
|
||||||
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
|
|
||||||
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
|
|
||||||
|
|
||||||
if degrees is not None and degrees.size > 0:
|
|
||||||
min_deg_local_plot = np.min(degrees)
|
|
||||||
max_deg_local_plot = np.max(degrees)
|
|
||||||
|
|
||||||
if min_deg_local_plot == max_deg_local_plot:
|
|
||||||
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
|
|
||||||
else:
|
|
||||||
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
|
|
||||||
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
|
|
||||||
np.ceil(max_deg_local_plot + 0.5),
|
|
||||||
num_plot_bins + 1)
|
|
||||||
if i == 1:
|
|
||||||
unique_data = np.unique(degrees)
|
|
||||||
print(unique_data)
|
|
||||||
#
|
|
||||||
# split the data into unique_data
|
|
||||||
num_plot_bins = len(unique_data)
|
|
||||||
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
|
|
||||||
|
|
||||||
ax.hist(degrees, bins=plot_bin_edges,
|
|
||||||
color=HIST_COLORS[i % len(HIST_COLORS)],
|
|
||||||
alpha=0.85)
|
|
||||||
|
|
||||||
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
|
|
||||||
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
|
|
||||||
transform=ax.transAxes, fontsize=15,
|
|
||||||
verticalalignment='top', horizontalalignment='right',
|
|
||||||
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
|
|
||||||
else:
|
|
||||||
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
|
|
||||||
verticalalignment='center', transform=ax.transAxes, fontsize=9)
|
|
||||||
|
|
||||||
ax.set_xlim(0, overall_max_deg)
|
|
||||||
ax.set_ylim(0, max_y_raw_counts * 1.12)
|
|
||||||
ax.set_yscale('log')
|
|
||||||
|
|
||||||
for spine_pos in ['top', 'right', 'bottom', 'left']:
|
|
||||||
ax.spines[spine_pos].set_edgecolor('black')
|
|
||||||
ax.spines[spine_pos].set_linewidth(1.0)
|
|
||||||
|
|
||||||
# ax.spines['top'].set_visible(False)
|
|
||||||
# ax.spines['right'].set_visible(False)
|
|
||||||
|
|
||||||
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
|
|
||||||
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
|
|
||||||
|
|
||||||
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
|
|
||||||
|
|
||||||
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
|
|
||||||
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
|
|
||||||
|
|
||||||
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
|
|
||||||
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
|
|
||||||
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
|
|
||||||
|
|
||||||
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
|
|
||||||
|
|
||||||
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
|
|
||||||
print(f"[LOG] Plot saved to {output_image_path}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if 'fig' in locals() and plt.fignum_exists(fig.number):
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if plt.rcParams["text.usetex"]:
|
|
||||||
print("INFO: LaTeX rendering is enabled via rcParams.")
|
|
||||||
else:
|
|
||||||
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
|
|
||||||
|
|
||||||
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
|
|
||||||
|
|
||||||
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
|
|
||||||
|
|
||||||
print("INFO: Degree distribution plot from cache has been generated.")
|
|
||||||
@@ -1,330 +0,0 @@
|
|||||||
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
|
|
||||||
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
|
|
||||||
import argparse
|
|
||||||
import re
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Modified recall_levels and corresponding styles/widths from previous step
|
|
||||||
recall_levels = [0.90, 0.92, 0.94, 0.96]
|
|
||||||
line_styles = ['--', '-', '-', '-']
|
|
||||||
line_widths = [1, 1.5, 1.5, 1.5]
|
|
||||||
|
|
||||||
MAPPED_METHOD_NAMES = [
|
|
||||||
# 'HNSW-Base',
|
|
||||||
# 'DegreeGuide',
|
|
||||||
# 'HNSW-D9',
|
|
||||||
# 'RandCut',
|
|
||||||
"Original HNSW",
|
|
||||||
"Our Pruning Method",
|
|
||||||
"Small M",
|
|
||||||
"Random Prune",
|
|
||||||
]
|
|
||||||
|
|
||||||
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
|
|
||||||
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
|
|
||||||
|
|
||||||
def extract_data_from_log(log_content):
|
|
||||||
"""Extract method names, recall lists, and recompute lists from the log file."""
|
|
||||||
|
|
||||||
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
|
|
||||||
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
|
|
||||||
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
|
|
||||||
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
|
|
||||||
|
|
||||||
method_matches = re.findall(method_pattern, log_content)
|
|
||||||
# Temporary list for raw method identifiers from regex
|
|
||||||
_methods_raw_identifiers_regex = []
|
|
||||||
for match in method_matches:
|
|
||||||
method_ident = match[0] if match[0] else match[1]
|
|
||||||
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
|
|
||||||
|
|
||||||
recall_lists_str = re.findall(recall_list_pattern, log_content)
|
|
||||||
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
|
|
||||||
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
|
|
||||||
|
|
||||||
# Determine if regex approach was sufficient, similar to original logic
|
|
||||||
# This check helps decide if we use regex-extracted names or fallback to split-parsing
|
|
||||||
_min_len_for_regex_path = min(
|
|
||||||
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
|
|
||||||
len(recall_lists_str) if recall_lists_str else 0,
|
|
||||||
len(recompute_lists_str) if recompute_lists_str else 0,
|
|
||||||
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
methods = [] # This will hold the final display names
|
|
||||||
|
|
||||||
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
|
|
||||||
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
|
|
||||||
sections = log_content.split("Building HNSW index with ")[1:]
|
|
||||||
methods_temp = []
|
|
||||||
for section in sections:
|
|
||||||
method_name_raw = section.split("\n")[0].strip().rstrip('.')
|
|
||||||
# Apply new short names in fallback
|
|
||||||
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
|
||||||
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
|
||||||
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
|
||||||
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
|
|
||||||
else: mapped_name = method_name_raw # Fallback to raw if no rule
|
|
||||||
methods_temp.append(mapped_name)
|
|
||||||
methods = methods_temp
|
|
||||||
# If fallback provides fewer than 4 methods, reordering later might not apply or error
|
|
||||||
# print(f"Direct extraction found {len(methods)} methods: {methods}")
|
|
||||||
else: # Regex path considered sufficient
|
|
||||||
methods_temp = []
|
|
||||||
for raw_name in _methods_raw_identifiers_regex:
|
|
||||||
# Apply new short names for regex path too
|
|
||||||
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
|
|
||||||
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
|
|
||||||
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
|
|
||||||
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
|
|
||||||
else: mapped_name = raw_name # Fallback to cleaned raw name
|
|
||||||
methods_temp.append(mapped_name)
|
|
||||||
methods = methods_temp
|
|
||||||
# print(f"Regex extraction found {len(methods)} methods: {methods}")
|
|
||||||
|
|
||||||
# Convert string lists of numbers to actual numbers
|
|
||||||
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
|
|
||||||
|
|
||||||
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
|
|
||||||
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
|
|
||||||
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
|
|
||||||
if len(methods) >= 4 and \
|
|
||||||
len(recall_lists_str) >= 4 and \
|
|
||||||
len(recompute_lists_str) >= 4 and \
|
|
||||||
len(avg_neighbors) >= 4:
|
|
||||||
# This reordering means:
|
|
||||||
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
|
|
||||||
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
|
|
||||||
methods = [methods[0], methods[1], methods[3], methods[2]]
|
|
||||||
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
|
|
||||||
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
|
|
||||||
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
|
|
||||||
# else:
|
|
||||||
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
|
|
||||||
|
|
||||||
|
|
||||||
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
|
|
||||||
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
|
|
||||||
current_len = len(avg_neighbors)
|
|
||||||
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
|
|
||||||
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
|
|
||||||
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
|
|
||||||
avg_neighbors = target_avg_neighbors[:current_len]
|
|
||||||
|
|
||||||
|
|
||||||
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
|
|
||||||
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
|
|
||||||
|
|
||||||
# Final truncation to ensure all lists have the same minimum length
|
|
||||||
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
|
|
||||||
|
|
||||||
methods = methods[:min_length]
|
|
||||||
recall_lists = recall_lists[:min_length]
|
|
||||||
recompute_lists = recompute_lists[:min_length]
|
|
||||||
avg_neighbors = avg_neighbors[:min_length]
|
|
||||||
|
|
||||||
return methods, recall_lists, recompute_lists, avg_neighbors
|
|
||||||
|
|
||||||
|
|
||||||
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
|
|
||||||
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
|
|
||||||
|
|
||||||
computation_costs = []
|
|
||||||
for i, method_name in enumerate(methods): # methods now contains short names
|
|
||||||
method_costs = []
|
|
||||||
for level in current_recall_levels:
|
|
||||||
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
|
|
||||||
if recall_idx is not None:
|
|
||||||
method_costs.append(recompute_lists[i][recall_idx])
|
|
||||||
else:
|
|
||||||
method_costs.append(None)
|
|
||||||
computation_costs.append(method_costs)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(5,2.5))
|
|
||||||
|
|
||||||
# Modified academic_colors for consistency
|
|
||||||
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
|
|
||||||
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
|
|
||||||
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
|
|
||||||
markers = ['o', '*', '^', 'D', 'v', 'P']
|
|
||||||
# Origin, Our, Random, SmallM
|
|
||||||
|
|
||||||
|
|
||||||
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
|
|
||||||
color_idx = i % len(academic_colors)
|
|
||||||
marker_idx = i % len(markers)
|
|
||||||
|
|
||||||
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
|
|
||||||
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
|
|
||||||
|
|
||||||
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
|
|
||||||
linestyle = '--'
|
|
||||||
else:
|
|
||||||
linestyle = '-'
|
|
||||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
|
||||||
marker_size = 12
|
|
||||||
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
|
|
||||||
marker_size = 7.5
|
|
||||||
else:
|
|
||||||
marker_size = 8
|
|
||||||
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
|
|
||||||
zorder = 10
|
|
||||||
else:
|
|
||||||
zorder = 1
|
|
||||||
|
|
||||||
# for random prune
|
|
||||||
if method_name == MAPPED_METHOD_NAMES[3]:
|
|
||||||
y_values_plot[0] += 0.12 # To prevent overlap with our method
|
|
||||||
elif method_name == MAPPED_METHOD_NAMES[1]:
|
|
||||||
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
|
|
||||||
|
|
||||||
ax.plot(current_recall_levels, y_values_plot,
|
|
||||||
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
|
|
||||||
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
|
|
||||||
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
|
|
||||||
|
|
||||||
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
|
|
||||||
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
|
|
||||||
ax.set_xticks(current_recall_levels)
|
|
||||||
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
|
|
||||||
ax.tick_params(axis='y', labelsize=10)
|
|
||||||
|
|
||||||
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
|
|
||||||
|
|
||||||
# Legend styling (already moved up from previous request)
|
|
||||||
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
|
|
||||||
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
|
|
||||||
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
|
|
||||||
|
|
||||||
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
|
|
||||||
|
|
||||||
# Spines adjustment for academic look
|
|
||||||
ax.spines['top'].set_visible(False)
|
|
||||||
ax.spines['right'].set_visible(False)
|
|
||||||
ax.spines['left'].set_linewidth(1.0)
|
|
||||||
ax.spines['bottom'].set_linewidth(1.0)
|
|
||||||
|
|
||||||
annot_recall_level_92 = 0.92
|
|
||||||
if annot_recall_level_92 in current_recall_levels:
|
|
||||||
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
|
|
||||||
method_base_name = "Our Pruning Method"
|
|
||||||
method_compare_92_name = "Small M"
|
|
||||||
|
|
||||||
if method_base_name in methods and method_compare_92_name in methods:
|
|
||||||
idx_base = methods.index(method_base_name)
|
|
||||||
idx_compare_92 = methods.index(method_compare_92_name)
|
|
||||||
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
|
|
||||||
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
|
|
||||||
|
|
||||||
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
|
|
||||||
ratio_92 = cost_compare_92 / cost_base_92
|
|
||||||
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
|
|
||||||
xytext=(annot_recall_level_92, cost_base_92),
|
|
||||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
|
||||||
lw=1.5, mutation_scale=15,
|
|
||||||
shrinkA=3, shrinkB=3),
|
|
||||||
zorder=10) # Arrow drawn first
|
|
||||||
|
|
||||||
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
|
|
||||||
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
|
|
||||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
|
||||||
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
|
||||||
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
|
||||||
|
|
||||||
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
|
|
||||||
fontsize=9, color='black',
|
|
||||||
va='center', ha='center', # Centered horizontally and vertically
|
|
||||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
|
||||||
fc='white', # Face color matches plot background
|
|
||||||
ec='white', # Edge color matches plot background
|
|
||||||
alpha=1.0), # Fully opaque
|
|
||||||
zorder=11) # Text on top of arrow
|
|
||||||
|
|
||||||
# --- Annotation for performance gap at 96% recall (0.96) ---
|
|
||||||
annot_recall_level_96 = 0.96
|
|
||||||
if annot_recall_level_96 in current_recall_levels:
|
|
||||||
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
|
|
||||||
method_base_name = "Our Pruning Method"
|
|
||||||
method_compare_96_name = "Random Prune"
|
|
||||||
|
|
||||||
if method_base_name in methods and method_compare_96_name in methods:
|
|
||||||
idx_base = methods.index(method_base_name)
|
|
||||||
idx_compare_96 = methods.index(method_compare_96_name)
|
|
||||||
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
|
|
||||||
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
|
|
||||||
|
|
||||||
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
|
|
||||||
ratio_96 = cost_compare_96 / cost_base_96
|
|
||||||
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
|
|
||||||
xytext=(annot_recall_level_96, cost_base_96),
|
|
||||||
arrowprops=dict(arrowstyle="<->", color='#333333',
|
|
||||||
lw=1.5, mutation_scale=15,
|
|
||||||
shrinkA=3, shrinkB=3),
|
|
||||||
zorder=10) # Arrow drawn first
|
|
||||||
|
|
||||||
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
|
|
||||||
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
|
|
||||||
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
|
|
||||||
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
|
|
||||||
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
|
|
||||||
|
|
||||||
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
|
|
||||||
fontsize=9, color='black',
|
|
||||||
va='center', ha='center', # Centered horizontally and vertically
|
|
||||||
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
|
|
||||||
fc='white', # Face color matches plot background
|
|
||||||
ec='white', # Edge color matches plot background
|
|
||||||
alpha=1.0), # Fully opaque
|
|
||||||
zorder=11) # Text on top of arrow
|
|
||||||
|
|
||||||
|
|
||||||
plt.tight_layout(pad=0.5)
|
|
||||||
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# --- Main script execution ---
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("log_file", type=str, default="./demo/output.log")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(args.log_file, 'r') as f:
|
|
||||||
log_content = f.read()
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Log file '{args.log_file}' not found.")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
|
|
||||||
|
|
||||||
if methods:
|
|
||||||
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
|
|
||||||
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
|
|
||||||
|
|
||||||
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
|
|
||||||
print(f"Recall comparison plot saved to {SAVED_PATH}")
|
|
||||||
|
|
||||||
print("\nMethod Summary:")
|
|
||||||
for i, method in enumerate(methods):
|
|
||||||
print(f"{method}:")
|
|
||||||
if i < len(avg_neighbors): # Check index bounds
|
|
||||||
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
|
|
||||||
|
|
||||||
for level in recall_levels:
|
|
||||||
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
|
|
||||||
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
|
|
||||||
if recall_idx is not None:
|
|
||||||
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
|
|
||||||
else:
|
|
||||||
print(f" - Does not reach {level*100:.0f}% recall in the test")
|
|
||||||
else:
|
|
||||||
print(f" - Data missing for recall/recompute lists for method {method}")
|
|
||||||
print()
|
|
||||||
else:
|
|
||||||
print("No data extracted from the log file. Cannot generate plots or summary.")
|
|
||||||
@@ -1,441 +0,0 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.lines as mlines
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from matplotlib.patches import FancyArrowPatch
|
|
||||||
|
|
||||||
sns.set_theme(style="ticks", font_scale=1.2)
|
|
||||||
plt.rcParams['axes.grid'] = True
|
|
||||||
plt.rcParams['axes.grid.which'] = 'major'
|
|
||||||
plt.rcParams['grid.linestyle'] = '--'
|
|
||||||
plt.rcParams['grid.color'] = 'gray'
|
|
||||||
plt.rcParams['grid.alpha'] = 0.3
|
|
||||||
plt.rcParams['xtick.minor.visible'] = False
|
|
||||||
plt.rcParams['ytick.minor.visible'] = False
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
|
||||||
# 0.085s 0.217s 0.472s
|
|
||||||
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
|
|
||||||
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
|
|
||||||
|
|
||||||
def parse_latency_data(csv_path):
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
latency_data = {}
|
|
||||||
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
|
||||||
dataset = row['Dataset']
|
|
||||||
hardware = row['Hardware']
|
|
||||||
recall_target_str = row['Recall_target'].replace('%', '')
|
|
||||||
try:
|
|
||||||
recall_target = float(recall_target_str)
|
|
||||||
except ValueError:
|
|
||||||
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
|
|
||||||
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
|
|
||||||
if not pd.isna(llm_time_val):
|
|
||||||
llm_gen_times[(dataset, hardware)] = llm_time_val
|
|
||||||
else:
|
|
||||||
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
|
|
||||||
|
|
||||||
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
|
|
||||||
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
|
|
||||||
|
|
||||||
for col in df.columns:
|
|
||||||
if col not in cols_to_skip:
|
|
||||||
method_name = col
|
|
||||||
key = (dataset, hardware, method_name)
|
|
||||||
if key not in latency_data:
|
|
||||||
latency_data[key] = []
|
|
||||||
try:
|
|
||||||
latency_value = float(row[method_name])
|
|
||||||
latency_data[key].append((recall_target, latency_value))
|
|
||||||
except ValueError:
|
|
||||||
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
|
|
||||||
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
|
|
||||||
latency_data[key].append((recall_target, np.nan)) # Or skip appending
|
|
||||||
|
|
||||||
# Sort by recall for consistent plotting
|
|
||||||
for key in latency_data:
|
|
||||||
latency_data[key].sort(key=lambda x: x[0])
|
|
||||||
return latency_data, llm_gen_times
|
|
||||||
|
|
||||||
def parse_storage_data(csv_path):
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
storage_data = {}
|
|
||||||
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
|
|
||||||
# And the header row is like: MetricType, Method1, Method2, ...
|
|
||||||
# Transpose to make methods as rows for easier lookup might be an option,
|
|
||||||
# but let's try direct parsing.
|
|
||||||
|
|
||||||
# Find the row for RAM and Storage
|
|
||||||
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
|
|
||||||
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
|
|
||||||
|
|
||||||
methods = df.columns[1:] # First column is the metric type label
|
|
||||||
for method in methods:
|
|
||||||
storage_data[method] = {
|
|
||||||
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
|
|
||||||
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
|
|
||||||
}
|
|
||||||
return storage_data
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
latency_csv_path = 'paper_plot/data/main_latency.csv'
|
|
||||||
storage_csv_path = 'paper_plot/data/ram_storage.csv'
|
|
||||||
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
|
|
||||||
storage_info = parse_storage_data(storage_csv_path)
|
|
||||||
|
|
||||||
# --- Determine unique Datasets and Hardware combinations to plot for ---
|
|
||||||
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
|
|
||||||
|
|
||||||
if not unique_dataset_hardware_configs:
|
|
||||||
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# --- Define constants for plotting ---
|
|
||||||
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
|
|
||||||
if not all_method_names:
|
|
||||||
# Fallback if latency_data is empty but storage_info might have method names
|
|
||||||
all_method_names = sorted(list(storage_info.keys()))
|
|
||||||
|
|
||||||
if not all_method_names:
|
|
||||||
print("Error: No method names found in data. Cannot proceed with plotting.")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
method_markers = {
|
|
||||||
'HNSW': 'o',
|
|
||||||
'IVF': 'X',
|
|
||||||
'DiskANN': 's',
|
|
||||||
'IVF-Disk': 'P',
|
|
||||||
'IVF-Recompute': '^',
|
|
||||||
'Our': '*',
|
|
||||||
'BM25': "v"
|
|
||||||
# Add more if necessary, or make it dynamic
|
|
||||||
}
|
|
||||||
method_display_names = {
|
|
||||||
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
|
|
||||||
# 其他方法保持原名
|
|
||||||
}
|
|
||||||
|
|
||||||
# Ensure all methods have a marker
|
|
||||||
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
|
|
||||||
next_default_marker = 0
|
|
||||||
for mn in all_method_names:
|
|
||||||
if mn not in method_markers:
|
|
||||||
print(f"mn: {mn}")
|
|
||||||
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
|
|
||||||
next_default_marker +=1
|
|
||||||
|
|
||||||
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
|
|
||||||
# Define colors for up to a few common recall levels, add more if needed
|
|
||||||
base_recall_colors = {
|
|
||||||
85.0: "#1f77b4", # Blue
|
|
||||||
90.0: "#ff7f0e", # Orange
|
|
||||||
95.0: "#2ca02c", # Green
|
|
||||||
# Add more if other recall % values exist
|
|
||||||
}
|
|
||||||
recall_colors = {}
|
|
||||||
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
|
|
||||||
|
|
||||||
for idx, r_level in enumerate(recall_levels_present):
|
|
||||||
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
|
|
||||||
|
|
||||||
|
|
||||||
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
|
|
||||||
all_latency_values = []
|
|
||||||
all_storage_values = []
|
|
||||||
raw_data_size = 76 # Raw data size in GB
|
|
||||||
|
|
||||||
for ds_hw_key in unique_dataset_hardware_configs:
|
|
||||||
current_ds, current_hw = ds_hw_key
|
|
||||||
for method_name in all_method_names:
|
|
||||||
# Get storage for this method
|
|
||||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
|
||||||
if not np.isnan(disk_storage):
|
|
||||||
all_storage_values.append(disk_storage)
|
|
||||||
|
|
||||||
# Get latencies for this method under current_ds, current_hw
|
|
||||||
latency_key = (current_ds, current_hw, method_name)
|
|
||||||
if latency_key in latency_data:
|
|
||||||
for recall, latency in latency_data[latency_key]:
|
|
||||||
if not np.isnan(latency):
|
|
||||||
all_latency_values.append(latency)
|
|
||||||
|
|
||||||
# Add padding to limits
|
|
||||||
min_lat = min(all_latency_values) if all_latency_values else 0.001
|
|
||||||
max_lat = max(all_latency_values) if all_latency_values else 1
|
|
||||||
min_store = min(all_storage_values) if all_storage_values else 0
|
|
||||||
max_store = max(all_storage_values) if all_storage_values else 1
|
|
||||||
|
|
||||||
# Convert storage values to proportion of raw data
|
|
||||||
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
|
|
||||||
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
|
|
||||||
|
|
||||||
# Padding for log scale latency - adjust minimum to be more reasonable
|
|
||||||
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
|
|
||||||
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
|
|
||||||
lat_padding = (lat_log_max - lat_log_min) * 0.05
|
|
||||||
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
|
|
||||||
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
|
|
||||||
|
|
||||||
# Padding for linear scale storage proportion
|
|
||||||
store_padding = (max_store_proportion - min_store_proportion) * 0.05
|
|
||||||
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
|
|
||||||
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
|
|
||||||
global_ylim[1] = global_ylim[0] + 0.1
|
|
||||||
|
|
||||||
# After loading the data and before plotting, add this code to reorder the datasets
|
|
||||||
# Find where you define all_datasets (around line 95)
|
|
||||||
|
|
||||||
# Original code:
|
|
||||||
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
|
|
||||||
|
|
||||||
# Replace with this to specify the exact order:
|
|
||||||
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
|
|
||||||
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
|
|
||||||
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
|
|
||||||
# Add any datasets that might be in the data but not in our desired_order list
|
|
||||||
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
|
|
||||||
|
|
||||||
# Then the rest of your code remains the same:
|
|
||||||
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
|
|
||||||
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
|
|
||||||
|
|
||||||
# Create two figures - one for A10 and one for MAC
|
|
||||||
hardware_configs = [a10_configs, mac_configs]
|
|
||||||
hardware_names = ['A10', 'MAC']
|
|
||||||
|
|
||||||
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
|
|
||||||
if not configs_for_this_figure:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_cols_this_figure = len(configs_for_this_figure)
|
|
||||||
# 1 row, num_cols_this_figure columns
|
|
||||||
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
|
|
||||||
|
|
||||||
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
|
|
||||||
|
|
||||||
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
|
|
||||||
ax = axs[0, subplot_idx] # Accessing column in the first row
|
|
||||||
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
|
|
||||||
|
|
||||||
for method_name in all_method_names:
|
|
||||||
marker = method_markers.get(method_name, '+')
|
|
||||||
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
|
|
||||||
|
|
||||||
latency_points_key = (current_ds, current_hw, method_name)
|
|
||||||
if latency_points_key in latency_data:
|
|
||||||
points_for_method = latency_data[latency_points_key]
|
|
||||||
print(f"points_for_method: {points_for_method}")
|
|
||||||
for recall, latency in points_for_method:
|
|
||||||
# Only skip if latency is invalid (since we need log scale for x-axis)
|
|
||||||
# But allow zero storage since y-axis is now linear
|
|
||||||
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add LLM generation time from CSV
|
|
||||||
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
|
|
||||||
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
|
|
||||||
latency = latency + current_llm_add_time
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
|
|
||||||
|
|
||||||
# Special handling for BM25
|
|
||||||
if method_name == 'BM25':
|
|
||||||
# BM25 is only valid for 85% recall points (other points are 0)
|
|
||||||
if recall != 85.0:
|
|
||||||
continue
|
|
||||||
color = 'grey'
|
|
||||||
else:
|
|
||||||
# Use the color for target recall
|
|
||||||
color = recall_colors.get(recall, 'grey')
|
|
||||||
|
|
||||||
# Convert storage to proportion
|
|
||||||
disk_storage_proportion = disk_storage / raw_data_size
|
|
||||||
size = 80
|
|
||||||
|
|
||||||
x_offset = -50
|
|
||||||
if current_ds == 'GPQA':
|
|
||||||
x_offset = -32
|
|
||||||
|
|
||||||
# Apply a small vertical offset to IVF-Recompute points to make them more visible
|
|
||||||
if method_name == 'IVF-Recompute':
|
|
||||||
# Add a small vertical offset (adjust the 0.05 value as needed)
|
|
||||||
disk_storage_proportion += 0.07
|
|
||||||
size = 80
|
|
||||||
if method_name == 'DiskANN':
|
|
||||||
size = 50
|
|
||||||
if method_name == 'Our':
|
|
||||||
size = 140
|
|
||||||
disk_storage_proportion += 0.05
|
|
||||||
# Add "Pareto Frontier" label to Our method points
|
|
||||||
|
|
||||||
if recall == 95:
|
|
||||||
ax.annotate('Ours',
|
|
||||||
(latency, disk_storage_proportion),
|
|
||||||
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
|
|
||||||
textcoords='offset points',
|
|
||||||
fontsize=20,
|
|
||||||
color='red',
|
|
||||||
weight='bold',
|
|
||||||
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
|
|
||||||
# Increase size for BM25 points
|
|
||||||
if method_name == 'BM25':
|
|
||||||
size = 70
|
|
||||||
size*=5
|
|
||||||
|
|
||||||
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
|
|
||||||
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ax.set_xscale("log")
|
|
||||||
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
|
|
||||||
|
|
||||||
# Generate appropriate powers of 10 based on your data range
|
|
||||||
min_power = -1
|
|
||||||
max_power = 4
|
|
||||||
log_ticks = [10**i for i in range(min_power, max_power+1)]
|
|
||||||
|
|
||||||
# Set custom tick positions
|
|
||||||
ax.set_xticks(log_ticks)
|
|
||||||
|
|
||||||
# Create custom bold LaTeX labels with 10^n format
|
|
||||||
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
|
|
||||||
ax.set_xticklabels(log_tick_labels, fontsize=24)
|
|
||||||
|
|
||||||
# Apply global limits
|
|
||||||
if subplot_idx == 0:
|
|
||||||
ax.set_xlim(global_xlim)
|
|
||||||
ax.set_ylim(global_ylim)
|
|
||||||
|
|
||||||
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
|
|
||||||
# Remove minor grid lines completely
|
|
||||||
ax.grid(False, which="minor")
|
|
||||||
|
|
||||||
# Remove ticks
|
|
||||||
# First set the shared parameters for both axes
|
|
||||||
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
|
|
||||||
|
|
||||||
# Then set the padding only for the x-axis
|
|
||||||
ax.tick_params(axis='x', which='both', pad=10)
|
|
||||||
|
|
||||||
if subplot_idx == 0: # Y-label only for the leftmost subplot
|
|
||||||
ax.set_ylabel("Proportional Size", fontsize=24)
|
|
||||||
|
|
||||||
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
|
|
||||||
# Let's put it on all for now.
|
|
||||||
ax.set_xlabel("Latency (s)", fontsize=25)
|
|
||||||
|
|
||||||
# Display 100%, 200%, 300% for yaxis
|
|
||||||
ax.set_yticks([1, 2, 3])
|
|
||||||
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
|
|
||||||
|
|
||||||
# Create a custom arrow with "Better" text inside
|
|
||||||
# Create the arrow patch with a wider shaft
|
|
||||||
arrow = FancyArrowPatch(
|
|
||||||
(0.8, 0.8), # Start point (top-right)
|
|
||||||
(0.65, 0.6), # End point (toward bottom-left)
|
|
||||||
transform=ax.transAxes,
|
|
||||||
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
|
|
||||||
facecolor='white',
|
|
||||||
edgecolor='black',
|
|
||||||
linewidth=3, # Thicker outline
|
|
||||||
zorder=5
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the arrow to the plot
|
|
||||||
ax.add_patch(arrow)
|
|
||||||
|
|
||||||
# Calculate the midpoint of the arrow for text placement
|
|
||||||
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
|
|
||||||
mid_y = (0.8 + 0.6) / 2 + 0.01
|
|
||||||
|
|
||||||
# Add the "Better" text at the midpoint of the arrow
|
|
||||||
ax.text(mid_x, mid_y, 'Better',
|
|
||||||
transform=ax.transAxes,
|
|
||||||
ha='center',
|
|
||||||
va='center',
|
|
||||||
fontsize=16, # Increased font size from 12 to 16
|
|
||||||
fontweight='bold',
|
|
||||||
rotation=40, # Rotate to match arrow direction
|
|
||||||
zorder=6) # Ensure text is on top of arrow
|
|
||||||
|
|
||||||
# Create legends (once per figure)
|
|
||||||
method_legend_handles = []
|
|
||||||
for method, marker_style in method_markers.items():
|
|
||||||
if method in all_method_names:
|
|
||||||
print(f"method: {method}")
|
|
||||||
# Use black color for BM25 in the legend
|
|
||||||
if method == 'BM25':
|
|
||||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
|
||||||
markersize=10, label=method))
|
|
||||||
else:
|
|
||||||
if method in method_display_names:
|
|
||||||
method = method_display_names[method]
|
|
||||||
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
|
|
||||||
markersize=10, label=method))
|
|
||||||
|
|
||||||
recall_legend_handles = []
|
|
||||||
sorted_recall_levels = sorted(recall_colors.keys())
|
|
||||||
for r_level in sorted_recall_levels:
|
|
||||||
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
|
|
||||||
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
|
|
||||||
|
|
||||||
# 将图例分成两行:第一行是方法,第二行是召回率
|
|
||||||
if fig_idx == 0:
|
|
||||||
# 从方法列表中先排除'Our'
|
|
||||||
other_methods = [m for m in all_method_names if m != 'Our']
|
|
||||||
# 按照需要的顺序创建方法列表(将'Our'放在最后)
|
|
||||||
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
|
|
||||||
|
|
||||||
# 按照新顺序创建方法图例句柄
|
|
||||||
method_legend_handles = []
|
|
||||||
for method in ordered_methods:
|
|
||||||
if method in method_markers:
|
|
||||||
marker_style = method_markers[method]
|
|
||||||
# 使用显示名称映射
|
|
||||||
display_name = method_display_names.get(method, method)
|
|
||||||
color = 'black'
|
|
||||||
marker_size = 22
|
|
||||||
if method == 'Our':
|
|
||||||
marker_size = 27
|
|
||||||
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
|
|
||||||
marker_size = 17
|
|
||||||
elif 'DiskANN' in method:
|
|
||||||
marker_size = 19
|
|
||||||
elif 'BM25' in method:
|
|
||||||
marker_size = 20
|
|
||||||
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
|
|
||||||
linestyle='None', markersize=marker_size, label=display_name))
|
|
||||||
|
|
||||||
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
|
|
||||||
recall_legend = fig.legend(handles=recall_legend_handles,
|
|
||||||
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低,放在第一行下方
|
|
||||||
ncol=len(recall_legend_handles), fontsize=28)
|
|
||||||
|
|
||||||
|
|
||||||
# 创建方法图例(第一行)
|
|
||||||
method_legend = fig.legend(handles=method_legend_handles,
|
|
||||||
loc='upper center', bbox_to_anchor=(0.5, 0.91),
|
|
||||||
ncol=len(method_legend_handles), fontsize=28)
|
|
||||||
|
|
||||||
# 添加图例到渲染器
|
|
||||||
fig.add_artist(method_legend)
|
|
||||||
fig.add_artist(recall_legend)
|
|
||||||
|
|
||||||
# 调整布局,为顶部的两行图例留出更多空间
|
|
||||||
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
|
|
||||||
|
|
||||||
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
|
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
||||||
print(f"Saved figure {fig_idx+1} to {save_path}")
|
|
||||||
plt.show()
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
import csv
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import csv
|
|
||||||
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
SAVE_PTH = "./paper_plot/figures"
|
|
||||||
font_size = 16
|
|
||||||
|
|
||||||
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
|
|
||||||
# 0.085s 0.217s 0.472s
|
|
||||||
llm_inference_time=[0.085, 0.217, 0.472, 0]
|
|
||||||
|
|
||||||
USE_LLM_INDEX = 3 # +0
|
|
||||||
|
|
||||||
file_path = "./paper_plot/data/main_latency.csv"
|
|
||||||
|
|
||||||
with open(file_path, mode="r", newline="") as file:
|
|
||||||
reader = csv.reader(file)
|
|
||||||
data = list(reader)
|
|
||||||
|
|
||||||
# 打印原始数据
|
|
||||||
for row in data:
|
|
||||||
print(",".join(row))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
models = ["A10", "MAC"]
|
|
||||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
|
||||||
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
|
|
||||||
for k, model in enumerate(models):
|
|
||||||
|
|
||||||
fig, axes = plt.subplots(1, 4)
|
|
||||||
fig.set_size_inches(20, 3)
|
|
||||||
plt.subplots_adjust(wspace=0, hspace=0)
|
|
||||||
|
|
||||||
total_width, n = 6, 6
|
|
||||||
group = 1
|
|
||||||
width = total_width * 0.9 / n
|
|
||||||
x = np.arange(group) * n
|
|
||||||
exit_idx_x = x + (total_width - width) / n
|
|
||||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
|
|
||||||
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
|
|
||||||
hatches =["\\\\\\","\\\\"]
|
|
||||||
|
|
||||||
labels = [
|
|
||||||
"HNSW",
|
|
||||||
"IVF",
|
|
||||||
"DiskANN",
|
|
||||||
"IVF-Disk",
|
|
||||||
"IVF-Recompute",
|
|
||||||
"Our",
|
|
||||||
# "DGL-OnDisk",
|
|
||||||
]
|
|
||||||
if k == 0:
|
|
||||||
x_labels = "GraphSAGE"
|
|
||||||
else:
|
|
||||||
x_labels = "GAT"
|
|
||||||
|
|
||||||
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
|
|
||||||
val_limit = 15000 # Upper limit for the plot
|
|
||||||
|
|
||||||
for i in range(4):
|
|
||||||
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
|
|
||||||
axes[i].set_yticks(yticks)
|
|
||||||
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
|
|
||||||
|
|
||||||
axes[i].tick_params(axis="y", labelsize=10)
|
|
||||||
|
|
||||||
axes[i].set_xticks([])
|
|
||||||
# axes[i].set_xticklabels()
|
|
||||||
axes[i].set_xlabel(datasets[i], fontsize=font_size)
|
|
||||||
axes[i].grid(axis="y", linestyle="--")
|
|
||||||
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
|
|
||||||
for j in range(n):
|
|
||||||
##TODO add label
|
|
||||||
|
|
||||||
# num = float(data[i * 2 + k][j + 3])
|
|
||||||
# plot_label = [num]
|
|
||||||
# if j == 6 and i == 3:
|
|
||||||
# plot_label = ["N/A"]
|
|
||||||
# num = 0
|
|
||||||
local_hatches=["////","\\\\","xxxx"]
|
|
||||||
# here add 3 bars rather than one bar TODO
|
|
||||||
print('exit_idx_x',exit_idx_x)
|
|
||||||
|
|
||||||
# Check if all three models for this algorithm are OOM (data = 0)
|
|
||||||
is_oom = True
|
|
||||||
for m in range(3):
|
|
||||||
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
|
|
||||||
is_oom = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_oom:
|
|
||||||
# Draw a cross for OOM instead of bars
|
|
||||||
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
|
|
||||||
marker_size = width * 150 # Size of the cross
|
|
||||||
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
|
|
||||||
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
|
|
||||||
else:
|
|
||||||
# Create three separate bar calls instead of trying to plot multiple bars at once
|
|
||||||
for m in range(3):
|
|
||||||
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
|
|
||||||
plot_label = [num]
|
|
||||||
pos = exit_idx_x + j * width + width * 0.3 * m
|
|
||||||
print(f"j: {j}, m: {m}, pos: {pos}")
|
|
||||||
# For log scale, we need to ensure values are positive
|
|
||||||
plot_value = max(0.01, num) if num < val_limit else val_limit
|
|
||||||
container = axes[i].bar(
|
|
||||||
pos,
|
|
||||||
plot_value,
|
|
||||||
width=width * 0.3,
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[j],
|
|
||||||
# edgecolor="k",
|
|
||||||
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
|
|
||||||
linewidth=1.0,
|
|
||||||
label=labels[j] if m == 0 else "", # Only add label for the first bar
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
# axes[i].bar_label(
|
|
||||||
# container,
|
|
||||||
# plot_label,
|
|
||||||
# fontsize=font_size - 2,
|
|
||||||
# zorder=200,
|
|
||||||
# fontweight="bold",
|
|
||||||
# )
|
|
||||||
|
|
||||||
if k == 0:
|
|
||||||
axes[0].legend(
|
|
||||||
bbox_to_anchor=(3.25, 1.02),
|
|
||||||
ncol=7,
|
|
||||||
loc="lower right",
|
|
||||||
# fontsize=font_size,
|
|
||||||
# markerscale=3,
|
|
||||||
labelspacing=0.2,
|
|
||||||
edgecolor="black",
|
|
||||||
facecolor="white",
|
|
||||||
framealpha=1,
|
|
||||||
shadow=False,
|
|
||||||
# fancybox=False,
|
|
||||||
handlelength=2,
|
|
||||||
handletextpad=0.5,
|
|
||||||
columnspacing=0.5,
|
|
||||||
prop={"weight": "bold", "size": font_size},
|
|
||||||
).set_zorder(100)
|
|
||||||
|
|
||||||
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
|
|
||||||
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
|
|
||||||
axes[1].set_yticklabels([])
|
|
||||||
axes[2].set_yticklabels([])
|
|
||||||
axes[3].set_yticklabels([])
|
|
||||||
|
|
||||||
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
|
|
||||||
## print save
|
|
||||||
print(f"{SAVE_PTH }/speed_{model}_revised.pdf")
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
from matplotlib import pyplot as plt
|
|
||||||
from matplotlib.gridspec import GridSpec
|
|
||||||
|
|
||||||
# Comment Test
|
|
||||||
|
|
||||||
# om script.settings import DATA_PATH, FIGURE_PATH
|
|
||||||
# DATA_PATH ="/home/ubuntu/Power-RAG/paper_plot/data"
|
|
||||||
# FIGURE_PATH = "/home/ubuntu/Power-RAG/paper_plot/figures"
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 2
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Load the RAM and Storage data directly from CSV
|
|
||||||
data = pd.read_csv("./paper_plot/data/ram_storage.csv")
|
|
||||||
|
|
||||||
# Explicitly reorder columns to ensure "Our" is at the end
|
|
||||||
cols = list(data.columns)
|
|
||||||
if "Our" in cols and cols[-1] != "Our":
|
|
||||||
cols.remove("Our")
|
|
||||||
cols.append("Our")
|
|
||||||
data = data[cols]
|
|
||||||
|
|
||||||
# Set up the figure with two columns
|
|
||||||
fig = plt.figure(figsize=(12, 3))
|
|
||||||
gs = GridSpec(1, 2, figure=fig)
|
|
||||||
ax1 = fig.add_subplot(gs[0, 0]) # Left panel for RAM
|
|
||||||
ax2 = fig.add_subplot(gs[0, 1]) # Right panel for Storage
|
|
||||||
|
|
||||||
# Define the visual style elements
|
|
||||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "silver", "navy"]
|
|
||||||
hatches = ["/////", "\\\\\\\\\\"]
|
|
||||||
|
|
||||||
# Calculate positions for the bars
|
|
||||||
methods = data.columns[1:] # Skip the 'Hardware' column
|
|
||||||
num_methods = len(methods)
|
|
||||||
# Reverse the order of methods for display (to have "Our" at the bottom)
|
|
||||||
methods = list(methods)[::-1]
|
|
||||||
y_positions = np.arange(num_methods)
|
|
||||||
bar_width = 0.6
|
|
||||||
|
|
||||||
# Plot RAM data in left panel
|
|
||||||
ram_bars = ax1.barh(
|
|
||||||
y_positions,
|
|
||||||
data.iloc[0, 1:].values[::-1], # Reverse the data to match reversed methods
|
|
||||||
height=bar_width,
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[0],
|
|
||||||
hatch=hatches[0],
|
|
||||||
linewidth=1.0,
|
|
||||||
label="RAM",
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
ax1.set_title("RAM Usage", fontsize=14, fontweight='bold')
|
|
||||||
ax1.set_yticks(y_positions)
|
|
||||||
ax1.set_yticklabels(methods, fontsize=14)
|
|
||||||
ax1.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
|
||||||
ax1.xaxis.set_tick_params(labelsize=14)
|
|
||||||
|
|
||||||
# Plot Storage data in right panel
|
|
||||||
storage_bars = ax2.barh(
|
|
||||||
y_positions,
|
|
||||||
data.iloc[1, 1:].values[::-1], # Reverse the data to match reversed methods
|
|
||||||
height=bar_width,
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[1],
|
|
||||||
hatch=hatches[1],
|
|
||||||
linewidth=1.0,
|
|
||||||
label="Storage",
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
ax2.set_title("Storage Usage", fontsize=14, fontweight='bold')
|
|
||||||
ax2.set_yticks(y_positions)
|
|
||||||
ax2.set_yticklabels(methods, fontsize=14)
|
|
||||||
ax2.set_xlabel("Size (\\textit{GB})", fontsize=14)
|
|
||||||
ax2.xaxis.set_tick_params(labelsize=14)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig("./paper_plot/figures/ram_storage_double_column.pdf", bbox_inches="tight", dpi=300)
|
|
||||||
print("Saving the figure to ./paper_plot/figures/ram_storage_double_column.pdf")
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
||||||
# \file: /bottleneck_breakdown.py
|
|
||||||
# \brief: Illustrates the query time bottleneck on consumer devices (Final Version - Font & Legend Adjust).
|
|
||||||
# Author: Gemini Assistant (adapted from user's style and feedback)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.ticker import FuncFormatter # Not strictly needed for just font, but imported if user wants to try
|
|
||||||
|
|
||||||
# Set matplotlib styles similar to the example
|
|
||||||
plt.rcParams["font.family"] = "Helvetica" # Primary font family
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["xtick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.0
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
# Attempt to make LaTeX use Helvetica as the main font
|
|
||||||
plt.rcParams['text.latex.preamble'] = r"""
|
|
||||||
\usepackage{helvet} % helvetica font
|
|
||||||
\usepackage{sansmath} % helvetica for math
|
|
||||||
\sansmath % activate sansmath
|
|
||||||
\renewcommand{\familydefault}{\sfdefault} % make sans-serif the default family
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# Final Data for the breakdown (3 Segments)
|
|
||||||
labels_raw = [ # Raw labels before potential LaTeX escaping
|
|
||||||
'IO: Text + PQ Lookup',
|
|
||||||
'CPU: Tokenize + Distance Compute',
|
|
||||||
'GPU: Embedding Recompute',
|
|
||||||
]
|
|
||||||
# Times in ms, ordered for stacking
|
|
||||||
times_ms = np.array([
|
|
||||||
8.009, # Quantization
|
|
||||||
16.197, # Search
|
|
||||||
76.512, # Embedding Recomputation
|
|
||||||
])
|
|
||||||
|
|
||||||
total_time_ms = times_ms.sum()
|
|
||||||
percentages = (times_ms / total_time_ms) * 100
|
|
||||||
|
|
||||||
# Prepare labels for legend, escaping for LaTeX if active
|
|
||||||
labels_legend = []
|
|
||||||
# st1 = r'&' # Not needed as current labels_raw don't have '&'
|
|
||||||
for label, time, perc in zip(labels_raw, times_ms, percentages):
|
|
||||||
# Construct the percentage string carefully for LaTeX
|
|
||||||
perc_str = f"{perc:.1f}" + r"\%" # Correct way to form 'NN.N\%'
|
|
||||||
# label_tex = label.replace('&', st1) # Use if '&' is in labels_raw
|
|
||||||
label_tex = label # Current labels_raw are clean for LaTeX
|
|
||||||
labels_legend.append(
|
|
||||||
f"{label_tex}\n({time:.1f}ms, {perc_str})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Styling based on user's script
|
|
||||||
# Using first 3 from the provided lists
|
|
||||||
edgecolors_list = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
|
||||||
hatches_list = ["/////", "xxxxx", "\\\\\\\\\\"]
|
|
||||||
|
|
||||||
edgecolors = edgecolors_list[:3]
|
|
||||||
hatches = hatches_list[:3]
|
|
||||||
fill_color = "white"
|
|
||||||
|
|
||||||
# Create the figure and axes
|
|
||||||
# Adjusted figure size to potentially accommodate legend on the right
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
fig.set_size_inches(7, 1.5) # Width increased slightly, height adjusted
|
|
||||||
# Adjusted right margin for external legend, bottom for x-label
|
|
||||||
plt.subplots_adjust(left=0.12, right=0.72, top=0.95, bottom=0.25)
|
|
||||||
|
|
||||||
# Create the horizontal stacked bar
|
|
||||||
bar_height = 0.2
|
|
||||||
y_pos = 0
|
|
||||||
|
|
||||||
left_offset = 0
|
|
||||||
for i in range(len(times_ms)):
|
|
||||||
ax.barh(
|
|
||||||
y_pos,
|
|
||||||
times_ms[i],
|
|
||||||
height=bar_height,
|
|
||||||
left=left_offset,
|
|
||||||
color=fill_color,
|
|
||||||
edgecolor=edgecolors[i],
|
|
||||||
hatch=hatches[i],
|
|
||||||
linewidth=1.5,
|
|
||||||
label=labels_legend[i],
|
|
||||||
zorder=10
|
|
||||||
)
|
|
||||||
text_x_pos = left_offset + times_ms[i] / 2
|
|
||||||
if times_ms[i] > total_time_ms * 0.03: # Threshold for displaying text
|
|
||||||
ax.text(
|
|
||||||
text_x_pos,
|
|
||||||
y_pos,
|
|
||||||
f"{times_ms[i]:.1f}ms",
|
|
||||||
ha='center',
|
|
||||||
va='center',
|
|
||||||
fontsize=8,
|
|
||||||
fontweight='bold',
|
|
||||||
color='black',
|
|
||||||
zorder=20,
|
|
||||||
bbox=dict(facecolor='white', edgecolor='none', pad=0.5, alpha=0.8)
|
|
||||||
)
|
|
||||||
left_offset += times_ms[i]
|
|
||||||
|
|
||||||
# Set plot limits and labels
|
|
||||||
ax.set_xlim([0, total_time_ms * 1.02])
|
|
||||||
ax.set_xlabel("Time (ms)", fontsize=14, fontweight='bold', x=0.75, )
|
|
||||||
|
|
||||||
# Y-axis: Remove y-ticks and labels
|
|
||||||
ax.set_yticks([])
|
|
||||||
ax.set_yticklabels([])
|
|
||||||
|
|
||||||
# Legend: Placed to the right of the plot
|
|
||||||
ax.legend(
|
|
||||||
# (x, y) for anchor, (0,0) is bottom left, (1,1) is top right of AXES
|
|
||||||
# To place outside on the right, x should be > 1
|
|
||||||
bbox_to_anchor=(1.03, 0.5), # x > 1 means outside to the right, y=0.5 for vertical center
|
|
||||||
ncol=1, # Single column for a taller, narrower legend
|
|
||||||
loc="center left", # Anchor the legend's left-center to bbox_to_anchor point
|
|
||||||
labelspacing=0.5, # Adjust spacing
|
|
||||||
edgecolor="black",
|
|
||||||
facecolor="white",
|
|
||||||
framealpha=1,
|
|
||||||
shadow=False,
|
|
||||||
fancybox=False,
|
|
||||||
handlelength=1.5,
|
|
||||||
handletextpad=0.6,
|
|
||||||
columnspacing=1.5,
|
|
||||||
prop={"weight": "bold", "size": 9},
|
|
||||||
).set_zorder(100)
|
|
||||||
|
|
||||||
# Save the figure (using the original generic name as requested)
|
|
||||||
output_filename = "./bottleneck_breakdown.pdf"
|
|
||||||
# plt.tight_layout() # tight_layout might conflict with external legend; adjust subplots_adjust instead
|
|
||||||
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
|
|
||||||
print(f"Saved plot to {output_filename}")
|
|
||||||
|
|
||||||
# plt.show() # Uncomment to display plot interactively
|
|
||||||
@@ -1,226 +0,0 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
# import matplotlib.ticker as mticker # Not actively used
|
|
||||||
import os
|
|
||||||
|
|
||||||
FIGURE_PATH = "paper_plot/figures"
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(FIGURE_PATH, exist_ok=True)
|
|
||||||
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
|
|
||||||
FIGURE_PATH = "."
|
|
||||||
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 2
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
|
|
||||||
dataset_names = ["NQ", "TriviaQA"]
|
|
||||||
metrics_plot1 = ["Exact Match", "F1"]
|
|
||||||
|
|
||||||
small_nq_f1 = 0.2621040899
|
|
||||||
small_tq_f1 = 0.4698198059
|
|
||||||
small_nq_em_score = 0.1845
|
|
||||||
small_tq_em_score = 0.4015
|
|
||||||
small_nq_time = 1.137
|
|
||||||
small_tq_time = 1.173
|
|
||||||
|
|
||||||
large_nq_f1 = 0.2841386117
|
|
||||||
large_tq_f1 = 0.4548340289
|
|
||||||
large_nq_em_score = 0.206
|
|
||||||
large_tq_em_score = 0.382
|
|
||||||
large_nq_time = 2.632
|
|
||||||
large_tq_time = 2.684
|
|
||||||
|
|
||||||
data_scores_plot1 = {
|
|
||||||
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
|
|
||||||
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
|
|
||||||
}
|
|
||||||
latency_data_plot2 = {
|
|
||||||
"NQ": [small_nq_time, large_nq_time],
|
|
||||||
"TriviaQA": [small_tq_time, large_tq_time]
|
|
||||||
}
|
|
||||||
|
|
||||||
edgecolors = ["dimgrey", "tomato"]
|
|
||||||
hatches = ["/////", "\\\\\\\\\\"]
|
|
||||||
|
|
||||||
# Changed: bar_center_separation_in_group increased for larger gap
|
|
||||||
bar_center_separation_in_group = 0.42
|
|
||||||
# Changed: bar_visual_width decreased for narrower bars
|
|
||||||
bar_visual_width = 0.28
|
|
||||||
|
|
||||||
figsize_plot1 = (4, 2.5)
|
|
||||||
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
|
|
||||||
figsize_plot2 = (2.5, 2.5)
|
|
||||||
|
|
||||||
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
|
|
||||||
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
|
|
||||||
|
|
||||||
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
|
|
||||||
|
|
||||||
|
|
||||||
def create_plot1_em_f1():
|
|
||||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
|
|
||||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
|
||||||
|
|
||||||
num_methods = len(method_labels)
|
|
||||||
metric_group_centers = np.array([0.5, 1.5])
|
|
||||||
# plot1_xlim_per_subplot is now global
|
|
||||||
|
|
||||||
for i, dataset_name in enumerate(dataset_names):
|
|
||||||
ax = axs[i]
|
|
||||||
for metric_idx, metric_name in enumerate(metrics_plot1):
|
|
||||||
metric_center_pos = metric_group_centers[metric_idx]
|
|
||||||
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
|
|
||||||
current_scores_percent = [val * 100 for val in current_scores_raw]
|
|
||||||
|
|
||||||
for j, method_label in enumerate(method_labels):
|
|
||||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
||||||
bar_center_pos = metric_center_pos + offset
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
|
|
||||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
|
||||||
label=method_label if i == 0 and metric_idx == 0 else None
|
|
||||||
)
|
|
||||||
ax.text(
|
|
||||||
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
|
|
||||||
ha='center', va='bottom', fontsize=8, fontweight='bold'
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xticks(metric_group_centers)
|
|
||||||
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
|
|
||||||
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
|
|
||||||
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
|
||||||
|
|
||||||
all_subplot_scores_percent = []
|
|
||||||
for metric_name_iter in metrics_plot1:
|
|
||||||
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
|
|
||||||
|
|
||||||
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
|
|
||||||
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
|
|
||||||
ax.tick_params(axis='y', labelsize=12)
|
|
||||||
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_visible(True)
|
|
||||||
spine.set_linewidth(1.0)
|
|
||||||
spine.set_edgecolor("black")
|
|
||||||
|
|
||||||
handles, labels = axs[0].get_legend_handles_labels()
|
|
||||||
fig.legend(
|
|
||||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
|
|
||||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
|
||||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
|
||||||
prop={"weight": "bold", "size": 9}
|
|
||||||
)
|
|
||||||
|
|
||||||
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
|
|
||||||
|
|
||||||
|
|
||||||
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
|
|
||||||
# plt.tight_layout() # Adjusted call below
|
|
||||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
|
||||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
|
||||||
plt.close(fig)
|
|
||||||
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
|
|
||||||
|
|
||||||
def create_plot2_latency():
|
|
||||||
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
|
|
||||||
fig.subplots_adjust(**common_subplots_adjust_params)
|
|
||||||
|
|
||||||
num_methods = len(method_labels)
|
|
||||||
method_group_center_in_subplot = 0.5
|
|
||||||
|
|
||||||
# Calculate bar extents to determine focused xlim
|
|
||||||
bar_positions_calc = []
|
|
||||||
for j_idx in range(num_methods):
|
|
||||||
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
||||||
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
|
|
||||||
bar_positions_calc.append(bar_center_pos_calc)
|
|
||||||
|
|
||||||
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
|
|
||||||
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
|
|
||||||
|
|
||||||
# Define padding around the bars
|
|
||||||
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
|
|
||||||
# padding_val = 0.15
|
|
||||||
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
|
|
||||||
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
|
|
||||||
|
|
||||||
# Option 2: Center the group (0.5) in a span of 1.0
|
|
||||||
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
|
|
||||||
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
|
|
||||||
# This is simpler and achieves the (0.0, 1.0) directly.
|
|
||||||
|
|
||||||
for i, dataset_name in enumerate(dataset_names):
|
|
||||||
ax = axs[i]
|
|
||||||
current_latencies = latency_data_plot2[dataset_name]
|
|
||||||
|
|
||||||
for j, method_label in enumerate(method_labels):
|
|
||||||
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
|
|
||||||
bar_center_pos = method_group_center_in_subplot + offset
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
|
|
||||||
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
|
|
||||||
label=method_label if i == 0 else None
|
|
||||||
)
|
|
||||||
ax.text(
|
|
||||||
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
|
|
||||||
ha='center', va='bottom', fontsize=10, fontweight='bold'
|
|
||||||
)
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
||||||
|
|
||||||
ax.set_xticks([0.5])
|
|
||||||
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
|
|
||||||
# set tick hatches
|
|
||||||
ax.tick_params(axis='x', colors="white")
|
|
||||||
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
|
|
||||||
ax.set_xlim(plot2_xlim_calculated)
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
|
|
||||||
|
|
||||||
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
|
|
||||||
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
|
|
||||||
ax.tick_params(axis='y', labelsize=12)
|
|
||||||
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_visible(True)
|
|
||||||
spine.set_linewidth(1.0)
|
|
||||||
spine.set_edgecolor("black")
|
|
||||||
|
|
||||||
handles, labels = axs[0].get_legend_handles_labels()
|
|
||||||
fig.legend(
|
|
||||||
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
|
|
||||||
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
|
|
||||||
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
|
|
||||||
prop={"weight": "bold", "size": 9}
|
|
||||||
)
|
|
||||||
|
|
||||||
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
|
|
||||||
|
|
||||||
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
|
|
||||||
# plt.tight_layout() # Adjusted call below
|
|
||||||
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
|
|
||||||
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
|
|
||||||
plt.close(fig)
|
|
||||||
print(f"Figure 2 (Latency) has been saved to: {save_path}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("Start generating figures...")
|
|
||||||
if plt.rcParams["text.usetex"]:
|
|
||||||
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
|
|
||||||
|
|
||||||
create_plot1_em_f1()
|
|
||||||
create_plot2_latency()
|
|
||||||
print("All figures have been generated.")
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
||||||
# \file: /speed_ablation.py
|
|
||||||
# \brief:
|
|
||||||
# Author: raphael hao
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# from script.settings import DATA_PATH, FIGURE_PATH
|
|
||||||
|
|
||||||
# Load the latency ablation data
|
|
||||||
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
|
|
||||||
# Filter for SpeedUp metric only
|
|
||||||
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
fig.set_size_inches(5, 1.5)
|
|
||||||
plt.subplots_adjust(wspace=0, hspace=0)
|
|
||||||
|
|
||||||
total_width, n = 3, 3
|
|
||||||
group = len(speedup_data['Dataset'].unique())
|
|
||||||
width = total_width * 0.9 / n
|
|
||||||
x = np.arange(group) * n
|
|
||||||
exit_idx_x = x + (total_width - width) / n
|
|
||||||
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
|
||||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
|
|
||||||
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
|
|
||||||
|
|
||||||
datasets = speedup_data['Dataset'].unique()
|
|
||||||
|
|
||||||
for i, dataset in enumerate(datasets):
|
|
||||||
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
|
|
||||||
|
|
||||||
for j in range(n):
|
|
||||||
if j == 0:
|
|
||||||
value = dataset_data['Original'].values[0]
|
|
||||||
elif j == 1:
|
|
||||||
value = dataset_data['original + two_level'].values[0]
|
|
||||||
else:
|
|
||||||
value = dataset_data['original + two_level + batch'].values[0]
|
|
||||||
|
|
||||||
ax.text(
|
|
||||||
exit_idx_x[i] + j * width,
|
|
||||||
value + 0.05,
|
|
||||||
f"{value:.2f}",
|
|
||||||
ha='center',
|
|
||||||
va='bottom',
|
|
||||||
fontsize=10,
|
|
||||||
fontweight='bold',
|
|
||||||
rotation=0,
|
|
||||||
zorder=20,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
exit_idx_x[i] + j * width,
|
|
||||||
value,
|
|
||||||
width=width * 0.8,
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[j],
|
|
||||||
hatch=hatches[j],
|
|
||||||
linewidth=1.5,
|
|
||||||
label=labels[j] if i == 0 else None,
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ax.set_ylim([0.5, 2.3])
|
|
||||||
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
|
|
||||||
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
|
|
||||||
ax.set_xticks(exit_idx_x + width)
|
|
||||||
ax.set_xticklabels(datasets, fontsize=10)
|
|
||||||
# ax.set_xlabel("Different Datasets", fontsize=14)
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(-0.03, 1.4),
|
|
||||||
ncol=3,
|
|
||||||
loc="upper left",
|
|
||||||
labelspacing=0.1,
|
|
||||||
edgecolor="black",
|
|
||||||
facecolor="white",
|
|
||||||
framealpha=1,
|
|
||||||
shadow=False,
|
|
||||||
fancybox=False,
|
|
||||||
handlelength=0.8,
|
|
||||||
handletextpad=0.6,
|
|
||||||
columnspacing=0.8,
|
|
||||||
prop={"weight": "bold", "size": 10},
|
|
||||||
).set_zorder(100)
|
|
||||||
ax.set_ylabel("Speedup", fontsize=11)
|
|
||||||
|
|
||||||
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
print(f"Save to ./paper_plot/figures/latency_speedup.pdf")
|
|
||||||
1
research/utils/.gitignore
vendored
1
research/utils/.gitignore
vendored
@@ -1 +0,0 @@
|
|||||||
analyze_diskann_graph
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
#include <cassert>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cstring>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <limits>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
static const size_t DISKANN_SECTOR_LEN = 4096; // Typical sector size
|
|
||||||
|
|
||||||
// ! Use float as CoordT
|
|
||||||
using CoordT = float;
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
|
||||||
if (argc < 3) {
|
|
||||||
std::cerr << "Usage: " << argv[0]
|
|
||||||
<< " <diskann_index_file> <output_degree_file>" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string disk_index_path = argv[1];
|
|
||||||
std::string output_degree_path = argv[2];
|
|
||||||
std::ifstream in(disk_index_path, std::ios::binary);
|
|
||||||
if (!in.is_open()) {
|
|
||||||
std::cerr << "Failed to open file: " << disk_index_path << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// =========== 1) Read meta information (corresponds to
|
|
||||||
// save_bin<uint64_t>(...,...,...,1,0)) ============== Read bin header:
|
|
||||||
// (npts_i32, dim_i32)
|
|
||||||
int32_t meta_count_i32 = 0, meta_dim_i32 = 0;
|
|
||||||
in.read(reinterpret_cast<char *>(&meta_count_i32), sizeof(int32_t));
|
|
||||||
in.read(reinterpret_cast<char *>(&meta_dim_i32), sizeof(int32_t));
|
|
||||||
size_t meta_count = static_cast<size_t>(meta_count_i32);
|
|
||||||
size_t meta_dim = static_cast<size_t>(meta_dim_i32);
|
|
||||||
|
|
||||||
// According to the diskann::save_bin writing method, here meta_dim is usually
|
|
||||||
// 1
|
|
||||||
std::cout << "[LOG] meta_count=" << meta_count << ", meta_dim=" << meta_dim
|
|
||||||
<< std::endl;
|
|
||||||
if (meta_dim != 1) {
|
|
||||||
std::cerr << "[ERROR] meta_dim != 1,不符合 create_disk_layout 的写盘约定。"
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read meta array
|
|
||||||
std::vector<uint64_t> meta(meta_count);
|
|
||||||
in.read(reinterpret_cast<char *>(meta.data()), meta_count * sizeof(uint64_t));
|
|
||||||
if (!in.good()) {
|
|
||||||
std::cerr << "[ERROR] Failed to read meta array, file is incomplete."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// meta[0..] Metadata
|
|
||||||
// 0: npts_64, 1: ndims_64, 2: medoid, 3: max_node_len, 4: nnodes_per_sector,
|
|
||||||
// 5: vamana_frozen_num, 6: vamana_frozen_loc, 7: append_reorder_data, ...
|
|
||||||
const uint64_t npts_64 = meta[0];
|
|
||||||
const uint64_t ndims_64 = meta[1];
|
|
||||||
const uint64_t medoid = meta[2];
|
|
||||||
const uint64_t max_node_len = meta[3];
|
|
||||||
const uint64_t nnodes_per_sector = meta[4];
|
|
||||||
const uint64_t vamana_frozen_num = meta[5];
|
|
||||||
const uint64_t vamana_frozen_loc = meta[6];
|
|
||||||
const uint64_t append_reorder_data = meta[7];
|
|
||||||
|
|
||||||
std::cout << "[LOG] npts_64=" << npts_64 << " ndims_64=" << ndims_64
|
|
||||||
<< " max_node_len=" << max_node_len
|
|
||||||
<< " nnodes_per_sector=" << nnodes_per_sector << std::endl;
|
|
||||||
// If append_reorder_data==1, it means that reorder_data is appended at the
|
|
||||||
// end of the index, but it does not affect the degree statistics, we can
|
|
||||||
// ignore that part of the vector.
|
|
||||||
|
|
||||||
// =========== 2) Skip the first sector (all empty/placeholder information)
|
|
||||||
// ==============
|
|
||||||
in.seekg(DISKANN_SECTOR_LEN, std::ios::beg);
|
|
||||||
if (!in.good()) {
|
|
||||||
std::cerr << "[ERROR] Failed to seek to the first sector." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// =========== 3) Calculate the total number of sectors ==============
|
|
||||||
// In create_disk_layout:
|
|
||||||
// If nnodes_per_sector > 0, then n_sectors = ceil(npts_64 /
|
|
||||||
// nnodes_per_sector) Otherwise nsectors_per_node = ceil(max_node_len /
|
|
||||||
// 4096), n_sectors = nsectors_per_node * npts_64
|
|
||||||
uint64_t n_sectors = 0;
|
|
||||||
if (nnodes_per_sector > 0) {
|
|
||||||
// Equivalent to Roundup(npts_64, nnodes_per_sector) / nnodes_per_sector
|
|
||||||
n_sectors = (npts_64 + nnodes_per_sector - 1) / nnodes_per_sector;
|
|
||||||
} else {
|
|
||||||
// multi-sector per node
|
|
||||||
uint64_t nsectors_per_node =
|
|
||||||
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
|
|
||||||
n_sectors = nsectors_per_node * npts_64;
|
|
||||||
}
|
|
||||||
std::cout << "[LOG] estimated #sectors storing adjacency = " << n_sectors
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
// =========== 4) Read the degree of all nodes in order ==============
|
|
||||||
// The memory layout of adjacency_count in each node: offset = ndims_64 *
|
|
||||||
// sizeof(CoordT) This is followed by 4 bytes for the number of neighbors
|
|
||||||
// uint32_t If you want to read the complete neighbor list, it is
|
|
||||||
// adjacency_count * sizeof(uint32_t) But we only count the count
|
|
||||||
std::vector<uint32_t> degrees(npts_64, 0); // Store the degree of each node
|
|
||||||
size_t node_id = 0; // Current node number
|
|
||||||
// Buffer for reading one sector at a time
|
|
||||||
std::vector<char> sector_buf(DISKANN_SECTOR_LEN, 0);
|
|
||||||
// If nnodes_per_sector>0, it means that one sector holds multiple nodes
|
|
||||||
// Otherwise, one node occupies nsectors_per_node sectors
|
|
||||||
if (nnodes_per_sector > 0) {
|
|
||||||
// Read one sector at a time
|
|
||||||
for (uint64_t s = 0; s < n_sectors; s++) {
|
|
||||||
in.read((char *)sector_buf.data(), DISKANN_SECTOR_LEN);
|
|
||||||
if (!in.good()) {
|
|
||||||
if (node_id < npts_64) {
|
|
||||||
std::cerr << "[ERROR] Failed to read sector " << s
|
|
||||||
<< ", nodes not finished, file error or incomplete."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
break; // If all nodes are read, you can exit
|
|
||||||
}
|
|
||||||
// Parse each node in sector_buf
|
|
||||||
for (uint64_t i = 0; i < nnodes_per_sector; i++) {
|
|
||||||
if (node_id >= npts_64)
|
|
||||||
break; // All node degrees have been obtained
|
|
||||||
// The starting offset of the node in sector_buf
|
|
||||||
size_t node_offset = i * max_node_len;
|
|
||||||
// offset first skips ndims_64 * sizeof(CoordT)
|
|
||||||
size_t degree_offset = node_offset + ndims_64 * sizeof(CoordT);
|
|
||||||
// Ensure not out of bounds
|
|
||||||
if (degree_offset + sizeof(uint32_t) > sector_buf.size()) {
|
|
||||||
std::cerr << "[ERROR] 不应该发生: 读取degree越过了扇区边界."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
uint32_t deg = 0;
|
|
||||||
memcpy(°, sector_buf.data() + degree_offset, sizeof(uint32_t));
|
|
||||||
degrees[node_id] = deg;
|
|
||||||
node_id++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Each node occupies nsectors_per_node sectors
|
|
||||||
uint64_t nsectors_per_node =
|
|
||||||
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
|
|
||||||
// Read each node
|
|
||||||
for (uint64_t n = 0; n < npts_64; n++) {
|
|
||||||
// Read multiple sectors into a multi-sector buffer
|
|
||||||
std::vector<char> node_buf(nsectors_per_node * DISKANN_SECTOR_LEN, 0);
|
|
||||||
in.read((char *)node_buf.data(), node_buf.size());
|
|
||||||
if (!in.good()) {
|
|
||||||
std::cerr << "[ERROR] Failed to read sector corresponding to node " << n
|
|
||||||
<< ", file error or incomplete." << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
// Parse the degree in node_buf
|
|
||||||
size_t degree_offset = ndims_64 * sizeof(CoordT);
|
|
||||||
if (degree_offset + sizeof(uint32_t) > node_buf.size()) {
|
|
||||||
std::cerr << "[ERROR] Should not happen: reading degree beyond node "
|
|
||||||
"region."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
uint32_t deg = 0;
|
|
||||||
memcpy(°, node_buf.data() + degree_offset, sizeof(uint32_t));
|
|
||||||
degrees[n] = deg;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We assert here: node_id should equal npts_64 (in multi-node mode)
|
|
||||||
if (nnodes_per_sector > 0) {
|
|
||||||
if (node_id != npts_64) {
|
|
||||||
std::cerr << "[ERROR] Actually read " << node_id
|
|
||||||
<< " nodes, but meta npts_64=" << npts_64
|
|
||||||
<< ", file may be incorrect or parsing method is wrong."
|
|
||||||
<< std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// =========== 5) Calculate min / max / average degree ==============
|
|
||||||
uint64_t sum_deg = 0;
|
|
||||||
uint32_t min_deg = std::numeric_limits<uint32_t>::max();
|
|
||||||
uint32_t max_deg = 0;
|
|
||||||
|
|
||||||
for (uint64_t n = 0; n < npts_64; n++) {
|
|
||||||
uint32_t d = degrees[n];
|
|
||||||
sum_deg += d;
|
|
||||||
if (d < min_deg)
|
|
||||||
min_deg = d;
|
|
||||||
if (d > max_deg)
|
|
||||||
max_deg = d;
|
|
||||||
}
|
|
||||||
double avg_deg = (npts_64 == 0) ? 0.0 : double(sum_deg) / double(npts_64);
|
|
||||||
|
|
||||||
// =========== 6) Output results ==============
|
|
||||||
std::cout << "DiskANN index file: " << disk_index_path << std::endl;
|
|
||||||
std::cout << "Total points: " << npts_64 << std::endl;
|
|
||||||
std::cout << "Min degree : " << min_deg << std::endl;
|
|
||||||
std::cout << "Max degree : " << max_deg << std::endl;
|
|
||||||
std::cout << "Avg degree : " << avg_deg << std::endl;
|
|
||||||
|
|
||||||
// =========== 7) Write degrees to output file ==============
|
|
||||||
std::ofstream out_deg(output_degree_path);
|
|
||||||
if (!out_deg.is_open()) {
|
|
||||||
std::cerr << "[ERROR] Failed to open output file: " << output_degree_path
|
|
||||||
<< std::endl;
|
|
||||||
// Don't necessarily exit, maybe just warn? Depends on desired behavior.
|
|
||||||
// For now, we continue closing the input file.
|
|
||||||
} else {
|
|
||||||
std::cout << "[LOG] Writing degrees to " << output_degree_path << "..."
|
|
||||||
<< std::endl;
|
|
||||||
for (uint64_t n = 0; n < npts_64; n++) {
|
|
||||||
out_deg << degrees[n] << std::endl;
|
|
||||||
}
|
|
||||||
out_deg.close();
|
|
||||||
std::cout << "[LOG] Finished writing degrees." << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
in.close();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
# 设置风格
|
|
||||||
plt.style.use('ggplot')
|
|
||||||
sns.set(font_scale=1.2)
|
|
||||||
|
|
||||||
# 读取数据 - 修改为自定义读取逻辑
|
|
||||||
log_file = './top3_positions_log.txt'
|
|
||||||
|
|
||||||
# 手动解析文件
|
|
||||||
data = []
|
|
||||||
header = None
|
|
||||||
with open(log_file, 'r') as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
header = lines[0].strip().split(',')
|
|
||||||
|
|
||||||
# 检查是否存在ThreadID列
|
|
||||||
has_thread_id = 'ThreadID' in header
|
|
||||||
|
|
||||||
for line in lines[1:]:
|
|
||||||
# 跳过非数据行,如"Search X results:"
|
|
||||||
if 'results:' in line or not ',' in line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 分割并解析数据行
|
|
||||||
parts = line.strip().split(',')
|
|
||||||
|
|
||||||
# 检查数据是否符合格式
|
|
||||||
if len(parts) >= 7: # 至少需要7个字段
|
|
||||||
# 对于旧格式(无ThreadID)的数据
|
|
||||||
if not has_thread_id and len(parts) == 7:
|
|
||||||
data.append([parts[0], 0, parts[1], parts[2], parts[3], parts[4], parts[5], parts[6]])
|
|
||||||
# 对于新格式(有ThreadID)的数据
|
|
||||||
elif has_thread_id and len(parts) == 8:
|
|
||||||
data.append(parts)
|
|
||||||
# 处理不一致的格式
|
|
||||||
elif not has_thread_id and len(parts) == 8:
|
|
||||||
# 假设第二列是ThreadID
|
|
||||||
data.append(parts)
|
|
||||||
if not has_thread_id:
|
|
||||||
has_thread_id = True
|
|
||||||
header.insert(1, 'ThreadID')
|
|
||||||
|
|
||||||
# 确保header正确
|
|
||||||
if not has_thread_id:
|
|
||||||
header.insert(1, 'ThreadID')
|
|
||||||
|
|
||||||
# 创建DataFrame并确保列名正确
|
|
||||||
if len(header) == 8: # 确保有8列
|
|
||||||
df = pd.DataFrame(data, columns=header)
|
|
||||||
else:
|
|
||||||
# 如果header不正确,则使用默认列名
|
|
||||||
default_header = ['Search#', 'ThreadID', 'FullSetSize', 'Rank', 'ID', 'PQ_Rank', 'PQ_Distance', 'Exact_Distance']
|
|
||||||
df = pd.DataFrame(data, columns=default_header)
|
|
||||||
|
|
||||||
# 转换数值列
|
|
||||||
df['Search#'] = pd.to_numeric(df['Search#'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['ThreadID'] = pd.to_numeric(df['ThreadID'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['FullSetSize'] = pd.to_numeric(df['FullSetSize'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['Rank'] = pd.to_numeric(df['Rank'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['ID'] = pd.to_numeric(df['ID'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['PQ_Rank'] = pd.to_numeric(df['PQ_Rank'], errors='coerce').fillna(0).astype(int)
|
|
||||||
df['PQ_Distance'] = pd.to_numeric(df['PQ_Distance'], errors='coerce').fillna(0).astype(float)
|
|
||||||
df['Exact_Distance'] = pd.to_numeric(df['Exact_Distance'], errors='coerce').fillna(0).astype(float)
|
|
||||||
|
|
||||||
print(f"读取了 {len(df)} 行数据")
|
|
||||||
print(f"搜索次数: {df['Search#'].nunique()}")
|
|
||||||
print(f"线程数: {df['ThreadID'].nunique()}")
|
|
||||||
|
|
||||||
# 提取前3名的结果
|
|
||||||
top3_df = df[df['Rank'] <= 3].copy()
|
|
||||||
|
|
||||||
# 分析PQ Rank的分布
|
|
||||||
pq_positions = []
|
|
||||||
for rank in [1, 2, 3]:
|
|
||||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
|
||||||
pq_positions.append(rank_df['PQ_Rank'].values)
|
|
||||||
|
|
||||||
# 创建结果目录
|
|
||||||
result_dir = './analysis_results'
|
|
||||||
os.makedirs(result_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 1. 箱型图:展示top-3结果在PQ排序中的位置分布
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
box_data = [top3_df[top3_df['Rank'] == i]['PQ_Rank'].values for i in [1, 2, 3]]
|
|
||||||
sns.boxplot(data=box_data)
|
|
||||||
plt.xticks([0, 1, 2], ['Top 1', 'Top 2', 'Top 3'])
|
|
||||||
plt.ylabel('PQ Rank Position')
|
|
||||||
plt.title('Distribution of PQ Ranks for Top-3 Exact Results')
|
|
||||||
plt.savefig(os.path.join(result_dir, 'pq_rank_boxplot.png'), dpi=300)
|
|
||||||
|
|
||||||
# 2. 直方图:每个排名在PQ结果中的位置分布
|
|
||||||
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
|
|
||||||
for i, rank in enumerate([1, 2, 3]):
|
|
||||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
|
||||||
sns.histplot(x=rank_df['PQ_Rank'].values, bins=20, ax=axs[i])
|
|
||||||
axs[i].set_title(f'Exact Rank {rank}')
|
|
||||||
axs[i].set_xlabel('PQ Rank')
|
|
||||||
axs[i].set_ylabel('Frequency')
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(os.path.join(result_dir, 'pq_rank_histogram.png'), dpi=300)
|
|
||||||
|
|
||||||
# 3. 热力图:PQ排名与精确排名的关系
|
|
||||||
plt.figure(figsize=(10, 8))
|
|
||||||
# 只关注Top 20的排名
|
|
||||||
bins = list(range(0, 22))
|
|
||||||
pq_rank_bins = pd.cut(top3_df['PQ_Rank'], bins=bins)
|
|
||||||
heatmap_data = pd.crosstab(pq_rank_bins, top3_df['Rank'])
|
|
||||||
sns.heatmap(heatmap_data, cmap='YlGnBu', annot=True, fmt='d')
|
|
||||||
plt.title('Heatmap of Exact Rank vs PQ Rank (Top 20)')
|
|
||||||
plt.xlabel('Exact Rank')
|
|
||||||
plt.ylabel('PQ Rank Range')
|
|
||||||
plt.savefig(os.path.join(result_dir, 'rank_heatmap.png'), dpi=300)
|
|
||||||
|
|
||||||
# 4. 散点图:比较PQ距离和精确距离的关系
|
|
||||||
plt.figure(figsize=(10, 8))
|
|
||||||
sns.scatterplot(x=top3_df['Exact_Distance'], y=top3_df['PQ_Distance'], hue=top3_df['Rank'], palette='viridis')
|
|
||||||
plt.title('PQ Distance vs Exact Distance')
|
|
||||||
plt.xlabel('Exact Distance')
|
|
||||||
plt.ylabel('PQ Distance')
|
|
||||||
plt.legend(title='Exact Rank')
|
|
||||||
# 添加对角线表示完美匹配
|
|
||||||
min_val = min(top3_df['Exact_Distance'].min(), top3_df['PQ_Distance'].min())
|
|
||||||
max_val = max(top3_df['Exact_Distance'].max(), top3_df['PQ_Distance'].max())
|
|
||||||
plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
|
|
||||||
plt.savefig(os.path.join(result_dir, 'distance_scatter.png'), dpi=300)
|
|
||||||
|
|
||||||
# 5. 折线图:PQ Rank随结果集大小的变化
|
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
size_grouped = top3_df.groupby(['FullSetSize', 'Rank'])['PQ_Rank'].mean().reset_index()
|
|
||||||
for rank in [1, 2, 3]:
|
|
||||||
rank_data = size_grouped[size_grouped['Rank'] == rank]
|
|
||||||
plt.plot(rank_data['FullSetSize'], rank_data['PQ_Rank'], marker='o', label=f'Rank {rank}')
|
|
||||||
plt.xlabel('Result Set Size')
|
|
||||||
plt.ylabel('Average PQ Rank')
|
|
||||||
plt.title('Average PQ Rank by Result Set Size')
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True)
|
|
||||||
plt.savefig(os.path.join(result_dir, 'pq_rank_by_size.png'), dpi=300)
|
|
||||||
|
|
||||||
# 6. 百分比热力图:在PQ排名前K的概率
|
|
||||||
top_k_values = [1, 5, 10, 20, 50, 100, 200, 300, 500, 700, 800, 900]
|
|
||||||
top_k_probs = []
|
|
||||||
|
|
||||||
for rank in [1, 2, 3]:
|
|
||||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
|
||||||
probs = []
|
|
||||||
for k in top_k_values:
|
|
||||||
prob = (rank_df['PQ_Rank'] <= k).mean() * 100
|
|
||||||
probs.append(prob)
|
|
||||||
top_k_probs.append(probs)
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
sns.heatmap(top_k_probs, annot=True, fmt='.1f', cmap='YlGnBu',
|
|
||||||
xticklabels=[f'Top-{k}' for k in top_k_values],
|
|
||||||
yticklabels=['Rank 1', 'Rank 2', 'Rank 3'])
|
|
||||||
plt.title('Probability (%) of Finding Exact Top-K Results in PQ Top-K')
|
|
||||||
plt.xlabel('PQ Top-K')
|
|
||||||
plt.ylabel('Exact Rank')
|
|
||||||
plt.savefig(os.path.join(result_dir, 'topk_probability.png'), dpi=300)
|
|
||||||
|
|
||||||
# 7. 生成统计摘要报告
|
|
||||||
with open(os.path.join(result_dir, 'summary_report.txt'), 'w') as f:
|
|
||||||
f.write(f"数据分析摘要\n")
|
|
||||||
f.write(f"=================\n")
|
|
||||||
f.write(f"总搜索次数: {df['Search#'].nunique()}\n")
|
|
||||||
f.write(f"使用线程数: {df['ThreadID'].nunique()}\n\n")
|
|
||||||
|
|
||||||
f.write("精确排名前3的结果在PQ排序中的平均位置:\n")
|
|
||||||
for rank in [1, 2, 3]:
|
|
||||||
avg_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].mean()
|
|
||||||
median_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].median()
|
|
||||||
f.write(f" 排名 {rank}: 平均位置 = {avg_pq_rank:.2f}, 中位数位置 = {median_pq_rank:.1f}\n")
|
|
||||||
|
|
||||||
f.write("\n各排名结果在PQ排序前K的命中率:\n")
|
|
||||||
for rank in [1, 2, 3]:
|
|
||||||
f.write(f" 排名 {rank}:\n")
|
|
||||||
for k in top_k_values:
|
|
||||||
hit_rate = (top3_df[top3_df['Rank'] == rank]['PQ_Rank'] <= k).mean() * 100
|
|
||||||
f.write(f" 在PQ前 {k} 中的命中率: {hit_rate:.2f}%\n")
|
|
||||||
|
|
||||||
print(f"分析完成! 结果已保存到 {result_dir} 目录")
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from tqdm import tqdm
|
|
||||||
import json
|
|
||||||
from contriever.src.contriever import load_retriever
|
|
||||||
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
||||||
os.environ["OMP_NUM_THREADS"] = "1"
|
|
||||||
os.environ["KMP_BLOCKTIME"] = "0"
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
|
||||||
"""Embed queries using the model with batching"""
|
|
||||||
model = model.half()
|
|
||||||
model.eval()
|
|
||||||
embeddings = []
|
|
||||||
batch_question = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
|
||||||
batch_question.append(query)
|
|
||||||
|
|
||||||
# Process when batch is full or at the end
|
|
||||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
|
||||||
encoded_batch = tokenizer.batch_encode_plus(
|
|
||||||
batch_question,
|
|
||||||
return_tensors="pt",
|
|
||||||
max_length=512,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
||||||
output = model(**encoded_batch)
|
|
||||||
|
|
||||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
|
||||||
# if "contriever" not in model_name_or_path:
|
|
||||||
# output = output.last_hidden_state[:, 0, :]
|
|
||||||
|
|
||||||
embeddings.append(output.cpu())
|
|
||||||
batch_question = [] # Reset batch
|
|
||||||
|
|
||||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
|
||||||
print(f"Query embeddings shape: {embeddings.shape}")
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Debug embedding tool")
|
|
||||||
parser.add_argument("--model", type=str, default="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
|
||||||
help="Model name for embedding (default: facebook/contriever-msmarco)")
|
|
||||||
parser.add_argument("--batch-size", type=int, default=32,
|
|
||||||
help="Batch size for encoding (default: 32)")
|
|
||||||
parser.add_argument("--input-file", type=str,
|
|
||||||
help="Input file with queries (JSON lines format with 'query' field)")
|
|
||||||
parser.add_argument("--output-file", type=str, default="embeddings.npy",
|
|
||||||
help="Output file to save embeddings (default: embeddings.npy)")
|
|
||||||
parser.add_argument("--text", type=str, nargs="+",
|
|
||||||
help="Direct text input to embed (can provide multiple)")
|
|
||||||
parser.add_argument("--save-text", action="store_true",
|
|
||||||
help="Save the input text alongside embeddings")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
print(f"Loading query encoder: {args.model}")
|
|
||||||
query_encoder, query_tokenizer, _ = load_retriever(args.model)
|
|
||||||
query_encoder = query_encoder.to(device)
|
|
||||||
query_encoder.eval()
|
|
||||||
|
|
||||||
# Get queries
|
|
||||||
queries = []
|
|
||||||
|
|
||||||
# From file if provided
|
|
||||||
if args.input_file:
|
|
||||||
print(f"Loading queries from: {args.input_file}")
|
|
||||||
with open(args.input_file, "r") as f:
|
|
||||||
for line in f:
|
|
||||||
data = json.loads(line)
|
|
||||||
queries.append(data["query"])
|
|
||||||
|
|
||||||
# From command line if provided
|
|
||||||
if args.text:
|
|
||||||
print(f"Using {len(args.text)} queries from command line")
|
|
||||||
queries.extend(args.text)
|
|
||||||
|
|
||||||
# If no queries, use some examples
|
|
||||||
if not queries:
|
|
||||||
print("No queries provided, using example queries")
|
|
||||||
queries = [
|
|
||||||
"Were there any variances detected for hour 6 on 3/9/01?"
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Embedding {len(queries)} queries")
|
|
||||||
for i, q in enumerate(queries[:5]): # Print first 5 queries
|
|
||||||
print(f"Query {i+1}: {q}")
|
|
||||||
if len(queries) > 5:
|
|
||||||
print(f"... and {len(queries)-5} more")
|
|
||||||
|
|
||||||
# Encode queries
|
|
||||||
embeddings = embed_queries(
|
|
||||||
queries, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
passages = [
|
|
||||||
"Start Date: 3/9/01; HourAhead hour: 6; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. Variances detected in Energy Import/Export schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001030906.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 20.00 / Final: 19.80) TRANS_TYPE: FINAL SC_ID: TOSCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: UNCHEM_1_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 29.00 / Final: 28.20) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 45.00 / Final: 43.80) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: PANDOL_6_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 13.00 / Final: 12.60) TRANS_TYPE: FINAL SC_ID: Wheelabrat MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: MARTEL_2_AMFOR ---- Energy Import/Export Schedule ---- $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 62.00 / Final: 60.40) TRANS_TYPE: FINAL SC_ID: ECTstCA MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5001 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 63.00 / Final: 61.23) TRANS_TYPE: FINAL SC_ID: ECTstSW MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5000 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 17.00 / Final: 11.00) TRANS_TYPE: FINAL SC_ID: ECTRT MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: SYLMAR_2_NOB INTERCHG_ID: EPMI_CISO_LUCKY ENGY_TYPE: NFRM",
|
|
||||||
"Start Date: 3/30/01; HourAhead hour: 15; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001033015.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 0.00 / Final: 0.00) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 45.00 / Final: 44.00) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: PANDOL_6_UNIT"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Embed passages
|
|
||||||
passage_embeddings = embed_queries(passages, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
# distance with passages 0 and query
|
|
||||||
distance_0 = np.linalg.norm(embeddings[0] - passage_embeddings[0])
|
|
||||||
print(f"Distance between query 0 and passage 0: {distance_0}")
|
|
||||||
|
|
||||||
# distance with passages 1 and query
|
|
||||||
distance_1 = np.linalg.norm(embeddings[0] - passage_embeddings[1])
|
|
||||||
print(f"Distance between query 0 and passage 1: {distance_1}")
|
|
||||||
|
|
||||||
# print which one is closer
|
|
||||||
if distance_0 < distance_1:
|
|
||||||
print("Query 0 is closer to passage 0")
|
|
||||||
else:
|
|
||||||
print("Query 0 is closer to passage 1")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
input_file = "/gscratch/zlab/rulins/data/lm-eval-data/raw_mmlu.jsonl"
|
|
||||||
output_file = "/gscratch/zlab/rulins/data/lm-eval-data/mmlu.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
raw_data = []
|
|
||||||
|
|
||||||
with open(input_file, "r") as fin:
|
|
||||||
for line in fin:
|
|
||||||
raw_data.append(json.loads(line))
|
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_dicts(dict_list):
|
|
||||||
unique_dicts = set()
|
|
||||||
unique_items = []
|
|
||||||
for item in dict_list:
|
|
||||||
# Make a hashable version of the dictionary by sorting it
|
|
||||||
hashable_item = tuple(sorted(item.items()))
|
|
||||||
if hashable_item not in unique_dicts:
|
|
||||||
unique_dicts.add(hashable_item)
|
|
||||||
unique_items.append(item)
|
|
||||||
return unique_items
|
|
||||||
|
|
||||||
|
|
||||||
unique_data = deduplicate_dicts(raw_data)
|
|
||||||
print(len(unique_data))
|
|
||||||
|
|
||||||
with open(output_file, "w") as fout:
|
|
||||||
for ex in unique_data:
|
|
||||||
fout.write(json.dumps(ex) + "\n")
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
import time
|
|
||||||
import multiprocessing
|
|
||||||
from datasketch import MinHash, MinHashLSH
|
|
||||||
|
|
||||||
|
|
||||||
def shingle_document(text, shingle_size=13):
|
|
||||||
"""Generate word-based shingles from a document."""
|
|
||||||
# Split the text into words
|
|
||||||
words = text.split()
|
|
||||||
# Generate shingles that are sequences of 'shingle_size' consecutive words
|
|
||||||
shingles = set(
|
|
||||||
" ".join(words[i : i + shingle_size])
|
|
||||||
for i in range(len(words) - shingle_size + 1)
|
|
||||||
)
|
|
||||||
return shingles
|
|
||||||
|
|
||||||
|
|
||||||
m = MinHash(num_perm=128)
|
|
||||||
perm = m.permutations
|
|
||||||
|
|
||||||
|
|
||||||
def create_minhash(shingles, num_perm=128):
|
|
||||||
"""Create a MinHash object from the set of shingles."""
|
|
||||||
m = MinHash(permutations=perm)
|
|
||||||
m.update_batch(map(lambda x: x.encode("utf-8"), shingles))
|
|
||||||
# for shingle in shingles:
|
|
||||||
# m.update(shingle.encode('utf-8'))
|
|
||||||
return m
|
|
||||||
|
|
||||||
|
|
||||||
def abstein_string_for_decon(string):
|
|
||||||
# Abstein the reading comprehension subject in MMLU where a paragraph from Wikipedia is given in the question
|
|
||||||
return "refers to the following information" in string
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicates_with_minhash(
|
|
||||||
documents, string_for_decontamination=None, threshold=0.8, num_perm=128
|
|
||||||
):
|
|
||||||
# Apply 13-gram Jaccard similarity deduplication and removes ones with similarity > 80% compared to former docs.
|
|
||||||
# Remove chunks shorter than 13 words.
|
|
||||||
|
|
||||||
# Create an LSH index
|
|
||||||
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
|
||||||
|
|
||||||
# Dictionary to store the MinHash of each document
|
|
||||||
minhashes = {}
|
|
||||||
|
|
||||||
# Hash string for decontamination first so contaminated samples will be removed
|
|
||||||
decon_offset = 0
|
|
||||||
if string_for_decontamination is not None and not abstein_string_for_decon(
|
|
||||||
string_for_decontamination
|
|
||||||
):
|
|
||||||
shingles = shingle_document(string_for_decontamination)
|
|
||||||
m_decon = create_minhash(shingles, num_perm)
|
|
||||||
lsh.insert(f"doc_{decon_offset}", m_decon)
|
|
||||||
minhashes[decon_offset] = m_decon
|
|
||||||
decon_offset = 1
|
|
||||||
|
|
||||||
# Populate the LSH index
|
|
||||||
short_chunk_indices = []
|
|
||||||
for idx, ctx in enumerate(documents, start=decon_offset):
|
|
||||||
doc = ctx["retrieval text"]
|
|
||||||
shingles = shingle_document(doc)
|
|
||||||
if not shingles:
|
|
||||||
short_chunk_indices.append(idx - decon_offset)
|
|
||||||
m = create_minhash(shingles, num_perm)
|
|
||||||
lsh.insert(f"doc_{idx}", m)
|
|
||||||
minhashes[idx] = m
|
|
||||||
|
|
||||||
# List to keep track of non-duplicate document indices
|
|
||||||
non_duplicate_indices = []
|
|
||||||
|
|
||||||
# Check each document against the LSH index
|
|
||||||
for idx, m in minhashes.items():
|
|
||||||
if idx < decon_offset:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Query the LSH for near-duplicate candidates
|
|
||||||
result = lsh.query(m)
|
|
||||||
|
|
||||||
# print(result)
|
|
||||||
# print([minhashes[int(doc_id.split("_")[1])].jaccard(m) for doc_id in result])
|
|
||||||
|
|
||||||
# If the document is the only one in its bucket or it appears first in the list
|
|
||||||
if all(
|
|
||||||
minhashes[int(doc_id.split("_")[1])].jaccard(m) <= threshold
|
|
||||||
or int(doc_id.split("_")[1]) >= idx
|
|
||||||
for doc_id in result
|
|
||||||
):
|
|
||||||
non_duplicate_indices.append(idx - decon_offset)
|
|
||||||
|
|
||||||
# Return non-duplicate documents
|
|
||||||
deduplicated_documents = [
|
|
||||||
documents[i] for i in non_duplicate_indices if i not in short_chunk_indices
|
|
||||||
]
|
|
||||||
[doc.update({"quality score": 1}) for doc in deduplicated_documents]
|
|
||||||
removed_documents = [doc for doc in documents if doc not in deduplicated_documents]
|
|
||||||
[doc.update({"quality score": 0}) for doc in removed_documents]
|
|
||||||
|
|
||||||
print(f"Non-deduplication ctxs num: {len(deduplicated_documents)}")
|
|
||||||
# for c in deduplicated_documents:
|
|
||||||
# try:
|
|
||||||
# print(c['retrieval text'][:10])
|
|
||||||
# except:
|
|
||||||
# print(c)
|
|
||||||
# if len(deduplicated_documents[0]['retrieval text'].split(' ')) < 13:
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
return deduplicated_documents # + removed_documents
|
|
||||||
|
|
||||||
|
|
||||||
def process_item(data_item):
|
|
||||||
time.sleep(0.0001)
|
|
||||||
id_, ex = data_item
|
|
||||||
ex["ctxs"] = remove_duplicates_with_minhash(
|
|
||||||
ex["ctxs"], string_for_decontamination=ex["raw_query"]
|
|
||||||
)
|
|
||||||
return id_, ex
|
|
||||||
|
|
||||||
|
|
||||||
def multiprocess_deduplication(data):
|
|
||||||
items_to_process = list(enumerate(data))
|
|
||||||
pool = multiprocessing.Pool(processes=32)
|
|
||||||
for result in pool.imap(process_item, items_to_process):
|
|
||||||
id_, updated_ex = result
|
|
||||||
data[id_] = updated_ex
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Example usage:
|
|
||||||
question = (
|
|
||||||
"Answer these questions:\n\nQ: when did the eagles win last super bowl?\nA:"
|
|
||||||
)
|
|
||||||
docs = [
|
|
||||||
"Eagles won the Super Bowl.",
|
|
||||||
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
|
|
||||||
* 20,
|
|
||||||
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
|
|
||||||
* 20
|
|
||||||
+ ".",
|
|
||||||
"An entirely different document looks nothing like the others and should not be considered a duplicate."
|
|
||||||
* 20,
|
|
||||||
"Short sentence." * 1,
|
|
||||||
"As someone who lived in Philly for about five years, I agree about the city\u2019s greatness \u2014 which makes the juxtaposition between its friendly day-to-day interactions and sometimes psychotic sports fandom even more jarring. The Eagles did win three NFL championships before the Super Bowl existed, most recently in 1960. But any fan who was following the team back then is now at least into their mid-60s, if not much older. It is, to say the least, a distant memory from another era. Granted, the Sixers went on their infamous tanking expedition during this span.",
|
|
||||||
] * 1
|
|
||||||
import time
|
|
||||||
|
|
||||||
num_ex = 1
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
data1 = []
|
|
||||||
for _ in range(num_ex):
|
|
||||||
cleaned_ex = remove_duplicates_with_minhash(
|
|
||||||
[{"retrieval text": doc} for doc in docs], question
|
|
||||||
)
|
|
||||||
data1.append(cleaned_ex)
|
|
||||||
time1 = time.time() - start
|
|
||||||
|
|
||||||
# ori_data = [{'raw_query': docs[0], 'ctxs': [{'retrieval text': doc} for doc in docs]}] * num_ex
|
|
||||||
# start = time.time()
|
|
||||||
# data2 = multiprocess_deduplication(ori_data)
|
|
||||||
# time2 = time.time()-start
|
|
||||||
|
|
||||||
# assert data2[0]['ctxs'] == data1[0]
|
|
||||||
|
|
||||||
# print(time1)
|
|
||||||
# print(time2)
|
|
||||||
@@ -1,387 +0,0 @@
|
|||||||
/*
|
|
||||||
Run with
|
|
||||||
g++ ./demo_reader.cpp -o ./demo_reader && ./demo_reader --stats \
|
|
||||||
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin
|
|
||||||
\
|
|
||||||
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
|
||||||
#include <limits> // Include for std::numeric_limits
|
|
||||||
#include <string> // Include for std::string comparison
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#define READ_U64(f, val) \
|
|
||||||
f.read(reinterpret_cast<char *>(&val), sizeof(uint64_t))
|
|
||||||
#define READ_U32(f, val) \
|
|
||||||
f.read(reinterpret_cast<char *>(&val), sizeof(uint32_t))
|
|
||||||
#define SECTOR_SIZE 4096
|
|
||||||
|
|
||||||
// Helper: Get file size
|
|
||||||
static size_t get_file_size(const std::string &fname) {
|
|
||||||
std::ifstream ifs(fname, std::ios::binary | std::ios::ate);
|
|
||||||
if (ifs.fail() || !ifs.is_open()) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return static_cast<size_t>(ifs.tellg());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print first few hex of sector for debug
|
|
||||||
static void print_hex(const char *buf, size_t len, size_t max_len = 64) {
|
|
||||||
size_t show_len = (len < max_len) ? len : max_len;
|
|
||||||
for (size_t i = 0; i < show_len; i++) {
|
|
||||||
unsigned char c = (unsigned char)buf[i];
|
|
||||||
std::cout << std::hex << std::setw(2) << std::setfill('0') << (unsigned)c
|
|
||||||
<< " ";
|
|
||||||
if ((i + 1) % 16 == 0)
|
|
||||||
std::cout << "\n ";
|
|
||||||
}
|
|
||||||
std::cout << std::dec << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Corrected demo_reader:
|
|
||||||
1) Read from partition.bin:
|
|
||||||
- C, partition_nums, nd
|
|
||||||
- graph_partitions[i]: all nodeIDs in partition i
|
|
||||||
- id2partition[nodeID]: nodeID => partition i
|
|
||||||
2) Read from _disk_graph.index:
|
|
||||||
a) sector0 first has 2 ints: meta_n, meta_dim
|
|
||||||
b) then meta_n uint64_t
|
|
||||||
e.g.: [0]=nd, [1]=dim, [2]=??, [3]=max_node_len, [4]=C, [5]..??,
|
|
||||||
[8]=file_size... specific positions need to be combined with relayout writing c) graph_node_len =
|
|
||||||
max_node_len - dim_in_meta*sizeof(float) 3) User given target_node_id =>
|
|
||||||
partition_id= id2partition[node_id]
|
|
||||||
find node index j in graph_partitions[partition_id]
|
|
||||||
offset = (partition_id+1)*4096 => sector
|
|
||||||
adjacency_offset= j*graph_node_len => neighbor_count => neighbors
|
|
||||||
*/
|
|
||||||
int main(int argc, char **argv) {
|
|
||||||
bool calculate_stats = false;
|
|
||||||
// int arg_offset = 0; // Offset for positional arguments
|
|
||||||
std::string partition_bin;
|
|
||||||
std::string graph_index;
|
|
||||||
uint64_t target_node_id = 0; // Initialize
|
|
||||||
|
|
||||||
if (argc != 4) {
|
|
||||||
std::cerr << "Usage:\n"
|
|
||||||
<< " " << argv[0]
|
|
||||||
<< " <partition.bin> <disk_graph.index> <target_node_id> (Reads "
|
|
||||||
"adjacency for a node)\n"
|
|
||||||
<< " " << argv[0]
|
|
||||||
<< " --stats <partition.bin> <disk_graph.index> "
|
|
||||||
"(Calculates degree statistics)\n";
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the first argument is the stats flag
|
|
||||||
if (std::string(argv[1]) == "--stats") {
|
|
||||||
calculate_stats = true;
|
|
||||||
partition_bin = argv[2];
|
|
||||||
graph_index = argv[3];
|
|
||||||
std::cout << "Mode: Calculating Degree Statistics\n";
|
|
||||||
} else {
|
|
||||||
// Assume default mode (single node lookup)
|
|
||||||
calculate_stats = false;
|
|
||||||
partition_bin = argv[1];
|
|
||||||
graph_index = argv[2];
|
|
||||||
try { // Add error handling for stoull
|
|
||||||
target_node_id = std::stoull(argv[3]);
|
|
||||||
} catch (const std::invalid_argument &ia) {
|
|
||||||
std::cerr << "Error: Invalid target_node_id: " << argv[3] << std::endl;
|
|
||||||
return 1;
|
|
||||||
} catch (const std::out_of_range &oor) {
|
|
||||||
std::cerr << "Error: target_node_id out of range: " << argv[3]
|
|
||||||
<< std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
std::cout << "Mode: Single Node Lookup for Node ID " << target_node_id
|
|
||||||
<< "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1) Read partition.bin
|
|
||||||
std::ifstream pf(partition_bin, std::ios::binary);
|
|
||||||
if (!pf.is_open()) {
|
|
||||||
std::cerr << "Cannot open partition.bin: " << partition_bin << std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint64_t C, partition_nums, nd;
|
|
||||||
READ_U64(pf, C);
|
|
||||||
READ_U64(pf, partition_nums);
|
|
||||||
READ_U64(pf, nd);
|
|
||||||
std::cout << "[partition.bin header] C=" << C
|
|
||||||
<< ", partition_nums=" << partition_nums << ", nd=" << nd
|
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
// Read partition node lists
|
|
||||||
std::vector<std::vector<uint32_t> > graph_partitions(partition_nums);
|
|
||||||
for (uint64_t i = 0; i < partition_nums; i++) {
|
|
||||||
uint32_t psize;
|
|
||||||
READ_U32(pf, psize);
|
|
||||||
graph_partitions[i].resize(psize);
|
|
||||||
pf.read(reinterpret_cast<char *>(graph_partitions[i].data()),
|
|
||||||
psize * sizeof(uint32_t));
|
|
||||||
}
|
|
||||||
// Read _id2partition[node], size= nd
|
|
||||||
std::vector<uint32_t> id2partition(nd);
|
|
||||||
pf.read(reinterpret_cast<char *>(id2partition.data()), nd * sizeof(uint32_t));
|
|
||||||
pf.close();
|
|
||||||
std::cout << "Done loading partition info.\n";
|
|
||||||
|
|
||||||
if (target_node_id >= nd) {
|
|
||||||
std::cerr << "target_node_id=" << target_node_id
|
|
||||||
<< " out of range nd=" << nd << std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2) Parse _disk_graph.index
|
|
||||||
std::ifstream gf(graph_index, std::ios::binary);
|
|
||||||
if (!gf.is_open()) {
|
|
||||||
std::cerr << "Cannot open disk_graph.index: " << graph_index << std::endl;
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
// (a) sector0 => first read 2 ints
|
|
||||||
int meta_n, meta_dim;
|
|
||||||
gf.read((char *)&meta_n, sizeof(int));
|
|
||||||
gf.read((char *)&meta_dim, sizeof(int));
|
|
||||||
std::cout << "[debug] meta_n=" << meta_n << ", meta_dim=" << meta_dim << "\n";
|
|
||||||
|
|
||||||
// (b) Read meta_n uint64_t
|
|
||||||
std::vector<uint64_t> meta_info(meta_n);
|
|
||||||
gf.read(reinterpret_cast<char *>(meta_info.data()),
|
|
||||||
meta_n * sizeof(uint64_t));
|
|
||||||
// Print
|
|
||||||
for (int i = 0; i < meta_n; i++) {
|
|
||||||
std::cout << " meta_info[" << i << "]= " << meta_info[i] << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t file_size = get_file_size(graph_index);
|
|
||||||
std::cout << "[disk_graph.index size] " << file_size << " bytes\n";
|
|
||||||
|
|
||||||
// **According to relayout log** you said: meta_info[0]=nd=60450220, meta_info[1]=dim=769,
|
|
||||||
// meta_info[2]=??(16495248?), meta_info[3]=max_node_len=3320,
|
|
||||||
// meta_info[4]=16 (C),
|
|
||||||
// meta_info[8]= 15475261440(file size)
|
|
||||||
// We manually parse here first:
|
|
||||||
uint64_t nd_in_meta = meta_info[0];
|
|
||||||
uint64_t dim_in_meta = meta_info[1];
|
|
||||||
uint64_t max_node_len = meta_info[3];
|
|
||||||
uint64_t c_in_meta = meta_info[4];
|
|
||||||
uint64_t entire_file_sz = meta_info[8];
|
|
||||||
|
|
||||||
std::cout << "Based on meta_info:\n"
|
|
||||||
<< " nd_in_meta= " << nd_in_meta
|
|
||||||
<< ", dim_in_meta= " << dim_in_meta
|
|
||||||
<< ", max_node_len= " << max_node_len
|
|
||||||
<< ", c_in_meta= " << c_in_meta
|
|
||||||
<< ", entire_file_size= " << entire_file_sz << "\n";
|
|
||||||
|
|
||||||
// Calculate graph_node_len
|
|
||||||
uint64_t dim_size = dim_in_meta * sizeof(float);
|
|
||||||
uint64_t graph_node_len = max_node_len - dim_size;
|
|
||||||
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
|
|
||||||
|
|
||||||
if (calculate_stats) {
|
|
||||||
// --- Degree Statistics Calculation Mode ---
|
|
||||||
std::cout << " Calculated graph_node_len = " << graph_node_len << "\n\n";
|
|
||||||
|
|
||||||
if (nd == 0) {
|
|
||||||
std::cerr << "Graph has 0 nodes (nd=0). Cannot calculate stats."
|
|
||||||
<< std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t min_degree = std::numeric_limits<uint32_t>::max();
|
|
||||||
uint32_t max_degree = 0;
|
|
||||||
uint64_t total_degree = 0;
|
|
||||||
uint64_t nodes_processed = 0;
|
|
||||||
std::vector<char> sectorBuf(SECTOR_SIZE);
|
|
||||||
|
|
||||||
std::cout << "Calculating degrees for " << nd << " nodes across "
|
|
||||||
<< partition_nums << " partitions..." << std::endl;
|
|
||||||
|
|
||||||
for (uint32_t p = 0; p < partition_nums; ++p) {
|
|
||||||
uint64_t sector_offset = uint64_t(p + 1) * SECTOR_SIZE;
|
|
||||||
gf.seekg(sector_offset, std::ios::beg);
|
|
||||||
if (gf.fail()) {
|
|
||||||
std::cerr << "Error seeking to sector offset for partition " << p
|
|
||||||
<< std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
gf.read(sectorBuf.data(), SECTOR_SIZE);
|
|
||||||
if (gf.fail() && !gf.eof()) {
|
|
||||||
std::cerr << "Error reading sector data for partition " << p
|
|
||||||
<< std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
gf.clear(); // Reset fail bits
|
|
||||||
|
|
||||||
const auto &part_list = graph_partitions[p];
|
|
||||||
for (size_t j = 0; j < part_list.size(); ++j) {
|
|
||||||
uint64_t node_offset = j * graph_node_len;
|
|
||||||
if (node_offset + sizeof(uint32_t) > SECTOR_SIZE) {
|
|
||||||
std::cerr << "Error: Node offset out of sector bounds.\n"
|
|
||||||
<< " Partition=" << p << ", node_subIndex=" << j
|
|
||||||
<< ", node_offset=" << node_offset
|
|
||||||
<< ", graph_node_len=" << graph_node_len << std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
char *adjacency_ptr = sectorBuf.data() + node_offset;
|
|
||||||
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
|
|
||||||
min_degree = std::min(min_degree, neighbor_count);
|
|
||||||
max_degree = std::max(max_degree, neighbor_count);
|
|
||||||
total_degree += neighbor_count;
|
|
||||||
nodes_processed++;
|
|
||||||
}
|
|
||||||
if (p % 10 == 0 || p == partition_nums - 1) {
|
|
||||||
std::cout << " Processed partition " << p + 1 << " / "
|
|
||||||
<< partition_nums << "...\r" << std::flush;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::cout << "\nFinished processing partitions." << std::endl;
|
|
||||||
|
|
||||||
if (nodes_processed != nd) {
|
|
||||||
std::cerr << "Warning: Processed " << nodes_processed
|
|
||||||
<< " nodes, but expected " << nd << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
double avg_degree = (nd > 0) ? static_cast<double>(total_degree) / nd : 0.0;
|
|
||||||
std::cout << "\n--- Degree Statistics ---\n";
|
|
||||||
std::cout << "Min Degree: "
|
|
||||||
<< (min_degree == std::numeric_limits<uint32_t>::max()
|
|
||||||
? 0
|
|
||||||
: min_degree)
|
|
||||||
<< std::endl; // Handle case of 0 nodes
|
|
||||||
std::cout << "Max Degree: " << max_degree << std::endl;
|
|
||||||
std::cout << "Avg Degree: " << std::fixed << std::setprecision(2)
|
|
||||||
<< avg_degree << std::endl;
|
|
||||||
std::cout << "Total Degree (Sum): " << total_degree << std::endl;
|
|
||||||
std::cout << "Nodes Processed: " << nodes_processed << std::endl;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
uint64_t nd_in_meta = meta_info[0];
|
|
||||||
uint64_t c_in_meta = meta_info[4];
|
|
||||||
uint64_t entire_file_sz = meta_info[8];
|
|
||||||
std::cout << "Based on meta_info:\n"
|
|
||||||
<< " nd_in_meta= " << nd_in_meta
|
|
||||||
<< ", dim_in_meta= " << dim_in_meta
|
|
||||||
<< ", max_node_len= " << max_node_len
|
|
||||||
<< ", c_in_meta= " << c_in_meta
|
|
||||||
<< ", entire_file_size= " << entire_file_sz << "\n";
|
|
||||||
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
|
|
||||||
|
|
||||||
if (target_node_id >= nd) {
|
|
||||||
std::cerr << "target_node_id=" << target_node_id
|
|
||||||
<< " out of range nd=" << nd << std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need id2partition only for single-node lookup
|
|
||||||
std::vector<uint32_t> id2partition(nd);
|
|
||||||
{ // Read id2partition again as it was skipped before
|
|
||||||
std::ifstream pf_again(partition_bin, std::ios::binary);
|
|
||||||
uint64_t header_offset =
|
|
||||||
3 * sizeof(uint64_t); // Skip C, partition_nums, nd
|
|
||||||
uint64_t partition_list_offset = 0;
|
|
||||||
for (uint64_t i = 0; i < partition_nums; i++) {
|
|
||||||
partition_list_offset += sizeof(uint32_t); // Size field
|
|
||||||
partition_list_offset +=
|
|
||||||
graph_partitions[i].size() * sizeof(uint32_t); // Data
|
|
||||||
}
|
|
||||||
pf_again.seekg(header_offset + partition_list_offset, std::ios::beg);
|
|
||||||
pf_again.read(reinterpret_cast<char *>(id2partition.data()),
|
|
||||||
nd * sizeof(uint32_t));
|
|
||||||
// Error check pf_again if needed
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3) Find target_node_id => partition_id => subIndex
|
|
||||||
uint32_t partition_id = id2partition[target_node_id];
|
|
||||||
if (partition_id >= partition_nums) {
|
|
||||||
std::cerr << "Partition ID out-of-range for target node.\n";
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
const auto &part_list = graph_partitions[partition_id]; // Use const ref
|
|
||||||
auto it =
|
|
||||||
std::find(part_list.begin(), part_list.end(), (uint32_t)target_node_id);
|
|
||||||
if (it == part_list.end()) {
|
|
||||||
std::cerr << "Cannot find node " << target_node_id << " in partition "
|
|
||||||
<< partition_id << std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
size_t j = std::distance(part_list.begin(), it);
|
|
||||||
|
|
||||||
// 4) sector => (partition_id+1)* 4096
|
|
||||||
uint64_t sector_offset = uint64_t(partition_id + 1) * SECTOR_SIZE;
|
|
||||||
gf.seekg(sector_offset, std::ios::beg);
|
|
||||||
std::vector<char> sectorBuf(SECTOR_SIZE);
|
|
||||||
gf.read(sectorBuf.data(), SECTOR_SIZE);
|
|
||||||
if (gf.fail() && !gf.eof()) {
|
|
||||||
std::cerr << "Error reading sector data for partition " << partition_id
|
|
||||||
<< std::endl;
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
gf.clear(); // Reset fail bits
|
|
||||||
|
|
||||||
std::cout << "Partition #" << partition_id
|
|
||||||
<< ", nodeCount= " << part_list.size()
|
|
||||||
<< ", offset= " << sector_offset << "\n"
|
|
||||||
<< " first64 hex:\n ";
|
|
||||||
print_hex(sectorBuf.data(), SECTOR_SIZE, 64);
|
|
||||||
|
|
||||||
// adjacency_offset= j* graph_node_len
|
|
||||||
uint64_t node_offset = j * graph_node_len;
|
|
||||||
if (node_offset + sizeof(uint32_t) >
|
|
||||||
SECTOR_SIZE) { // Check only for neighbor_count read first
|
|
||||||
std::cerr << "Out-of-range. j=" << j << ", node_offset=" << node_offset
|
|
||||||
<< ", node_offset+4=" << (node_offset + sizeof(uint32_t))
|
|
||||||
<< "> 4096\n";
|
|
||||||
gf.close();
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
char *adjacency_ptr = sectorBuf.data() + node_offset;
|
|
||||||
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
|
|
||||||
std::cout << "[Node " << target_node_id << "] partition=" << partition_id
|
|
||||||
<< ", subIndex=" << j << ", adjacency_offset=" << node_offset
|
|
||||||
<< ", neighbor_count=" << neighbor_count << "\n";
|
|
||||||
|
|
||||||
size_t needed = neighbor_count * sizeof(uint32_t);
|
|
||||||
if (node_offset + sizeof(uint32_t) + needed > SECTOR_SIZE) {
|
|
||||||
std::cerr << "Neighbors partly out-of-range => neighbor_count="
|
|
||||||
<< neighbor_count << "\n";
|
|
||||||
// Option: Can still print partial list if needed, but indicating it's
|
|
||||||
// truncated
|
|
||||||
gf.close();
|
|
||||||
return 1; // Or handle differently
|
|
||||||
}
|
|
||||||
std::vector<uint32_t> neighbors(neighbor_count);
|
|
||||||
memcpy(neighbors.data(), adjacency_ptr + sizeof(uint32_t), needed);
|
|
||||||
|
|
||||||
std::cout << " neighbors=[";
|
|
||||||
for (size_t kk = 0; kk < std::min<size_t>(10, neighbor_count); kk++) {
|
|
||||||
std::cout << neighbors[kk];
|
|
||||||
if (kk + 1 < std::min<size_t>(10, neighbor_count))
|
|
||||||
std::cout << ", ";
|
|
||||||
}
|
|
||||||
if (neighbor_count > 10)
|
|
||||||
std::cout << " ... (total " << neighbor_count << ")";
|
|
||||||
std::cout << "]\n";
|
|
||||||
} // End of else (single node lookup mode)
|
|
||||||
|
|
||||||
gf.close();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#! /bin/fish
|
|
||||||
|
|
||||||
# get the dir of this script
|
|
||||||
set -x SCRIPT_DIR (dirname (realpath $0))
|
|
||||||
|
|
||||||
g++ $SCRIPT_DIR/analyze_diskann_graph.cpp -o $SCRIPT_DIR/analyze_diskann_graph
|
|
||||||
|
|
||||||
# get args
|
|
||||||
set -x INDEX_PATH $argv[1]
|
|
||||||
|
|
||||||
./analyze_diskann_graph $INDEX_PATH $INDEX_PATH.degree_distribution.txt
|
|
||||||
|
|
||||||
python plot_degree_distribution.py $INDEX_PATH.degree_distribution.txt
|
|
||||||
|
|
||||||
rm $INDEX_PATH.degree_distribution.txt
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env fish
|
|
||||||
|
|
||||||
set scaling_out_dir "/Users/ec2-user/scaling_out"
|
|
||||||
|
|
||||||
# Define an array of paths to download
|
|
||||||
set paths \
|
|
||||||
"examples/" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_medoids.bin" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_centroids.bin" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_max_base_norm.bin" \
|
|
||||||
"embeddings/facebook/contriever-msmarco/rpj_wiki/compressed_10/" \
|
|
||||||
"passages/rpj_wiki/8-shards/" \
|
|
||||||
"indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
|
||||||
|
|
||||||
# Download each path using a for loop
|
|
||||||
for path in $paths
|
|
||||||
echo "Downloading $path..."
|
|
||||||
# if ends with /, then create the directory
|
|
||||||
if string match -q "*/" $path
|
|
||||||
echo "Creating directory $scaling_out_dir/$path"
|
|
||||||
mkdir -p "$scaling_out_dir/$path"
|
|
||||||
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path" --recursive
|
|
||||||
else
|
|
||||||
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
echo "Download completed."
|
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
from tqdm import tqdm
|
|
||||||
from scipy.stats import kendalltau, spearmanr
|
|
||||||
|
|
||||||
# 设置设备
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else
|
|
||||||
("mps" if torch.backends.mps.is_available() else "cpu"))
|
|
||||||
print(f"使用设备: {device}")
|
|
||||||
|
|
||||||
# 定义自定义比较函数(基于内积)
|
|
||||||
def compare(a, b):
|
|
||||||
"""
|
|
||||||
计算两个向量的内积,并返回其负值作为距离度量
|
|
||||||
数值越小表示越相似(与提供的代码一致)
|
|
||||||
"""
|
|
||||||
result = np.dot(a, b)
|
|
||||||
return -result # 返回负值,与原代码一致
|
|
||||||
|
|
||||||
# 批量计算相似度
|
|
||||||
def compute_similarities(queries, corpus):
|
|
||||||
"""计算查询向量与语料库向量之间的相似度矩阵"""
|
|
||||||
similarities = np.zeros((len(queries), len(corpus)))
|
|
||||||
for i, query in enumerate(queries):
|
|
||||||
for j, doc in enumerate(corpus):
|
|
||||||
similarities[i, j] = compare(query, doc)
|
|
||||||
return similarities
|
|
||||||
|
|
||||||
# 加载两个模型
|
|
||||||
model_names = [
|
|
||||||
"facebook/contriever-msmarco", # Contriever模型
|
|
||||||
"facebook/contriever-msmarco-int4" # Contriever模型 (int4)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 扩展的样本文本 - 分为多个主题组
|
|
||||||
texts = [
|
|
||||||
# 组1: 关于狐狸和动物 (0-9)
|
|
||||||
"The quick brown fox jumps over the lazy dog.",
|
|
||||||
"A rapid auburn fox leaps above the inactive canine.",
|
|
||||||
"The sly fox outsmarts the hunting hounds in the forest.",
|
|
||||||
"Foxes are known for their cunning behavior and bushy tails.",
|
|
||||||
"The red fox is the largest of the true foxes and the most common fox species.",
|
|
||||||
"Dogs have been companions to humans for thousands of years.",
|
|
||||||
"The lazy dog slept through the commotion of the playful fox.",
|
|
||||||
"Wolves and foxes belong to the same family, Canidae.",
|
|
||||||
"The arctic fox changes its coat color with the seasons.",
|
|
||||||
"Domestic dogs come in hundreds of breeds of various sizes and appearances.",
|
|
||||||
|
|
||||||
# 组2: 人工智能和机器学习 (10-19)
|
|
||||||
"Machine learning is a branch of artificial intelligence.",
|
|
||||||
"Deep learning is a subset of machine learning.",
|
|
||||||
"Neural networks are computing systems inspired by biological neural networks.",
|
|
||||||
"AI systems can now beat human champions at complex games like chess and Go.",
|
|
||||||
"Natural language processing allows computers to understand human language.",
|
|
||||||
"Reinforcement learning involves training agents to make sequences of decisions.",
|
|
||||||
"Computer vision enables machines to derive information from images and videos.",
|
|
||||||
"The Turing test measures a machine's ability to exhibit intelligent behavior.",
|
|
||||||
"Supervised learning uses labeled training data to learn the mapping function.",
|
|
||||||
"Unsupervised learning finds patterns in data without pre-existing labels.",
|
|
||||||
|
|
||||||
# 组3: 巴黎和法国地标 (20-29)
|
|
||||||
"The Eiffel Tower is located in Paris, France.",
|
|
||||||
"The Louvre Museum is in the city of Paris.",
|
|
||||||
"Notre-Dame Cathedral is a medieval Catholic cathedral on the Île de la Cité in Paris.",
|
|
||||||
"The Arc de Triomphe stands at the center of the Place Charles de Gaulle in Paris.",
|
|
||||||
"The Seine River flows through the heart of Paris.",
|
|
||||||
"Montmartre is a large hill in Paris's 18th arrondissement known for its artistic history.",
|
|
||||||
"The Palace of Versailles is located in the Île-de-France region of France.",
|
|
||||||
"The Champs-Élysées is an avenue in Paris famous for its theatres, cafés, and luxury shops.",
|
|
||||||
"The Sacré-Cœur Basilica offers one of the most beautiful panoramic views of Paris.",
|
|
||||||
"The Musée d'Orsay houses the largest collection of impressionist masterpieces in the world.",
|
|
||||||
|
|
||||||
# 组4: 可再生能源 (30-39)
|
|
||||||
"Solar panels convert sunlight into electricity.",
|
|
||||||
"Wind turbines generate power from moving air.",
|
|
||||||
"Hydroelectric power is generated from flowing water.",
|
|
||||||
"Geothermal energy harnesses heat from within the Earth.",
|
|
||||||
"Biomass energy comes from organic materials like plants and waste.",
|
|
||||||
"Tidal energy uses the natural rise and fall of coastal tidal waters.",
|
|
||||||
"Renewable energy sources can help reduce greenhouse gas emissions.",
|
|
||||||
"Solar farms can span hundreds of acres with thousands of panels.",
|
|
||||||
"Offshore wind farms are built in bodies of water to harvest wind energy.",
|
|
||||||
"Energy storage systems are crucial for balancing renewable energy supply and demand.",
|
|
||||||
|
|
||||||
# 组5: 编程语言 (40-49)
|
|
||||||
"Python is a popular programming language for data science.",
|
|
||||||
"JavaScript is commonly used for web development.",
|
|
||||||
"Java is known for its 'write once, run anywhere' capability.",
|
|
||||||
"C++ provides high-performance and close hardware control.",
|
|
||||||
"Ruby is praised for its simplicity and productivity.",
|
|
||||||
"PHP is a server-side scripting language designed for web development.",
|
|
||||||
"Swift is used to develop applications for Apple platforms.",
|
|
||||||
"Rust offers memory safety without using garbage collection.",
|
|
||||||
"Go was designed at Google to improve programming productivity.",
|
|
||||||
"Kotlin is fully interoperable with Java and provides more concise syntax.",
|
|
||||||
]
|
|
||||||
|
|
||||||
# 扩展的查询句子
|
|
||||||
query_texts = [
|
|
||||||
# 动物相关查询
|
|
||||||
"A fox jumped over a dog.",
|
|
||||||
"Wild animals and their behaviors in forests.",
|
|
||||||
"Different species of foxes around the world.",
|
|
||||||
|
|
||||||
# AI相关查询
|
|
||||||
"Artificial intelligence and neural networks.",
|
|
||||||
"Machine learning algorithms and applications.",
|
|
||||||
"The future of deep learning technology.",
|
|
||||||
|
|
||||||
# 巴黎相关查询
|
|
||||||
"Famous landmarks in Paris, France.",
|
|
||||||
"Tourist attractions along the Seine River.",
|
|
||||||
"Historical buildings and museums in Paris.",
|
|
||||||
|
|
||||||
# 能源相关查询
|
|
||||||
"Renewable energy sources and sustainability.",
|
|
||||||
"Solar and wind power generation technologies.",
|
|
||||||
"Alternative clean energy solutions for the future.",
|
|
||||||
|
|
||||||
# 编程相关查询
|
|
||||||
"Computer programming languages comparison.",
|
|
||||||
"Best languages for web development.",
|
|
||||||
"Programming tools for data science applications."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 函数:获取BGE模型的嵌入
|
|
||||||
def get_bge_embeddings(model, tokenizer, texts, device):
|
|
||||||
# 处理大量文本时分批进行
|
|
||||||
batch_size = 16
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch_texts = texts[i:i+batch_size]
|
|
||||||
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
|
|
||||||
max_length=512, return_tensors='pt').to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model_output = model(**encoded_input)
|
|
||||||
|
|
||||||
# BGE使用[CLS]标记
|
|
||||||
embeddings = model_output.last_hidden_state[:, 0]
|
|
||||||
# 归一化嵌入
|
|
||||||
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
||||||
all_embeddings.append(normalized_embeddings.cpu().numpy())
|
|
||||||
|
|
||||||
return np.vstack(all_embeddings)
|
|
||||||
|
|
||||||
# 函数:获取Contriever模型的嵌入
|
|
||||||
def get_contriever_embeddings(model, tokenizer, texts, device, use_int4=False):
|
|
||||||
# 处理大量文本时分批进行
|
|
||||||
batch_size = 16
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch_texts = texts[i:i+batch_size]
|
|
||||||
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
|
|
||||||
max_length=512, return_tensors='pt').to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model_output = model(**encoded_input)
|
|
||||||
|
|
||||||
# Contriever使用平均池化
|
|
||||||
attention_mask = encoded_input['attention_mask'].unsqueeze(-1)
|
|
||||||
embeddings = (model_output.last_hidden_state * attention_mask).sum(1) / attention_mask.sum(1)
|
|
||||||
# 归一化嵌入
|
|
||||||
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
||||||
all_embeddings.append(normalized_embeddings.cpu().numpy())
|
|
||||||
|
|
||||||
return np.vstack(all_embeddings)
|
|
||||||
|
|
||||||
# 主函数
|
|
||||||
def compare_embeddings():
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for i, model_name in enumerate(model_names):
|
|
||||||
model_display_name = model_name
|
|
||||||
# 给第二个模型一个不同的显示名称,以便区分
|
|
||||||
if i == 1:
|
|
||||||
model_display_name = "facebook/contriever-msmarco-int4"
|
|
||||||
|
|
||||||
print(f"\n======= 加载模型 {i+1}: {model_display_name} =======")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_names[0]) # 两个模型使用相同的tokenizer
|
|
||||||
|
|
||||||
# 如果是第二个模型(int4版本),进行量化
|
|
||||||
if i == 1:
|
|
||||||
print("应用int4量化...")
|
|
||||||
try:
|
|
||||||
from transformers import BitsAndBytesConfig
|
|
||||||
quantization_config = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=torch.float16,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4"
|
|
||||||
)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
model_names[0], # 使用相同的基础模型
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
print("成功加载int4模型")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"int4加载失败: {e}")
|
|
||||||
print("回退到标准模型...")
|
|
||||||
model = AutoModel.from_pretrained(model_names[0]).to(device)
|
|
||||||
else:
|
|
||||||
model = AutoModel.from_pretrained(model_names[0]).to(device)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
print(f"使用 {model_display_name} 生成嵌入...")
|
|
||||||
# 所有模型都使用contriever
|
|
||||||
use_int4 = i == 1
|
|
||||||
corpus_embeddings = get_contriever_embeddings(model, tokenizer, texts, device, use_int4)
|
|
||||||
query_embeddings = get_contriever_embeddings(model, tokenizer, query_texts, device, use_int4)
|
|
||||||
|
|
||||||
print(f"语料库嵌入形状: {corpus_embeddings.shape}")
|
|
||||||
print(f"查询嵌入形状: {query_embeddings.shape}")
|
|
||||||
|
|
||||||
# 使用自定义函数计算相似度
|
|
||||||
similarity_scores = compute_similarities(query_embeddings, corpus_embeddings)
|
|
||||||
|
|
||||||
# 对每个查询,按相似度排序文本索引(较小的值表示更相似)
|
|
||||||
ranked_indices = {}
|
|
||||||
for j, scores in enumerate(similarity_scores):
|
|
||||||
# 按相似度从低到高排序(因为我们返回的是负内积值)
|
|
||||||
sorted_indices = np.argsort(scores)
|
|
||||||
ranked_indices[f"query_{j+1}"] = sorted_indices
|
|
||||||
|
|
||||||
results[model_display_name] = {
|
|
||||||
'corpus_embeddings': corpus_embeddings,
|
|
||||||
'query_embeddings': query_embeddings,
|
|
||||||
'similarity_scores': similarity_scores,
|
|
||||||
'ranked_indices': ranked_indices
|
|
||||||
}
|
|
||||||
|
|
||||||
# 立即打印这个模型的一些结果作为验证
|
|
||||||
print(f"\n=== {model_display_name} 初步结果 ===")
|
|
||||||
# 显示第一个查询的前3个结果
|
|
||||||
query_idx = 0
|
|
||||||
ranked_idx = ranked_indices[f"query_{query_idx+1}"]
|
|
||||||
top_texts = [texts[idx] for idx in ranked_idx[:3]]
|
|
||||||
print(f"查询: '{query_texts[query_idx]}'")
|
|
||||||
print(f"排名前3位的文本:")
|
|
||||||
for j, text in enumerate(top_texts):
|
|
||||||
idx = ranked_idx[j]
|
|
||||||
score = similarity_scores[query_idx][idx]
|
|
||||||
print(f" {j+1}. [ID:{idx}] {text} (分数: {score:.4f})")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
# 分析结果
|
|
||||||
def analyze_results(results):
|
|
||||||
models = list(results.keys())
|
|
||||||
|
|
||||||
# 1. 比较相似度分数
|
|
||||||
print("\n=== 相似度分数比较 ===")
|
|
||||||
for model_name, result in results.items():
|
|
||||||
similarities = result['similarity_scores'].flatten()
|
|
||||||
print(f"{model_name} 相似度统计:")
|
|
||||||
print(f" 平均值: {similarities.mean():.4f}")
|
|
||||||
print(f" 最小值: {similarities.min():.4f}")
|
|
||||||
print(f" 最大值: {similarities.max():.4f}")
|
|
||||||
print(f" 标准差: {similarities.std():.4f}")
|
|
||||||
|
|
||||||
# 2. 比较排序结果(针对每个查询显示前5个结果)
|
|
||||||
print("\n=== 排序结果比较 ===")
|
|
||||||
for query_idx in range(len(query_texts)):
|
|
||||||
query_key = f"query_{query_idx+1}"
|
|
||||||
print(f"\n查询 {query_idx+1}: '{query_texts[query_idx]}'")
|
|
||||||
|
|
||||||
for model_name in models:
|
|
||||||
ranked_idx = results[model_name]['ranked_indices'][query_key]
|
|
||||||
top_texts = [texts[idx] for idx in ranked_idx[:5]]
|
|
||||||
print(f"{model_name} 排名前5位的文本:")
|
|
||||||
for i, text in enumerate(top_texts):
|
|
||||||
idx = ranked_idx[i]
|
|
||||||
score = results[model_name]['similarity_scores'][query_idx][idx]
|
|
||||||
print(f" {i+1}. [ID:{idx}] {text} (分数: {score:.4f})")
|
|
||||||
|
|
||||||
# 3. 排序一致性分析
|
|
||||||
print("\n=== 模型间排序一致性分析 ===")
|
|
||||||
kendall_tau_scores = []
|
|
||||||
spearman_scores = []
|
|
||||||
|
|
||||||
for query_idx in range(len(query_texts)):
|
|
||||||
query_key = f"query_{query_idx+1}"
|
|
||||||
|
|
||||||
# 获取各模型的排序结果(只比较前10个结果)
|
|
||||||
model1_top10 = results[models[0]]['ranked_indices'][query_key][:10]
|
|
||||||
model2_top10 = results[models[1]]['ranked_indices'][query_key][:10]
|
|
||||||
|
|
||||||
# 计算排序一致性
|
|
||||||
kt, _ = kendalltau(model1_top10, model2_top10)
|
|
||||||
sr, _ = spearmanr(model1_top10, model2_top10)
|
|
||||||
|
|
||||||
kendall_tau_scores.append(kt)
|
|
||||||
spearman_scores.append(sr)
|
|
||||||
|
|
||||||
# 计算前10个结果的重叠率
|
|
||||||
overlap = len(set(model1_top10) & set(model2_top10))
|
|
||||||
overlap_rate = overlap / 10.0
|
|
||||||
|
|
||||||
print(f"查询 {query_idx+1} '{query_texts[query_idx]}':")
|
|
||||||
print(f" Kendall's Tau = {kt:.4f}, Spearman's rank correlation = {sr:.4f}")
|
|
||||||
print(f" 前10结果重叠率: {overlap_rate:.2f} ({overlap}/10)")
|
|
||||||
|
|
||||||
print(f"\n平均 Kendall's Tau: {np.mean(kendall_tau_scores):.4f}")
|
|
||||||
print(f"平均 Spearman's rank correlation: {np.mean(spearman_scores):.4f}")
|
|
||||||
|
|
||||||
# 4. 可视化相似度分布差异
|
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
for i, model_name in enumerate(models):
|
|
||||||
sns.histplot(results[model_name]['similarity_scores'].flatten(),
|
|
||||||
kde=True, label=model_name, alpha=0.6)
|
|
||||||
|
|
||||||
plt.title('不同模型的相似度分布')
|
|
||||||
plt.xlabel('相似度得分(越小越相似)')
|
|
||||||
plt.ylabel('频率')
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig('similarity_distribution.png')
|
|
||||||
print("已保存相似度分布图表到 'similarity_distribution.png'")
|
|
||||||
|
|
||||||
# 5. 可视化主题相关性
|
|
||||||
plt.figure(figsize=(15, 10))
|
|
||||||
|
|
||||||
# 为每个主题组定义颜色
|
|
||||||
topic_colors = {
|
|
||||||
'动物': 'blue',
|
|
||||||
'AI': 'red',
|
|
||||||
'巴黎': 'green',
|
|
||||||
'能源': 'purple',
|
|
||||||
'编程': 'orange'
|
|
||||||
}
|
|
||||||
|
|
||||||
# 定义主题组范围
|
|
||||||
topic_ranges = {
|
|
||||||
'动物': (0, 10),
|
|
||||||
'AI': (10, 20),
|
|
||||||
'巴黎': (20, 30),
|
|
||||||
'能源': (30, 40),
|
|
||||||
'编程': (40, 50)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 对每个查询显示前10个结果的主题分布
|
|
||||||
query_groups = [
|
|
||||||
[0, 1, 2], # 动物查询组
|
|
||||||
[3, 4, 5], # AI查询组
|
|
||||||
[6, 7, 8], # 巴黎查询组
|
|
||||||
[9, 10, 11], # 能源查询组
|
|
||||||
[12, 13, 14] # 编程查询组
|
|
||||||
]
|
|
||||||
|
|
||||||
for group_idx, group in enumerate(query_groups):
|
|
||||||
plt.subplot(len(query_groups), 1, group_idx+1)
|
|
||||||
|
|
||||||
# 为每个模型计算主题分布
|
|
||||||
bar_width = 0.35
|
|
||||||
bar_positions = np.arange(len(topic_ranges))
|
|
||||||
|
|
||||||
for model_idx, model_name in enumerate(models):
|
|
||||||
# 统计每个主题在前10个结果中的出现次数
|
|
||||||
topic_counts = {topic: 0 for topic in topic_ranges.keys()}
|
|
||||||
|
|
||||||
for query_idx in group:
|
|
||||||
query_key = f"query_{query_idx+1}"
|
|
||||||
top10 = results[model_name]['ranked_indices'][query_key][:10]
|
|
||||||
|
|
||||||
for idx in top10:
|
|
||||||
for topic, (start, end) in topic_ranges.items():
|
|
||||||
if start <= idx < end:
|
|
||||||
topic_counts[topic] += 1
|
|
||||||
|
|
||||||
# 绘制主题分布柱状图
|
|
||||||
plt.bar(bar_positions + (model_idx * bar_width),
|
|
||||||
list(topic_counts.values()),
|
|
||||||
bar_width,
|
|
||||||
label=model_name)
|
|
||||||
|
|
||||||
plt.title(f"查询组 {group_idx+1}: {', '.join([query_texts[i] for i in group[:1]])}")
|
|
||||||
plt.xticks(bar_positions + bar_width/2, list(topic_ranges.keys()))
|
|
||||||
plt.ylabel('前10结果中的出现次数')
|
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig('topic_distribution.png')
|
|
||||||
print("已保存主题分布图表到 'topic_distribution.png'")
|
|
||||||
|
|
||||||
# 6. 可视化查询与相关文档的相似度热图
|
|
||||||
plt.figure(figsize=(15, 12))
|
|
||||||
|
|
||||||
for i, model_name in enumerate(models):
|
|
||||||
plt.subplot(2, 1, i+1)
|
|
||||||
|
|
||||||
# 获取相似度矩阵(负数越小表示越相似)
|
|
||||||
sim_matrix = results[model_name]['similarity_scores']
|
|
||||||
|
|
||||||
# 将负值转换为正值以便可视化(越大表示越相似)
|
|
||||||
sim_matrix_viz = -sim_matrix
|
|
||||||
|
|
||||||
# 绘制热图
|
|
||||||
sns.heatmap(sim_matrix_viz, cmap='YlGnBu',
|
|
||||||
xticklabels=[f"Doc{i}" for i in range(len(texts))],
|
|
||||||
yticklabels=[f"Q{i+1}" for i in range(len(query_texts))],
|
|
||||||
cbar_kws={'label': '相似度(越高越相似)'})
|
|
||||||
|
|
||||||
plt.title(f"{model_name} 相似度热图")
|
|
||||||
plt.xlabel('文档ID')
|
|
||||||
plt.ylabel('查询ID')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig('similarity_heatmap.png')
|
|
||||||
print("已保存相似度热图到 'similarity_heatmap.png'")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("开始比较嵌入模型...")
|
|
||||||
results = compare_embeddings()
|
|
||||||
analyze_results(results)
|
|
||||||
print("\n比较完成!")
|
|
||||||
@@ -1,444 +0,0 @@
|
|||||||
# Filename: evaluate_results_xai_line_sync.py
|
|
||||||
import openai
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from tqdm import tqdm
|
|
||||||
from collections import defaultdict
|
|
||||||
import concurrent.futures
|
|
||||||
from typing import List, Dict, Any, Tuple
|
|
||||||
|
|
||||||
# --- Configuration ---
|
|
||||||
load_dotenv()
|
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not OPENAI_API_KEY:
|
|
||||||
raise ValueError("Please set the OPENAI_API_KEY in your .env file")
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key=OPENAI_API_KEY,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
print("Please install the latest OpenAI library: pip install --upgrade openai")
|
|
||||||
exit()
|
|
||||||
except openai.AuthenticationError:
|
|
||||||
print("OpenAI library reported an AuthenticationError. Ensure OPENAI_API_KEY is correct.")
|
|
||||||
exit()
|
|
||||||
|
|
||||||
LLM_MODEL = "gpt-3.5-turbo" # Using OpenAI's standard model
|
|
||||||
MAX_RETRIES = 5
|
|
||||||
INITIAL_RETRY_DELAY_SECONDS = 5
|
|
||||||
REQUEST_TIMEOUT_SECONDS = 90
|
|
||||||
MAX_WORKERS = 10 # Number of parallel workers
|
|
||||||
|
|
||||||
# --- File Paths (Adjust as needed) ---
|
|
||||||
# User provided paths
|
|
||||||
QUERIES_FILE_PATH = "/opt/dlami/nvme/scaling_out/examples/enron_eval_retrieval.jsonl"
|
|
||||||
RAW_PASSAGES_FILE_PATH = "/opt/dlami/nvme/scaling_out/passages/enron_emails/1-shards/raw_passages-0-of-1.jsonl"
|
|
||||||
RESULTS_FILE_PATH = "search_results_top10_bm25.jsonl" # This file's Nth line corresponds to QUERIES_FILE_PATH's Nth line
|
|
||||||
OUTPUT_EVALUATION_FILE = "llm_containment_evaluations_xai_line_sync.jsonl"
|
|
||||||
|
|
||||||
# --- LLM Prompt Definitions for Containment (Same as before) ---
|
|
||||||
CONTAINMENT_SYSTEM_PROMPT = """You are an AI evaluator. Your task is to determine if the core information presented in the 'Retrieved Passage' is directly contained within *any* of the text snippets provided in the 'Ground Truth Email Snippets' list."""
|
|
||||||
CONTAINMENT_USER_TEMPLATE = """Retrieved Passage:
|
|
||||||
"{retrieved_passage_text}"
|
|
||||||
|
|
||||||
---
|
|
||||||
Ground Truth Email Snippets (Parts of the correct source email):
|
|
||||||
{ground_truth_snippets_formatted_list}
|
|
||||||
---
|
|
||||||
|
|
||||||
Is the core information of the 'Retrieved Passage' directly present or fully contained within *any* of the 'Ground Truth Email Snippets' listed above?
|
|
||||||
- Focus on whether the specific facts or statements in the 'Retrieved Passage' can be found within the ground truth snippets.
|
|
||||||
- Ignore minor formatting differences. If the retrieved passage is a direct quote or a very close paraphrase of content within the ground truth snippets, answer YES.
|
|
||||||
- Respond YES if the Retrieved Passage's content is clearly represented in one or more of the ground truth snippets.
|
|
||||||
- Respond NO if the Retrieved Passage's content is not found, is contradictory, or introduces significant information not present in the ground truth snippets.
|
|
||||||
|
|
||||||
Your response must be a single word: YES or NO.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --- Data Loading Functions ---
|
|
||||||
|
|
||||||
def load_queries_as_list(file_path):
|
|
||||||
"""
|
|
||||||
Loads queries from a jsonl file into a list, preserving order.
|
|
||||||
Each item in the list is a dict containing original_id, query_text, and ground_truth_message_ids.
|
|
||||||
"""
|
|
||||||
queries_list = []
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
required_keys = ["id", "query", "ground_truth_message_ids"]
|
|
||||||
if not all(key in data for key in required_keys):
|
|
||||||
print(f"Warning: Skipping line {line_num + 1} in query file due to missing keys: {line.strip()}")
|
|
||||||
continue
|
|
||||||
if not isinstance(data["ground_truth_message_ids"], list):
|
|
||||||
print(f"Warning: 'ground_truth_message_ids' is not a list in line {line_num + 1}. Skipping: {line.strip()}")
|
|
||||||
continue
|
|
||||||
queries_list.append({
|
|
||||||
"original_id": data["id"], # Store the original ID from the file
|
|
||||||
"query_text": data["query"],
|
|
||||||
"ground_truth_message_ids": data["ground_truth_message_ids"]
|
|
||||||
})
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in query file: {line.strip()}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Queries file not found at {file_path}")
|
|
||||||
exit()
|
|
||||||
print(f"Loaded {len(queries_list)} queries (as a list) from {file_path}")
|
|
||||||
return queries_list
|
|
||||||
|
|
||||||
def load_all_passages_by_message_id(raw_passages_file_path):
|
|
||||||
"""Loads all raw passages into memory, grouped by message_id. (Same as before)"""
|
|
||||||
passages_dict = defaultdict(list)
|
|
||||||
# ... (implementation from previous script, no changes needed here) ...
|
|
||||||
print(f"Loading all raw passages from {raw_passages_file_path} into memory...")
|
|
||||||
try:
|
|
||||||
with open(raw_passages_file_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
if "message_id" in data and "text" in data:
|
|
||||||
passages_dict[data["message_id"]].append(data["text"])
|
|
||||||
else:
|
|
||||||
print(f"Warning: Skipping line {line_num+1} in raw passages file due to missing 'message_id' or 'text'.")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in raw passages file: {line.strip()}")
|
|
||||||
print(f"Finished loading raw passages. Found {len(passages_dict)} unique message IDs.")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Raw passages file not found at {raw_passages_file_path}")
|
|
||||||
exit()
|
|
||||||
except MemoryError:
|
|
||||||
print("Error: Ran out of memory loading all raw passages. Consider an indexed approach.")
|
|
||||||
exit()
|
|
||||||
return dict(passages_dict)
|
|
||||||
|
|
||||||
def load_search_results_as_list(file_path):
|
|
||||||
"""Loads search results from a jsonl file into a list, preserving order."""
|
|
||||||
results_list = []
|
|
||||||
# ... (implementation similar to load_queries_as_list, parsing each line as a dict) ...
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
# We expect "query_id" (though not used for matching) and "passages"
|
|
||||||
if "passages" not in data: # query_id might be implicitly by order
|
|
||||||
print(f"Warning: Skipping line {line_num + 1} in search results file due to missing 'passages' key: {line.strip()}")
|
|
||||||
continue
|
|
||||||
results_list.append(data)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in search results file: {line.strip()}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Search results file not found at {file_path}")
|
|
||||||
exit()
|
|
||||||
print(f"Loaded {len(results_list)} search result sets (as a list) from {file_path}")
|
|
||||||
return results_list
|
|
||||||
|
|
||||||
|
|
||||||
def format_ground_truth_snippets(snippet_list):
|
|
||||||
"""Formats the list of ground truth snippets for the prompt. (Same as before)"""
|
|
||||||
# ... (implementation from previous script) ...
|
|
||||||
if not snippet_list:
|
|
||||||
return " [No ground truth snippets found for the target message ID(s)]"
|
|
||||||
formatted = []
|
|
||||||
for i, snippet in enumerate(snippet_list):
|
|
||||||
display_snippet = (snippet[:500] + '...') if len(snippet) > 500 else snippet
|
|
||||||
formatted.append(f" {i+1}. {display_snippet}")
|
|
||||||
return "\n".join(formatted)
|
|
||||||
|
|
||||||
# --- LLM API Call Function ---
|
|
||||||
def get_llm_containment_evaluation(retrieved_passage_text: str, ground_truth_snippets_list: List[str], query_id_for_log: str, passage_identifier_info: str, query_text_for_context: str = "") -> str:
|
|
||||||
"""Calls the OpenAI API with retry logic."""
|
|
||||||
formatted_gt_snippets = format_ground_truth_snippets(ground_truth_snippets_list)
|
|
||||||
# max_gt_chars_in_prompt = 5000 # Arbitrary limit, adjust as needed
|
|
||||||
# if len(formatted_gt_snippets) > max_gt_chars_in_prompt:
|
|
||||||
# print(f"Warning: Ground truth snippets for Q_log_id:{query_id_for_log} are too long ({len(formatted_gt_snippets)} chars), truncating for LLM prompt.")
|
|
||||||
# formatted_gt_snippets = formatted_gt_snippets[:max_gt_chars_in_prompt] + "\n [... Snippets Truncated ...]"
|
|
||||||
|
|
||||||
user_prompt = CONTAINMENT_USER_TEMPLATE.format(
|
|
||||||
retrieved_passage_text=retrieved_passage_text,
|
|
||||||
ground_truth_snippets_formatted_list=formatted_gt_snippets
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": CONTAINMENT_SYSTEM_PROMPT},
|
|
||||||
{"role": "user", "content": user_prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
current_retry_delay = INITIAL_RETRY_DELAY_SECONDS
|
|
||||||
for attempt in range(MAX_RETRIES):
|
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=LLM_MODEL,
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=10,
|
|
||||||
timeout=REQUEST_TIMEOUT_SECONDS
|
|
||||||
)
|
|
||||||
answer = response.choices[0].message.content.strip().upper()
|
|
||||||
if answer in ["YES", "NO"]:
|
|
||||||
return answer
|
|
||||||
else:
|
|
||||||
print(f"Warning: Unexpected LLM response content '{answer[:100]}' for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Defaulting to NO.")
|
|
||||||
return "NO"
|
|
||||||
except openai.APIConnectionError as e:
|
|
||||||
error_message = f"API Connection Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
|
|
||||||
except openai.RateLimitError as e:
|
|
||||||
error_message = f"API Rate Limit Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
|
|
||||||
except openai.APIStatusError as e:
|
|
||||||
error_message = f"API Status Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e.status_code} - {e.response}"
|
|
||||||
if e.status_code == 401:
|
|
||||||
return "ERROR_AUTH"
|
|
||||||
if e.status_code == 500:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return "ERROR_API_CLIENT"
|
|
||||||
except Exception as e:
|
|
||||||
error_message = f"Unexpected error with OpenAI lib (Attempt {attempt + 1}/{MAX_RETRIES}): {type(e).__name__} - {e}"
|
|
||||||
|
|
||||||
print(f"{error_message}. Query Log ID: {query_id_for_log}, Passage: {passage_identifier_info}")
|
|
||||||
if "ERROR_AUTH" in error_message or "ERROR_API_CLIENT" in error_message:
|
|
||||||
break
|
|
||||||
|
|
||||||
if attempt < MAX_RETRIES - 1:
|
|
||||||
print(f"Retrying in {current_retry_delay} seconds...")
|
|
||||||
time.sleep(current_retry_delay)
|
|
||||||
current_retry_delay = min(current_retry_delay * 2, 60)
|
|
||||||
else:
|
|
||||||
print(f"Max retries ({MAX_RETRIES}) reached for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Skipping.")
|
|
||||||
return "ERROR_MAX_RETRIES"
|
|
||||||
return "ERROR_MAX_RETRIES"
|
|
||||||
|
|
||||||
def process_query_passage_pair(args: Tuple[Dict[str, Any], Dict[str, Any], Dict[str, List[str]], set]) -> List[Dict[str, Any]]:
|
|
||||||
"""Process a single query-passage pair for parallel execution."""
|
|
||||||
query_info, result_item, passages_lookup, already_evaluated = args
|
|
||||||
evaluations = []
|
|
||||||
|
|
||||||
query_original_id = query_info["original_id"]
|
|
||||||
query_text = query_info["query_text"]
|
|
||||||
target_message_ids = query_info.get("ground_truth_message_ids", [])
|
|
||||||
|
|
||||||
if not target_message_ids:
|
|
||||||
return evaluations
|
|
||||||
|
|
||||||
ground_truth_snippets = []
|
|
||||||
for msg_id_in_query_file in target_message_ids:
|
|
||||||
msg_id_to_lookup = msg_id_in_query_file
|
|
||||||
if msg_id_in_query_file.startswith("<") and msg_id_in_query_file.endswith(">"):
|
|
||||||
msg_id_to_lookup = msg_id_in_query_file[1:-1]
|
|
||||||
|
|
||||||
snippets = passages_lookup.get(msg_id_to_lookup)
|
|
||||||
if snippets:
|
|
||||||
ground_truth_snippets.extend(snippets)
|
|
||||||
|
|
||||||
if not ground_truth_snippets:
|
|
||||||
return evaluations
|
|
||||||
|
|
||||||
retrieved_passages = result_item.get("passages", [])
|
|
||||||
if not retrieved_passages:
|
|
||||||
return evaluations
|
|
||||||
|
|
||||||
for passage_idx, passage_obj in enumerate(retrieved_passages):
|
|
||||||
if not isinstance(passage_obj, dict):
|
|
||||||
print(f"Warning: Invalid passage format for Q_original_id:{query_original_id}, passage index {passage_idx}. Skipping passage.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
retrieved_passage_text = passage_obj.get("text", "").strip()
|
|
||||||
passage_identifier = passage_obj.get("passage_id", passage_obj.get("id", f"retrieved_idx_{passage_idx}"))
|
|
||||||
|
|
||||||
evaluation_key = (query_original_id, passage_identifier)
|
|
||||||
if evaluation_key in already_evaluated:
|
|
||||||
continue
|
|
||||||
|
|
||||||
passage_text_preview = (retrieved_passage_text[:75] + '...') if len(retrieved_passage_text) > 75 else retrieved_passage_text
|
|
||||||
|
|
||||||
if not retrieved_passage_text:
|
|
||||||
evaluation = "NO"
|
|
||||||
else:
|
|
||||||
evaluation = get_llm_containment_evaluation(
|
|
||||||
retrieved_passage_text,
|
|
||||||
ground_truth_snippets,
|
|
||||||
query_original_id,
|
|
||||||
passage_identifier,
|
|
||||||
query_text
|
|
||||||
)
|
|
||||||
if evaluation == "ERROR_AUTH":
|
|
||||||
print("Authentication error with OpenAI API. Stopping script.")
|
|
||||||
return evaluations
|
|
||||||
|
|
||||||
evaluation_record = {
|
|
||||||
"query_original_id": query_original_id,
|
|
||||||
"passage_identifier": passage_identifier,
|
|
||||||
"passage_text_preview": passage_text_preview,
|
|
||||||
"evaluation": evaluation,
|
|
||||||
"model_used": LLM_MODEL,
|
|
||||||
"ground_truth_message_ids_checked": target_message_ids
|
|
||||||
}
|
|
||||||
evaluations.append(evaluation_record)
|
|
||||||
|
|
||||||
return evaluations
|
|
||||||
|
|
||||||
# --- Resume Logic ---
|
|
||||||
def load_existing_evaluations(output_file):
|
|
||||||
"""Loads already evaluated query-passage pairs using 'passage_identifier' and 'query_original_id'. (Same as before, but keying with original_id)"""
|
|
||||||
# ... (implementation from previous script, ensure it uses the correct ID for keys) ...
|
|
||||||
evaluated_pairs = set()
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
print(f"Loading existing containment evaluations from {output_file}...")
|
|
||||||
with open(output_file, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
# Key for resuming should be based on the logged original query ID
|
|
||||||
query_original_id = data.get('query_original_id')
|
|
||||||
passage_identifier = data.get('passage_identifier')
|
|
||||||
if query_original_id is not None and passage_identifier is not None:
|
|
||||||
evaluated_pairs.add((query_original_id, passage_identifier))
|
|
||||||
else:
|
|
||||||
print(f"Warning: Could not identify query_original_id/passage_identifier in existing file line {line_num + 1}.")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping malformed line {line_num + 1} in existing file: {line.strip()}")
|
|
||||||
except KeyError as e:
|
|
||||||
print(f"Warning: Skipping line {line_num + 1} with missing key '{e}' in existing file: {line.strip()}")
|
|
||||||
print(f"Loaded {len(evaluated_pairs)} existing evaluation records.")
|
|
||||||
else:
|
|
||||||
print(f"No existing evaluation file found at {output_file}. Starting fresh.")
|
|
||||||
return evaluated_pairs
|
|
||||||
|
|
||||||
# --- Main Execution Logic ---
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to run the containment evaluation process using parallel processing."""
|
|
||||||
print(f"Starting containment evaluation using OpenAI model: {LLM_MODEL} via OpenAI library interface.")
|
|
||||||
|
|
||||||
# Load data as lists
|
|
||||||
queries_list = load_queries_as_list(QUERIES_FILE_PATH)
|
|
||||||
passages_lookup = load_all_passages_by_message_id(RAW_PASSAGES_FILE_PATH)
|
|
||||||
search_results_list = load_search_results_as_list(RESULTS_FILE_PATH)
|
|
||||||
|
|
||||||
if not queries_list or not search_results_list or not passages_lookup:
|
|
||||||
print("Error loading one or more input files or raw passages. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Determine the number of items to process
|
|
||||||
num_items_to_process = min(len(queries_list), len(search_results_list))
|
|
||||||
print(f"Will process {num_items_to_process} query-result pairs.")
|
|
||||||
|
|
||||||
already_evaluated = load_existing_evaluations(OUTPUT_EVALUATION_FILE)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(OUTPUT_EVALUATION_FILE, 'a', encoding='utf-8') as outfile:
|
|
||||||
# Prepare arguments for parallel processing
|
|
||||||
process_args = [
|
|
||||||
(queries_list[i], search_results_list[i], passages_lookup, already_evaluated)
|
|
||||||
for i in range(num_items_to_process)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Use ThreadPoolExecutor for parallel processing
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
|
||||||
# Submit all tasks and get futures
|
|
||||||
futures = [executor.submit(process_query_passage_pair, args) for args in process_args]
|
|
||||||
|
|
||||||
# Process results as they complete
|
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing query-result pairs"):
|
|
||||||
try:
|
|
||||||
evaluations = future.result()
|
|
||||||
for evaluation in evaluations:
|
|
||||||
outfile.write(json.dumps(evaluation) + "\n")
|
|
||||||
outfile.flush()
|
|
||||||
# Update already_evaluated set
|
|
||||||
already_evaluated.add((evaluation["query_original_id"], evaluation["passage_identifier"]))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing query-result pair: {e}")
|
|
||||||
|
|
||||||
except IOError as e:
|
|
||||||
print(f"Error writing to output file {OUTPUT_EVALUATION_FILE}: {e}")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An unexpected error occurred during the main processing loop: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("\n--- Containment Evaluation Script Finished ---")
|
|
||||||
|
|
||||||
# --- Final Summary Calculation ---
|
|
||||||
print(f"Calculating final summary statistics from: {OUTPUT_EVALUATION_FILE}")
|
|
||||||
final_query_containment_found = {}
|
|
||||||
total_evaluated_pairs = 0
|
|
||||||
error_count = 0
|
|
||||||
evaluated_query_original_ids = set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(OUTPUT_EVALUATION_FILE, 'r', encoding='utf-8') as f:
|
|
||||||
for line_num, line in enumerate(f):
|
|
||||||
total_evaluated_pairs += 1
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
q_original_id = data['query_original_id']
|
|
||||||
eval_result = data['evaluation']
|
|
||||||
evaluated_query_original_ids.add(q_original_id)
|
|
||||||
|
|
||||||
if eval_result == "YES":
|
|
||||||
final_query_containment_found[q_original_id] = True
|
|
||||||
elif q_original_id not in final_query_containment_found:
|
|
||||||
final_query_containment_found[q_original_id] = False
|
|
||||||
if eval_result not in ["YES", "NO"]:
|
|
||||||
error_count += 1
|
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
|
||||||
print(f"Error reading line {line_num + 1} during summary: {e} - Line: {line.strip()}")
|
|
||||||
error_count += 1
|
|
||||||
|
|
||||||
num_queries_with_any_contained = sum(1 for contained in final_query_containment_found.values() if contained)
|
|
||||||
total_unique_queries_evaluated = len(evaluated_query_original_ids)
|
|
||||||
|
|
||||||
if total_unique_queries_evaluated > 0:
|
|
||||||
containment_rate_at_10 = num_queries_with_any_contained / total_unique_queries_evaluated
|
|
||||||
print(f"\n--- Final Statistics (Containment Check) ---")
|
|
||||||
print(f"Total unique queries processed (based on output file entries): {total_unique_queries_evaluated}")
|
|
||||||
print(f"Number of queries with at least one contained passage (YES): {num_queries_with_any_contained}")
|
|
||||||
print(f"Containment Match Rate @ Top 10 (Any YES): {containment_rate_at_10:.4f}")
|
|
||||||
print(f"Total query-passage pairs processed (lines in output file): {total_evaluated_pairs}")
|
|
||||||
if error_count > 0:
|
|
||||||
print(f"Number of evaluation errors or non-YES/NO results: {error_count}")
|
|
||||||
else:
|
|
||||||
print("No evaluation results found to summarize.")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Output file {OUTPUT_EVALUATION_FILE} not found for summary.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An unexpected error occurred during summary calculation: {e}")
|
|
||||||
|
|
||||||
print(f"\nDetailed containment evaluations saved to: {OUTPUT_EVALUATION_FILE}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Dummy files for testing the line sync logic
|
|
||||||
if not os.path.exists(QUERIES_FILE_PATH):
|
|
||||||
print(f"Warning: {QUERIES_FILE_PATH} not found. Creating dummy file.")
|
|
||||||
with open(QUERIES_FILE_PATH, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump({"id": "q_alpha", "query": "Query Alpha Text", "ground_truth_message_ids": ["<msg_A>"]}, f); f.write("\n") # Line 0
|
|
||||||
json.dump({"id": "q_beta", "query": "Query Beta Text", "ground_truth_message_ids": ["<msg_B>"]}, f); f.write("\n") # Line 1
|
|
||||||
json.dump({"id": "q_gamma", "query": "Query Gamma Text", "ground_truth_message_ids": ["<msg_C>"]}, f); f.write("\n")# Line 2
|
|
||||||
|
|
||||||
if not os.path.exists(RAW_PASSAGES_FILE_PATH):
|
|
||||||
print(f"Warning: {RAW_PASSAGES_FILE_PATH} not found. Creating dummy file.")
|
|
||||||
with open(RAW_PASSAGES_FILE_PATH, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump({"text": "Content from message A snippet 1.", "id": 100, "message_id": "<msg_A>"}, f); f.write("\n")
|
|
||||||
json.dump({"text": "Content from message A snippet 2.", "id": 101, "message_id": "<msg_A>"}, f); f.write("\n")
|
|
||||||
json.dump({"text": "Content from message B.", "id": 200, "message_id": "<msg_B>"}, f); f.write("\n")
|
|
||||||
json.dump({"text": "Content from message D (unrelated).", "id": 300, "message_id": "<msg_D>"}, f); f.write("\n")
|
|
||||||
|
|
||||||
# RESULTS_FILE_PATH should have results corresponding line-by-line to QUERIES_FILE_PATH
|
|
||||||
if not os.path.exists(RESULTS_FILE_PATH):
|
|
||||||
print(f"Warning: {RESULTS_FILE_PATH} not found. Creating dummy file (2 entries).")
|
|
||||||
with open(RESULTS_FILE_PATH, 'w', encoding='utf-8') as f:
|
|
||||||
# Result for query "q_alpha" (line 0 in queries file)
|
|
||||||
json.dump({"query_id": "this_can_be_ignored_if_line_sync", "passages": [{"id": 101, "text": "Content from message A snippet 2."}, {"id": 300, "text": "Content from message D (unrelated)."}]}, f); f.write("\n")
|
|
||||||
# Result for query "q_beta" (line 1 in queries file)
|
|
||||||
json.dump({"query_id": "this_too", "passages": [{"id": 999, "text": "Some other text."}, {"id": 200, "text": "Content from message B."}]}, f); f.write("\n")
|
|
||||||
# Note: Only 2 result sets, but 3 queries in dummy QUERIES_FILE_PATH.
|
|
||||||
# The script will process min(len(queries_list), len(search_results_list)) if you uncomment that logic,
|
|
||||||
# or just len(search_results_list) as it's currently written for tqdm.
|
|
||||||
|
|
||||||
main()
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
# Recompute Embeddings Saved
|
|
||||||
|
|
||||||
```console
|
|
||||||
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute --dedup --use-partition
|
|
||||||
python ./demo/embedding_server.py --domain rpj_wiki
|
|
||||||
python ./demo/test_serve.py --port 8082 --nprobe 80 --re --dedup
|
|
||||||
```
|
|
||||||
|
|
||||||
Result:
|
|
||||||
```
|
|
||||||
Evaluation Results for nprobe = 80:
|
|
||||||
Final Recall Rate: 0.9333
|
|
||||||
Average total latency: 2.427s
|
|
||||||
Average search time: 2.414s
|
|
||||||
```
|
|
||||||
|
|
||||||
其中,use-partition也可以不加,也可以跑。不加的效果如下:
|
|
||||||
```
|
|
||||||
Results for nprobe = 80:
|
|
||||||
Final Recall Rate: 0.9333
|
|
||||||
Average total latency: 2.434s
|
|
||||||
Average search time: 2.421s
|
|
||||||
```
|
|
||||||
|
|
||||||
# Recompute Embeddings + Loading from disk
|
|
||||||
|
|
||||||
Remove `--dedup --use-partition`
|
|
||||||
|
|
||||||
```console
|
|
||||||
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute
|
|
||||||
python ./demo/embedding_server.py --domain rpj_wiki
|
|
||||||
python ./demo/test_serve.py --port 8082 --nprobe 80 --re
|
|
||||||
```
|
|
||||||
|
|
||||||
Result:
|
|
||||||
```
|
|
||||||
Evaluation Results for nprobe = 80:
|
|
||||||
Evaluation Results for nprobe = 80:
|
|
||||||
Average F1 Score: 0.5708
|
|
||||||
Average Exact Match Score: 0.4500
|
|
||||||
Average Recall Rate: 0.9333
|
|
||||||
Average total latency: 3.709s
|
|
||||||
Average search time: 3.696s
|
|
||||||
```
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user