Compare commits
206 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14f096dfe3 | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||
12
.github/workflows/build-and-publish.yml
vendored
Normal file
12
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
358
.github/workflows/build-reusable.yml
vendored
Normal file
358
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,358 @@
|
||||
name: Reusable Build
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'Git ref to build'
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Lint and Format Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.9'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
python: '3.11'
|
||||
- os: macos-15
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
python: '3.10'
|
||||
- os: macos-13
|
||||
python: '3.11'
|
||||
- os: macos-13
|
||||
python: '3.12'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
patchelf
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Use brew --prefix to automatically detect Homebrew installation path
|
||||
HOMEBREW_PREFIX=$(brew --prefix)
|
||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||
|
||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
|
||||
# Set compiler flags for OpenMP (required for both backends)
|
||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
# But Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# Must match the Homebrew libraries for each macOS version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
HNSW_TARGET="13.0"
|
||||
DISKANN_TARGET="13.3"
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
pytest tests/ -v --tb=short
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
|
||||
|
||||
arch-smoke:
|
||||
name: Arch Linux smoke test (install & import)
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: archlinux:latest
|
||||
|
||||
steps:
|
||||
- name: Prepare system
|
||||
run: |
|
||||
pacman -Syu --noconfirm
|
||||
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||
|
||||
- name: Download ALL wheel artifacts from this run
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
# Don't specify name, download all artifacts
|
||||
path: ./wheels
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Create virtual environment and install wheels
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
uv pip install --find-links wheels leann-core
|
||||
uv pip install --find-links wheels leann-backend-hnsw
|
||||
uv pip install --find-links wheels leann-backend-diskann
|
||||
uv pip install --find-links wheels leann
|
||||
|
||||
- name: Import & tiny runtime check
|
||||
env:
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
python - <<'PY'
|
||||
import leann
|
||||
import leann_backend_hnsw as h
|
||||
import leann_backend_diskann as d
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
b = LeannBuilder(backend_name="hnsw")
|
||||
b.add_text("hello arch")
|
||||
b.build_index("arch_demo.leann")
|
||||
s = LeannSearcher("arch_demo.leann")
|
||||
print("search:", s.search("hello", top_k=1))
|
||||
PY
|
||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version to release (e.g., 0.1.2)'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
update-version:
|
||||
name: Update Version
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
# Remove 'v' prefix if present for validation
|
||||
VERSION_CLEAN="${{ inputs.version }}"
|
||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
# Check current version
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
else
|
||||
./scripts/bump_version.sh ${{ inputs.version }}
|
||||
git config user.name "GitHub Actions"
|
||||
git config user.email "actions@github.com"
|
||||
git add packages/*/pyproject.toml
|
||||
git commit -m "chore: release v${{ inputs.version }}"
|
||||
git push origin main
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
name: Build packages
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
needs: [update-version, build-packages]
|
||||
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
run: |
|
||||
if [ -z "$TWINE_PASSWORD" ]; then
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||
else
|
||||
git tag "v${{ inputs.version }}"
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
else
|
||||
gh release create "v${{ inputs.version }}" \
|
||||
--title "Release v${{ inputs.version }}" \
|
||||
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||
--latest
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
37
.gitignore
vendored
37
.gitignore
vendored
@@ -9,16 +9,16 @@ demo/indices/
|
||||
outputs/
|
||||
*.pkl
|
||||
*.pdf
|
||||
*.idx
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
scripts/
|
||||
lm_eval.egg-info/
|
||||
demo/experiment_results/**/*.json
|
||||
*.jsonl
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
!CMakeLists.txt
|
||||
@@ -35,11 +35,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/huawei_pangu.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
@@ -86,4 +90,21 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
paru-bin/
|
||||
|
||||
CLAUDE.md
|
||||
CLAUDE.local.md
|
||||
.claude/*.local.*
|
||||
.claude/local/*
|
||||
|
||||
benchmarks/data/
|
||||
!benchmarks/data/prompts_g5/*.txt
|
||||
!benchmarks/run_all.sh
|
||||
!benchmarks/run_speed_bench_all.sh
|
||||
!benchmarks/simple_mac_tpt_test.py
|
||||
!benchmarks/run_speed_bench_all.sh
|
||||
!benchmarks/run_speed_bench_all.sh
|
||||
!benchmarks/run_speed_bench_all.sh
|
||||
|
||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"charliermarsh.ruff",
|
||||
]
|
||||
}
|
||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
},
|
||||
"editor.insertSpaces": true,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"ruff.enable": true,
|
||||
"files.watcherExclude": {
|
||||
"**/.venv/**": true,
|
||||
"**/__pycache__/**": true,
|
||||
"**/*.egg-info/**": true,
|
||||
"**/build/**": true,
|
||||
"**/dist/**": true
|
||||
}
|
||||
}
|
||||
744
README.md
744
README.md
@@ -3,20 +3,25 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
|
||||
|
||||
|
||||
@@ -26,57 +31,172 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
|
||||
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
|
||||
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Quick Start in 1 minute
|
||||
## Installation
|
||||
|
||||
### 📦 Prerequisites: Install uv
|
||||
|
||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
### 🚀 Quick Install
|
||||
|
||||
Clone the repository to access all examples and try amazing applications,
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
```
|
||||
|
||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
<!--
|
||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||
</summary>
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
|
||||
Note: DiskANN requires MacOS 13.3 or later.
|
||||
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
brew install libomp boost protobuf zeromq pkgconf
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
|
||||
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||
|
||||
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
sudo apt-get update && sudo apt-get install -y \
|
||||
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
libmkl-full-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (Arch Linux):**
|
||||
|
||||
```bash
|
||||
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||
|
||||
# For MKL in DiskANN
|
||||
sudo pacman -S --needed base-devel git
|
||||
git clone https://aur.archlinux.org/paru-bin.git
|
||||
cd paru-bin && makepkg -si
|
||||
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||
|
||||
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||
|
||||
```bash
|
||||
sudo dnf groupinstall -y "Development Tools"
|
||||
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||
|
||||
# For MKL in DiskANN
|
||||
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
## Quick Start
|
||||
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
|
||||
*macOS:*
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
|
||||
Set your OpenAI API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||
|
||||
**macOS:**
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
|
||||
@@ -85,7 +205,8 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
*Linux:*
|
||||
**Linux:**
|
||||
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -97,90 +218,127 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
```
|
||||
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
|
||||
## Wild Things You Can Do
|
||||
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
|
||||
### Process Any Documents (.pdf, .txt, .md)
|
||||
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
|
||||
|
||||
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into apps/documents/data/
|
||||
python -m apps.documents
|
||||
|
||||
# Or with uv
|
||||
uv run python -m apps.documents
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
## ⭐ Flexible Configuration
|
||||
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
### Search Your Entire Life
|
||||
```bash
|
||||
python -m apps.email
|
||||
# "What's the number of class recommend to take per semester for incoming EECS students?"
|
||||
```
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||
|
||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||
|
||||
```bash
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python -m apps.email
|
||||
# Core Parameters (General preprocessing for all examples)
|
||||
--index-dir DIR # Directory to store the index (default: current directory)
|
||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Run with custom index directory
|
||||
python -m apps.email --index-dir "./my_mail_index"
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python -m apps.email --max-emails -1
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python -m apps.email --max-emails 1000
|
||||
# Search Parameters
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||
|
||||
# Run a single query
|
||||
python -m apps.email --query "What did my boss say about deadlines?"
|
||||
# Chunking Parameters
|
||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||
|
||||
# Index Building Parameters
|
||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||
--graph-degree N # Graph degree for index construction (default: 32)
|
||||
--build-complexity N # Build complexity for index construction (default: 64)
|
||||
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--data-dir DIR # Directory containing documents to process (default: data)
|
||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Process all documents with larger chunks for academic papers
|
||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
|
||||
# Enable AST-aware chunking for code files
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||
|
||||
# Or use the specialized code RAG for better code understanding
|
||||
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
|
||||
```bash
|
||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
```
|
||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||
--include-html # Include HTML content in processing (useful for newsletters)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search work emails from a specific account
|
||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||
|
||||
# Find all receipts and order confirmations (includes HTML)
|
||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -194,28 +352,32 @@ Once the index is built, you can ask questions like:
|
||||
- "Show me emails about travel expenses"
|
||||
</details>
|
||||
|
||||
### Time Machine for the Web
|
||||
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python -m apps.browser
|
||||
# "Tell me my browser history about machine learning system stuff?"
|
||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python -m apps.browser
|
||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python -m apps.browser --index-dir "./my_chrome_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search academic research from your browsing history
|
||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||
|
||||
# Limit number of history entries processed (useful for testing)
|
||||
python -m apps.browser --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python -m apps.browser --query "What websites did I visit about machine learning?"
|
||||
# Track competitor analysis across work profile
|
||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -248,44 +410,58 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### WeChat Detective
|
||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python -m apps.wechat
|
||||
# "Show me all group chats about weekend plans"
|
||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
|
||||
First, you need to install the WeChat exporter:
|
||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||
|
||||
```bash
|
||||
brew install sunnyyoung/repo/wechattweak-cli
|
||||
```
|
||||
|
||||
or install it manually (if you have issues with Homebrew):
|
||||
|
||||
```bash
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```bash
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default settings (recommended for first run)
|
||||
python -m apps.wechat
|
||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||
--force-export # Force re-export even if data exists
|
||||
```
|
||||
|
||||
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||
python -m apps.wechat --export-dir "./my_wechat_exports"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search for travel plans discussed in group chats
|
||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||
|
||||
# Run with custom index directory
|
||||
python -m apps.wechat --index-dir "./my_wechat_index"
|
||||
|
||||
# Limit number of chat entries processed (useful for testing)
|
||||
python -m apps.wechat --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python -m apps.wechat --query "Show me conversations about travel plans"
|
||||
# Re-export and search recent chats (useful after new messages)
|
||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -299,17 +475,70 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||
|
||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||
|
||||
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||
|
||||
</details>
|
||||
|
||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||
|
||||
**Key features:**
|
||||
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||
- 📚 **Context-aware assistance** for debugging and development
|
||||
- 🚀 **Zero-config setup** with automatic language detection
|
||||
|
||||
```bash
|
||||
# Install LEANN globally for MCP integration
|
||||
uv tool install leann-core --with leann
|
||||
claude mcp add --scope user leann-server -- leann_mcp
|
||||
# Setup is automatic - just start using Claude Code!
|
||||
```
|
||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||
|
||||

|
||||
|
||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
### Installation
|
||||
|
||||
# Search your documents
|
||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
leann --help
|
||||
```
|
||||
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core --with leann
|
||||
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
```
|
||||
|
||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
||||
leann build my-docs --docs ./your_documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
@@ -317,30 +546,36 @@ leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
# Remove an index
|
||||
leann remove my-docs
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||
- Smart text chunking with overlap
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||
- Smart text chunking with overlap for all other content
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `~/.leann/indexes/`
|
||||
- Organized index storage in `.leann/indexes/` (project-local)
|
||||
- Support for advanced search parameters
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
||||
|
||||
Options:
|
||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact Use compact storage (default: true)
|
||||
--recompute Enable recomputation (default: true)
|
||||
--graph-degree N Graph degree (default: 32)
|
||||
--complexity N Build complexity (default: 64)
|
||||
--force Force rebuild existing index
|
||||
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||
--recompute / --no-recompute Enable recomputation (default: true)
|
||||
```
|
||||
|
||||
**Search Command:**
|
||||
@@ -348,9 +583,9 @@ Options:
|
||||
leann search INDEX_NAME QUERY [OPTIONS]
|
||||
|
||||
Options:
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute-embeddings Use recomputation for highest accuracy
|
||||
--top-k N Number of results (default: 5)
|
||||
--complexity N Search complexity (default: 64)
|
||||
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||
--pruning-strategy {global,local,proportional}
|
||||
```
|
||||
|
||||
@@ -365,8 +600,60 @@ Options:
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
```bash
|
||||
leann list
|
||||
|
||||
# Lists all indexes across all projects with status indicators:
|
||||
# ✅ - Index is complete and ready to use
|
||||
# ❌ - Index is incomplete or corrupted
|
||||
# 📁 - CLI-created index (in .leann/indexes/)
|
||||
# 📄 - App-created index (*.leann.meta.json files)
|
||||
```
|
||||
|
||||
**Remove Command:**
|
||||
```bash
|
||||
leann remove INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--force, -f Force removal without confirmation
|
||||
|
||||
# Smart removal: automatically finds and safely removes indexes
|
||||
# - Shows all matching indexes across projects
|
||||
# - Requires confirmation for cross-project removal
|
||||
# - Interactive selection when multiple matches found
|
||||
# - Supports both CLI and app-created indexes
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🚀 Advanced Features
|
||||
|
||||
### 🎯 Metadata Filtering
|
||||
|
||||
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||
|
||||
```python
|
||||
# Add metadata during indexing
|
||||
builder.add_text(
|
||||
"def authenticate_user(token): ...",
|
||||
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||
)
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||
|
||||
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
<p align="center">
|
||||
@@ -377,60 +664,36 @@ Options:
|
||||
|
||||
**Core techniques:**
|
||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||
**Backends:**
|
||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Run the comparison yourself:
|
||||
```bash
|
||||
python -m apps.benchmarks
|
||||
```
|
||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||
|
||||
| System | Storage |
|
||||
|--------|---------|
|
||||
| FAISS HNSW | 5.5 MB |
|
||||
| LEANN | 0.5 MB |
|
||||
| **Savings** | **91%** |
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||
|
||||
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||
|
||||
<!-- ### Memory Usage Comparison
|
||||
|
||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||
| **Reduction** | **x%** | **x%** | **x%** |
|
||||
|
||||
### Query Performance of LEANN
|
||||
|
||||
| Backend | Index Size | Query Time | Recall@3 |
|
||||
| ------------------- | ---------- | ---------- | --------- |
|
||||
| DiskANN | 1M docs | xms | 0.95 |
|
||||
| HNSW | 1M docs | xms | 0.95 | -->
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python -m apps.evaluation data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python -m apps.evaluation data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -442,108 +705,25 @@ If you find Leann useful, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025leannlowstoragevectorindex,
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||
year={2025},
|
||||
eprint={2506.08276},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DB},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ Features
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
### Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||
|
||||
|
||||
<!-- ## ❓ FAQ
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### NCCL Topology Error
|
||||
|
||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||
|
||||
```
|
||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||
```
|
||||
|
||||
**Solution**: Set these environment variables before running your script:
|
||||
|
||||
```bash
|
||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
``` -->
|
||||
## FAQ
|
||||
|
||||
### 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
## ❓ [FAQ →](docs/faq.md)
|
||||
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
### 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
### 🚀 Q3 2025
|
||||
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
### 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
|
||||
## 📄 License
|
||||
|
||||
@@ -551,13 +731,18 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
- **Microsoft Research** for the DiskANN algorithm
|
||||
- **Meta AI** for FAISS and optimization insights
|
||||
- **HuggingFace** for the transformer ecosystem
|
||||
- **Our amazing contributors** who make this possible
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
---
|
||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
||||
<p align="center">
|
||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||
</p>
|
||||
@@ -565,4 +750,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
|
||||
342
apps/base_rag_example.py
Normal file
342
apps/base_rag_example.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf", "simulated"],
|
||||
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
ast_group.add_argument(
|
||||
"--use-ast-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum characters per AST chunk (default: 512)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks (default: 64)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
# Simulated LLM doesn't need additional configuration
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
# Register project directory so leann list can discover this index
|
||||
# The index is saved as args.index_dir/index_name.leann
|
||||
# We want to register the current working directory where the app is run
|
||||
register_project_directory(Path.cwd())
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||
)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
@@ -1,338 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
"""Get current memory usage in MB"""
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
def print_memory_stats(stage: str, start_mem: float):
|
||||
"""Print memory statistics"""
|
||||
current_mem = get_memory_usage()
|
||||
diff = current_mem - start_mem
|
||||
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||
return current_mem
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.start_mem = get_memory_usage()
|
||||
self.stages = []
|
||||
|
||||
def checkpoint(self, stage: str):
|
||||
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||
self.stages.append((stage, current_mem))
|
||||
return current_mem
|
||||
|
||||
def summary(self):
|
||||
print(f"\n=== {self.name} Memory Summary ===")
|
||||
for stage, mem in self.stages:
|
||||
print(f"{stage}: {mem:.1f} MB")
|
||||
peak_mem = max(mem for _, mem in self.stages)
|
||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||
return peak_mem
|
||||
|
||||
|
||||
def test_faiss_hnsw():
|
||||
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING FAISS HNSW VECTOR STORE")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Get the directory of this script
|
||||
script_dir = Path(__file__).parent
|
||||
faiss_script = script_dir / "faiss_only.py"
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(faiss_script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Stderr:", result.stderr)
|
||||
|
||||
if result.returncode != 0:
|
||||
return {
|
||||
"peak_memory": float("inf"),
|
||||
"error": f"Process failed with code {result.returncode}",
|
||||
}
|
||||
|
||||
# Parse peak memory from output
|
||||
lines = result.stdout.split("\n")
|
||||
peak_memory = 0.0
|
||||
|
||||
for line in lines:
|
||||
if "Peak Memory:" in line:
|
||||
peak_memory = float(
|
||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||
)
|
||||
|
||||
return {"peak_memory": peak_memory}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"peak_memory": float("inf"),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def test_leann_hnsw():
|
||||
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||
print("=" * 50)
|
||||
|
||||
tracker = MemoryTracker("LEANN HNSW Search")
|
||||
|
||||
# Import and setup
|
||||
tracker.checkpoint("Initial")
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
tracker.checkpoint("After imports")
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"../documents/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data()
|
||||
|
||||
tracker.checkpoint("After document loading")
|
||||
|
||||
# Parse into chunks
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
# Build LEANN index
|
||||
INDEX_DIR = Path("./test_leann_comparison")
|
||||
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||
print("Loading existing LEANN HNSW index...")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
else:
|
||||
print("Building new LEANN HNSW index...")
|
||||
# Clean up previous index
|
||||
import shutil
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
tracker.checkpoint("After builder setup")
|
||||
|
||||
print("Building LEANN HNSW index...")
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
del builder
|
||||
gc.collect()
|
||||
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
# Find existing LEANN index
|
||||
index_paths = [
|
||||
"./test_leann_comparison/comparison.leann",
|
||||
]
|
||||
index_path = None
|
||||
for path in index_paths:
|
||||
if os.path.exists(path + ".meta.json"):
|
||||
index_path = path
|
||||
break
|
||||
|
||||
if not index_path:
|
||||
print("❌ LEANN index not found. Please build it first")
|
||||
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||
|
||||
# Measure runtime memory overhead
|
||||
print("\nMeasuring runtime memory overhead...")
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
# Load searcher
|
||||
searcher = LeannSearcher(index_path)
|
||||
tracker.checkpoint("After searcher loading")
|
||||
|
||||
|
||||
|
||||
print("Running search queries...")
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
start_time = time.time()
|
||||
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||
_ = searcher.search(query, top_k=20, ef=120)
|
||||
query_time = time.time() - start_time
|
||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||
tracker.checkpoint(f"After query {i + 1}")
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
|
||||
# Get storage size before cleanup
|
||||
storage_size = 0
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
if INDEX_DIR.exists():
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||
for filename in filenames:
|
||||
# Only count actual index files, skip text data and backups
|
||||
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||
continue
|
||||
# Count .index, .idx, .map files (actual index structures)
|
||||
if filename.endswith((".index", ".idx", ".map")):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
total_size += os.path.getsize(filepath)
|
||||
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||
|
||||
# Clean up
|
||||
del searcher
|
||||
gc.collect()
|
||||
|
||||
return {
|
||||
"peak_memory": peak_memory,
|
||||
"storage_size": storage_size,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Run comparison tests"""
|
||||
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||
print("=" * 60)
|
||||
|
||||
# Test Faiss HNSW
|
||||
faiss_results = test_faiss_hnsw()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
time.sleep(2)
|
||||
|
||||
# Test LEANN HNSW
|
||||
leann_results = test_leann_hnsw()
|
||||
|
||||
# Final comparison
|
||||
print("\n" + "=" * 60)
|
||||
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||
print("=" * 60)
|
||||
|
||||
# Get storage sizes
|
||||
faiss_storage_size = 0
|
||||
leann_storage_size = leann_results.get("storage_size", 0)
|
||||
|
||||
# Get Faiss storage size using Python
|
||||
if os.path.exists("./storage_faiss"):
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
total_size += os.path.getsize(filepath)
|
||||
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||
|
||||
print("Faiss HNSW:")
|
||||
if "error" in faiss_results:
|
||||
print(f" ❌ Failed: {faiss_results['error']}")
|
||||
else:
|
||||
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||
|
||||
print("\nLEANN HNSW:")
|
||||
if "error" in leann_results:
|
||||
print(f" ❌ Failed: {leann_results['error']}")
|
||||
else:
|
||||
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||
|
||||
# Calculate improvements only if both tests succeeded
|
||||
if "error" not in faiss_results and "error" not in leann_results:
|
||||
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||
|
||||
print("\nLEANN vs Faiss Performance:")
|
||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||
print(
|
||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||
)
|
||||
|
||||
# Storage comparison
|
||||
if leann_storage_size > faiss_storage_size:
|
||||
storage_ratio = leann_storage_size / faiss_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||
)
|
||||
elif faiss_storage_size > leann_storage_size:
|
||||
storage_ratio = faiss_storage_size / leann_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||
)
|
||||
else:
|
||||
print(" Storage Size: similar")
|
||||
else:
|
||||
if "error" not in leann_results:
|
||||
print("\n✅ LEANN HNSW completed successfully!")
|
||||
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||
if "error" not in faiss_results:
|
||||
print("\n✅ Faiss HNSW completed successfully!")
|
||||
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test only Faiss HNSW"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.start_mem = get_memory_usage()
|
||||
self.stages = []
|
||||
|
||||
def checkpoint(self, stage: str):
|
||||
current_mem = get_memory_usage()
|
||||
diff = current_mem - self.start_mem
|
||||
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||
self.stages.append((stage, current_mem))
|
||||
return current_mem
|
||||
|
||||
def summary(self):
|
||||
peak_mem = max(mem for _, mem in self.stages)
|
||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||
return peak_mem
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu`")
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
StorageContext,
|
||||
Settings,
|
||||
node_parser,
|
||||
Document,
|
||||
)
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
tracker = MemoryTracker("Faiss HNSW")
|
||||
tracker.checkpoint("Initial")
|
||||
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
Settings.embed_model = embed_model
|
||||
tracker.checkpoint("After embedding model setup")
|
||||
|
||||
d = 768
|
||||
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||
faiss_index.hnsw.efConstruction = 64
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"../documents/data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data()
|
||||
tracker.checkpoint("After document loading")
|
||||
|
||||
# Parse into chunks using the same splitter as LEANN
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
tracker.checkpoint("After text splitter setup")
|
||||
|
||||
# Check if index already exists and try to load it
|
||||
index_loaded = False
|
||||
if os.path.exists("./storage_faiss"):
|
||||
print("Loading existing Faiss HNSW index...")
|
||||
try:
|
||||
# Use the correct Faiss loading pattern from the example
|
||||
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||
storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||
)
|
||||
from llama_index.core import load_index_from_storage
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
print(f"Index loaded from ./storage_faiss")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
index_loaded = True
|
||||
except Exception as e:
|
||||
print(f"Failed to load existing index: {e}")
|
||||
print("Cleaning up corrupted index and building new one...")
|
||||
# Clean up corrupted index
|
||||
import shutil
|
||||
if os.path.exists("./storage_faiss"):
|
||||
shutil.rmtree("./storage_faiss")
|
||||
|
||||
if not index_loaded:
|
||||
print("Building new Faiss HNSW index...")
|
||||
|
||||
# Use the correct Faiss building pattern from the example
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
transformations=[node_parser]
|
||||
)
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
# Save index to disk using the correct pattern
|
||||
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||
tracker.checkpoint("After index saving")
|
||||
|
||||
# Measure runtime memory overhead
|
||||
print("\nMeasuring runtime memory overhead...")
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
start_time = time.time()
|
||||
_ = query_engine.query(query)
|
||||
query_time = time.time() - start_time
|
||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||
tracker.checkpoint(f"After query {i + 1}")
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,201 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
try:
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from local readers module
|
||||
from .readers import ChromeHistoryReader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
parser.add_argument('--query', type=str, default=None,
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||
help='Automatically find all Chrome profiles (default: True)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from .readers import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming"
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,176 +0,0 @@
|
||||
import sqlite3
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
class ChromeHistoryReader(BaseReader):
|
||||
"""
|
||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||
|
||||
Reads Chrome history from the default Chrome profile location and creates documents
|
||||
with embedded metadata similar to the email reader structure.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load Chrome history data from the default Chrome profile location.
|
||||
|
||||
Args:
|
||||
input_dir: Not used for Chrome history (kept for compatibility)
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of history entries to read.
|
||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||
|
||||
# Default Chrome profile path on macOS
|
||||
if chrome_profile_path is None:
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return docs
|
||||
|
||||
try:
|
||||
# Connect to the Chrome history database
|
||||
print(f"Connecting to database: {history_db_path}")
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query to get browsing history with metadata (removed created_time column)
|
||||
query = """
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
"""
|
||||
|
||||
print(f"Executing query on database: {history_db_path}")
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
print(f"Query returned {len(rows)} rows")
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[BROWSING HISTORY METADATA]
|
||||
URL: {url}
|
||||
Title: {title}
|
||||
Last Visit: {last_visit}
|
||||
Visit Count: {visit_count}
|
||||
Typed Count: {typed_count}
|
||||
Hidden: {hidden}
|
||||
[END METADATA]
|
||||
|
||||
Title: {title}
|
||||
URL: {url}
|
||||
Last visited: {last_visit}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
conn.close()
|
||||
print(f"Loaded {len(docs)} Chrome history documents")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
return docs
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_chrome_profiles() -> List[Path]:
|
||||
"""
|
||||
Find all Chrome profile directories.
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to Chrome profile directories
|
||||
"""
|
||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||
profile_dirs = []
|
||||
|
||||
if not chrome_base_path.exists():
|
||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||
return profile_dirs
|
||||
|
||||
# Find all profile directories
|
||||
for profile_dir in chrome_base_path.iterdir():
|
||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||
history_path = profile_dir / "History"
|
||||
if history_path.exists():
|
||||
profile_dirs.append(profile_dir)
|
||||
print(f"Found Chrome profile: {profile_dir}")
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
return profile_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||
"""
|
||||
Export Chrome history to a text file using the same SQL query format.
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
"""
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
cursor.execute(query, (max_count,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for row in rows:
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||
|
||||
conn.close()
|
||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
22
apps/chunking/__init__.py
Normal file
22
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Chunking utilities for LEANN RAG applications.
|
||||
Provides AST-aware and traditional text chunking functionality.
|
||||
"""
|
||||
|
||||
from .utils import (
|
||||
CODE_EXTENSIONS,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
"detect_code_files",
|
||||
"get_language_from_extension",
|
||||
]
|
||||
320
apps/chunking/utils.py
Normal file
320
apps/chunking/utils.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Enhanced chunking utilities with AST-aware code chunking support.
|
||||
Provides unified interface for both traditional and AST-based text chunking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code file extensions supported by astchunk
|
||||
CODE_EXTENSIONS = {
|
||||
".py": "python",
|
||||
".java": "java",
|
||||
".cs": "csharp",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".js": "typescript",
|
||||
".jsx": "typescript",
|
||||
}
|
||||
|
||||
# Default chunk parameters for different content types
|
||||
DEFAULT_CHUNK_PARAMS = {
|
||||
"code": {
|
||||
"max_chunk_size": 512,
|
||||
"chunk_overlap": 64,
|
||||
},
|
||||
"text": {
|
||||
"chunk_size": 256,
|
||||
"chunk_overlap": 128,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||
"""
|
||||
Separate documents into code files and regular text files.
|
||||
|
||||
Args:
|
||||
documents: List of LlamaIndex Document objects
|
||||
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
|
||||
|
||||
Returns:
|
||||
Tuple of (code_documents, text_documents)
|
||||
"""
|
||||
if code_extensions is None:
|
||||
code_extensions = CODE_EXTENSIONS
|
||||
|
||||
code_docs = []
|
||||
text_docs = []
|
||||
|
||||
for doc in documents:
|
||||
# Get file path from metadata
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if not file_path:
|
||||
# Fallback to file_name
|
||||
file_path = doc.metadata.get("file_name", "")
|
||||
|
||||
if file_path:
|
||||
file_ext = Path(file_path).suffix.lower()
|
||||
if file_ext in code_extensions:
|
||||
# Add language info to metadata
|
||||
doc.metadata["language"] = code_extensions[file_ext]
|
||||
doc.metadata["is_code"] = True
|
||||
code_docs.append(doc)
|
||||
else:
|
||||
doc.metadata["is_code"] = False
|
||||
text_docs.append(doc)
|
||||
else:
|
||||
# If no file path, treat as text
|
||||
doc.metadata["is_code"] = False
|
||||
text_docs.append(doc)
|
||||
|
||||
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
||||
return code_docs, text_docs
|
||||
|
||||
|
||||
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||
"""Get the programming language from file extension."""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return CODE_EXTENSIONS.get(ext)
|
||||
|
||||
|
||||
def create_ast_chunks(
|
||||
documents,
|
||||
max_chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
metadata_template: str = "default",
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create AST-aware chunks from code documents using astchunk.
|
||||
|
||||
Args:
|
||||
documents: List of code documents
|
||||
max_chunk_size: Maximum characters per chunk
|
||||
chunk_overlap: Number of AST nodes to overlap between chunks
|
||||
metadata_template: Template for chunk metadata
|
||||
|
||||
Returns:
|
||||
List of text chunks with preserved code structure
|
||||
"""
|
||||
try:
|
||||
from astchunk import ASTChunkBuilder
|
||||
except ImportError as e:
|
||||
logger.error(f"astchunk not available: {e}")
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||
|
||||
all_chunks = []
|
||||
|
||||
for doc in documents:
|
||||
# Get language from metadata (set by detect_code_files)
|
||||
language = doc.metadata.get("language")
|
||||
if not language:
|
||||
logger.warning(
|
||||
"No language detected for document, falling back to traditional chunking"
|
||||
)
|
||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
||||
all_chunks.extend(traditional_chunks)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Configure astchunk
|
||||
configs = {
|
||||
"max_chunk_size": max_chunk_size,
|
||||
"language": language,
|
||||
"metadata_template": metadata_template,
|
||||
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||
}
|
||||
|
||||
# Add repository-level metadata if available
|
||||
repo_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
"creation_date": doc.metadata.get("creation_date", ""),
|
||||
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
||||
}
|
||||
configs["repo_level_metadata"] = repo_metadata
|
||||
|
||||
# Create chunk builder and process
|
||||
chunk_builder = ASTChunkBuilder(**configs)
|
||||
code_content = doc.get_content()
|
||||
|
||||
if not code_content or not code_content.strip():
|
||||
logger.warning("Empty code content, skipping")
|
||||
continue
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
|
||||
# Extract text content from chunks
|
||||
for chunk in chunks:
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
elif isinstance(chunk, dict) and "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
else:
|
||||
# Try to convert to string
|
||||
chunk_text = str(chunk)
|
||||
|
||||
if chunk_text and chunk_text.strip():
|
||||
all_chunks.append(chunk_text.strip())
|
||||
|
||||
logger.info(
|
||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||
logger.info("Falling back to traditional chunking")
|
||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
||||
all_chunks.extend(traditional_chunks)
|
||||
|
||||
return all_chunks
|
||||
|
||||
|
||||
def create_traditional_chunks(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||
|
||||
Args:
|
||||
documents: List of documents to chunk
|
||||
chunk_size: Size of each chunk in characters
|
||||
chunk_overlap: Overlap between chunks
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
# Handle invalid chunk_size values
|
||||
if chunk_size <= 0:
|
||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||
chunk_size = 256
|
||||
|
||||
# Ensure chunk_overlap is not negative and not larger than chunk_size
|
||||
if chunk_overlap < 0:
|
||||
chunk_overlap = 0
|
||||
if chunk_overlap >= chunk_size:
|
||||
chunk_overlap = chunk_size // 2
|
||||
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
try:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
chunk_texts = [node.get_content() for node in nodes]
|
||||
all_texts.extend(chunk_texts)
|
||||
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
|
||||
except Exception as e:
|
||||
logger.error(f"Traditional chunking failed for document: {e}")
|
||||
# As last resort, add the raw content
|
||||
content = doc.get_content()
|
||||
if content and content.strip():
|
||||
all_texts.append(content.strip())
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
def create_text_chunks(
|
||||
documents,
|
||||
chunk_size: int = 256,
|
||||
chunk_overlap: int = 128,
|
||||
use_ast_chunking: bool = False,
|
||||
ast_chunk_size: int = 512,
|
||||
ast_chunk_overlap: int = 64,
|
||||
code_file_extensions: Optional[list[str]] = None,
|
||||
ast_fallback_traditional: bool = True,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Create text chunks from documents with optional AST support for code files.
|
||||
|
||||
Args:
|
||||
documents: List of LlamaIndex Document objects
|
||||
chunk_size: Size for traditional text chunks
|
||||
chunk_overlap: Overlap for traditional text chunks
|
||||
use_ast_chunking: Whether to use AST chunking for code files
|
||||
ast_chunk_size: Size for AST chunks
|
||||
ast_chunk_overlap: Overlap for AST chunks
|
||||
code_file_extensions: Custom list of code file extensions
|
||||
ast_fallback_traditional: Fall back to traditional chunking on AST errors
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for chunking")
|
||||
return []
|
||||
|
||||
# Create a local copy of supported extensions for this function call
|
||||
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||
|
||||
# Update supported extensions if provided
|
||||
if code_file_extensions:
|
||||
# Map extensions to languages (simplified mapping)
|
||||
ext_mapping = {
|
||||
".py": "python",
|
||||
".java": "java",
|
||||
".cs": "c_sharp",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
}
|
||||
for ext in code_file_extensions:
|
||||
if ext.lower() not in local_code_extensions:
|
||||
# Try to guess language from extension
|
||||
if ext.lower() in ext_mapping:
|
||||
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||
else:
|
||||
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||
|
||||
all_chunks = []
|
||||
|
||||
if use_ast_chunking:
|
||||
# Separate code and text documents using local extensions
|
||||
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||
|
||||
# Process code files with AST chunking
|
||||
if code_docs:
|
||||
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
|
||||
try:
|
||||
ast_chunks = create_ast_chunks(
|
||||
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||
)
|
||||
all_chunks.extend(ast_chunks)
|
||||
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
|
||||
except Exception as e:
|
||||
logger.error(f"AST chunking failed: {e}")
|
||||
if ast_fallback_traditional:
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
traditional_code_chunks = create_traditional_chunks(
|
||||
code_docs, chunk_size, chunk_overlap
|
||||
)
|
||||
all_chunks.extend(traditional_code_chunks)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Process text files with traditional chunking
|
||||
if text_docs:
|
||||
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
|
||||
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
|
||||
all_chunks.extend(text_chunks)
|
||||
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
|
||||
else:
|
||||
# Use traditional chunking for all files
|
||||
logger.info(f"Processing {len(documents)} documents with traditional chunking")
|
||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
return all_chunks
|
||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||
Specialized for code repositories with automatic language detection and
|
||||
optimized chunking parameters.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class CodeRAG(BaseRAGExample):
|
||||
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Code",
|
||||
description="Process and query code repositories with AST-aware chunking",
|
||||
default_index_name="code_index",
|
||||
)
|
||||
# Override defaults for code-specific usage
|
||||
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||
self.max_items_default = -1 # Process all code files by default
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add code-specific arguments."""
|
||||
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||
|
||||
code_group.add_argument(
|
||||
"--repo-dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Code repository directory to index (default: current directory)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-extensions",
|
||||
nargs="+",
|
||||
default=list(CODE_EXTENSIONS.keys()),
|
||||
help="File extensions to include (default: supported code extensions)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--exclude-dirs",
|
||||
nargs="+",
|
||||
default=[
|
||||
".git",
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"venv",
|
||||
".venv",
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
],
|
||||
help="Directories to exclude from indexing",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--max-file-size",
|
||||
type=int,
|
||||
default=1000000, # 1MB
|
||||
help="Maximum file size in bytes to process (default: 1MB)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-comments",
|
||||
action="store_true",
|
||||
help="Include comments in chunking (useful for documentation)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--preserve-imports",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Try to preserve import statements in chunks (default: True)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load code files and convert to AST-aware chunks."""
|
||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||
print(f"📁 Including extensions: {args.include_extensions}")
|
||||
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||
|
||||
# Check if repository directory exists
|
||||
repo_path = Path(args.repo_dir)
|
||||
if not repo_path.exists():
|
||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||
|
||||
# Load code files with filtering
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
"required_exts": args.include_extensions,
|
||||
"exclude_hidden": True,
|
||||
}
|
||||
|
||||
# Create exclusion filter
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Filter out unwanted files and directories."""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
if path.stat().st_size > args.max_file_size:
|
||||
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if in excluded directory
|
||||
for exclude_dir in args.exclude_dirs:
|
||||
if exclude_dir in path.parts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# Load documents with file filtering
|
||||
documents = SimpleDirectoryReader(
|
||||
args.repo_dir,
|
||||
file_extractor=None, # Use default extractors
|
||||
**reader_kwargs,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Apply custom filtering
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents = filtered_docs
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading code files: {e}")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||
)
|
||||
return []
|
||||
|
||||
print(f"✅ Loaded {len(documents)} code files")
|
||||
|
||||
# Show breakdown by language/extension
|
||||
ext_counts = {}
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||
|
||||
print("📊 Files by extension:")
|
||||
for ext, count in sorted(ext_counts.items()):
|
||||
print(f" {ext}: {count} files")
|
||||
|
||||
# Use AST-aware chunking by default for code
|
||||
print(
|
||||
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||
)
|
||||
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=256, # Fallback for non-code files
|
||||
chunk_overlap=64,
|
||||
use_ast_chunking=True, # Always use AST for code RAG
|
||||
ast_chunk_size=args.ast_chunk_size,
|
||||
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||
code_file_extensions=args.include_extensions,
|
||||
ast_fallback_traditional=True,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for code RAG
|
||||
print("\n💻 Code RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'How does the embedding computation work?'")
|
||||
print("- 'What are the main classes in this codebase?'")
|
||||
print("- 'Show me the search implementation'")
|
||||
print("- 'How is error handling implemented?'")
|
||||
print("- 'What design patterns are used?'")
|
||||
print("- 'Explain the chunking logic'")
|
||||
print("\n🚀 Features:")
|
||||
print("- ✅ AST-aware chunking preserves code structure")
|
||||
print("- ✅ Automatic language detection")
|
||||
print("- ✅ Smart filtering of large files and common excludes")
|
||||
print("- ✅ Optimized for code understanding")
|
||||
print("\nUsage examples:")
|
||||
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||
print(
|
||||
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = CodeRAG()
|
||||
asyncio.run(rag.run())
|
||||
131
apps/document_rag.py
Normal file
131
apps/document_rag.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--enable-code-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Determine chunking strategy
|
||||
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||
|
||||
if use_ast:
|
||||
print("Using AST-aware chunking for code files")
|
||||
|
||||
# Convert to text chunks with optional AST support
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
use_ast_chunking=use_ast,
|
||||
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\n🚀 NEW: Code-aware chunking available!")
|
||||
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||
print("- Supports Python, Java, C#, TypeScript files")
|
||||
print("- Better semantic understanding of code structure")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -1,113 +0,0 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
# Get the data directory relative to this module
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / "data"
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
str(data_dir),
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Leann Chat with various LLM backends."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,82 +0,0 @@
|
||||
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||
|
||||
各位好,
|
||||
|
||||
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||
|
||||
首先为自证身份,列举一些细节:
|
||||
|
||||
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||
|
||||
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||
|
||||
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||
|
||||
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||
|
||||
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||
|
||||
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||
|
||||
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||
|
||||
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||
|
||||
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||
|
||||
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||
|
||||
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||
|
||||
|
||||
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||
|
||||
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||
|
||||
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||
|
||||
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||
|
||||
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||
|
||||
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||
|
||||
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||
|
||||
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||
|
||||
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||
|
||||
现在,我累了,我想投降。
|
||||
|
||||
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||
|
||||
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||
|
||||
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||
|
||||
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||
|
||||
诺亚,再见
|
||||
|
||||
2025年7月6日凌晨 写于深圳
|
||||
|
||||
---
|
||||
|
||||
各位好,
|
||||
|
||||
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||
|
||||
我补充一些细节,以免某些人继续颠倒黑白。
|
||||
|
||||
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||
|
||||
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||
|
||||
---
|
||||
|
||||
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||
@@ -1,193 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from local readers module
|
||||
from .readers import EmlxReader, find_all_messages_directories
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path,
|
||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=12,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--include-html', action='store_true', default=False,
|
||||
help='Include HTML content in email processing (default: False)')
|
||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||
help='Embedding model to use (default: facebook/contriever)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from .readers import find_all_messages_directories
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Mbox parser.
|
||||
|
||||
Contains simple parser for mbox files.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fsspec import AbstractFileSystem
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MboxReader(BaseReader):
|
||||
"""
|
||||
Mbox parser.
|
||||
|
||||
Extract messages from mailbox files.
|
||||
Returns string including date, subject, sender, receiver and
|
||||
content for each message.
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_MESSAGE_FORMAT: str = (
|
||||
"Date: {_date}\n"
|
||||
"From: {_from}\n"
|
||||
"To: {_to}\n"
|
||||
"Subject: {_subject}\n"
|
||||
"Content: {_content}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
max_count: int = 0,
|
||||
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_count = max_count
|
||||
self.message_format = message_format
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
"""Parse file into string."""
|
||||
# Import required libraries
|
||||
import mailbox
|
||||
from email.parser import BytesParser
|
||||
from email.policy import default
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but MboxReader doesn't support loading "
|
||||
"from fsspec filesystems. Will load from local filesystem instead."
|
||||
)
|
||||
|
||||
i = 0
|
||||
results: List[str] = []
|
||||
# Load file using mailbox
|
||||
bytes_parser = BytesParser(policy=default).parse
|
||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||
|
||||
# Iterate through all messages
|
||||
for _, _msg in enumerate(mbox):
|
||||
try:
|
||||
msg: mailbox.mboxMessage = _msg
|
||||
# Parse multipart messages
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
ctype = part.get_content_type()
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
if "attachment" in cdispo:
|
||||
print(f"Attachment found: {part.get_filename()}")
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
content = part.get_payload(decode=True) # decode
|
||||
break
|
||||
# Get plain message payload for non-multipart messages
|
||||
else:
|
||||
content = msg.get_payload(decode=True)
|
||||
|
||||
# Parse message HTML content and remove unneeded whitespace
|
||||
soup = BeautifulSoup(content)
|
||||
stripped_content = " ".join(soup.get_text().split())
|
||||
# Format message to include date, sender, receiver and subject
|
||||
msg_string = self.message_format.format(
|
||||
_date=msg["date"],
|
||||
_from=msg["from"],
|
||||
_to=msg["to"],
|
||||
_subject=msg["subject"],
|
||||
_content=stripped_content,
|
||||
)
|
||||
# Add message string to results
|
||||
results.append(msg_string)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||
|
||||
# Increment counter and return if max count is met
|
||||
i += 1
|
||||
if self.max_count > 0 and i >= self.max_count:
|
||||
break
|
||||
|
||||
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||
|
||||
|
||||
class EmlxMboxReader(MboxReader):
|
||||
"""
|
||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||
|
||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||
1. Reading .emlx files from a directory
|
||||
2. Converting them to mbox format in memory
|
||||
3. Using the parent MboxReader's parsing logic
|
||||
"""
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||
"from fsspec filesystems. Will load from local filesystem instead."
|
||||
)
|
||||
|
||||
# Find all .emlx files in the directory
|
||||
emlx_files = list(directory.glob("*.emlx"))
|
||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||
|
||||
if not emlx_files:
|
||||
logger.warning(f"No .emlx files found in {directory}")
|
||||
return []
|
||||
|
||||
# Create a temporary mbox file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||
temp_mbox_path = temp_mbox.name
|
||||
|
||||
# Convert .emlx files to mbox format
|
||||
for emlx_file in emlx_files:
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx format: first line is length, rest is email content
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1] # Skip the length line
|
||||
|
||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||
continue
|
||||
|
||||
# Close the temporary file so MboxReader can read it
|
||||
temp_mbox.close()
|
||||
|
||||
try:
|
||||
# Use the parent MboxReader's logic to parse the mbox file
|
||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_mbox_path)
|
||||
except:
|
||||
pass
|
||||
@@ -1,124 +0,0 @@
|
||||
import os
|
||||
import email
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get('Subject', 'No Subject')
|
||||
from_addr = msg.get('From', 'Unknown')
|
||||
to_addr = msg.get('To', 'Unknown')
|
||||
date = msg.get('Date', 'Unknown')
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||
if part.get_content_type() == "text/html" and not self.include_html:
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[EMAIL METADATA]
|
||||
File: {filename}
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
[END METADATA]
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import email
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
if (
|
||||
part.get_content_type() == "text/html"
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fsspec import AbstractFileSystem
|
||||
from typing import Any
|
||||
|
||||
from fsspec import AbstractFileSystem
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
||||
"""
|
||||
|
||||
DEFAULT_MESSAGE_FORMAT: str = (
|
||||
"Date: {_date}\n"
|
||||
"From: {_from}\n"
|
||||
"To: {_to}\n"
|
||||
"Subject: {_subject}\n"
|
||||
"Content: {_content}"
|
||||
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||
)
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_count = max_count
|
||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse file into string."""
|
||||
# Import required libraries
|
||||
import mailbox
|
||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
||||
)
|
||||
|
||||
i = 0
|
||||
results: List[str] = []
|
||||
results: list[str] = []
|
||||
# Load file using mailbox
|
||||
bytes_parser = BytesParser(policy=default).parse
|
||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||
@@ -124,7 +118,7 @@ class MboxReader(BaseReader):
|
||||
class EmlxMboxReader(MboxReader):
|
||||
"""
|
||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||
|
||||
|
||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||
1. Reading .emlx files from a directory
|
||||
2. Converting them to mbox format in memory
|
||||
@@ -134,13 +128,13 @@ class EmlxMboxReader(MboxReader):
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||
@@ -150,37 +144,37 @@ class EmlxMboxReader(MboxReader):
|
||||
# Find all .emlx files in the directory
|
||||
emlx_files = list(directory.glob("*.emlx"))
|
||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||
|
||||
|
||||
if not emlx_files:
|
||||
logger.warning(f"No .emlx files found in {directory}")
|
||||
return []
|
||||
|
||||
# Create a temporary mbox file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||
temp_mbox_path = temp_mbox.name
|
||||
|
||||
|
||||
# Convert .emlx files to mbox format
|
||||
for emlx_file in emlx_files:
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
|
||||
# .emlx format: first line is length, rest is email content
|
||||
lines = content.split('\n', 1)
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1] # Skip the length line
|
||||
|
||||
|
||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Close the temporary file so MboxReader can read it
|
||||
temp_mbox.close()
|
||||
|
||||
|
||||
try:
|
||||
# Use the parent MboxReader's logic to parse the mbox file
|
||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_mbox_path)
|
||||
except:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -1,382 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script runs a recall evaluation on a given LEANN index.
|
||||
It correctly compares results by fetching the text content for both the new search
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print(
|
||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||
)
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if download_embeddings:
|
||||
# Download everything including embeddings (large files)
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("Data download complete (including embeddings)!")
|
||||
else:
|
||||
# Download only specific folders, excluding embeddings
|
||||
allow_patterns = [
|
||||
"ground_truth/**",
|
||||
"indices/**",
|
||||
"queries/**",
|
||||
"*.md",
|
||||
"*.txt",
|
||||
]
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
print("Data download complete (excluding embeddings)!")
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
if dataset_type:
|
||||
# Check if specific dataset embeddings exist
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
print(f"Embeddings for {dataset_type} already exist")
|
||||
return str(target_file)
|
||||
|
||||
print("Downloading embeddings from HuggingFace Hub...")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download only embeddings folder
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=["embeddings/**/*.pkl"],
|
||||
)
|
||||
print("Embeddings download complete!")
|
||||
|
||||
if dataset_type:
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
return str(target_file)
|
||||
|
||||
return str(embeddings_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading embeddings: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
"""
|
||||
golden_texts = set()
|
||||
for gid in golden_ids:
|
||||
try:
|
||||
# PassageManager uses string IDs
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(
|
||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||
)
|
||||
return golden_texts
|
||||
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
queries = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
Args:
|
||||
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||
output_path: Path where to save the index
|
||||
backend: Backend to use ("hnsw" or "diskann")
|
||||
"""
|
||||
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||
|
||||
# Create builder with appropriate parameters
|
||||
if backend == "hnsw":
|
||||
builder_kwargs = {
|
||||
"M": 32, # Graph degree
|
||||
"efConstruction": 256, # Construction complexity
|
||||
"is_compact": True, # Use compact storage
|
||||
"is_recompute": True, # Enable pruning for better recall
|
||||
}
|
||||
elif backend == "diskann":
|
||||
builder_kwargs = {
|
||||
"complexity": 64,
|
||||
"graph_degree": 32,
|
||||
"search_memory_maximum": 8.0, # GB
|
||||
"build_memory_maximum": 16.0, # GB
|
||||
}
|
||||
else:
|
||||
builder_kwargs = {}
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||
dimensions=768, # Will be auto-detected from embeddings
|
||||
**builder_kwargs,
|
||||
)
|
||||
|
||||
# Build index from precomputed embeddings
|
||||
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||
print(f"Index saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser.add_argument(
|
||||
"index_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path to the LEANN index to evaluate or build (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["evaluate", "build"],
|
||||
default="evaluate",
|
||||
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embeddings-file",
|
||||
type=str,
|
||||
help="Path to embeddings pickle file (optional for build mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["hnsw", "diskann"],
|
||||
default="hnsw",
|
||||
help="Backend to use for building index (default: hnsw)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
embeddings_file = args.embeddings_file
|
||||
# Try to detect dataset type from embeddings file path
|
||||
if "rpj_wiki" in str(embeddings_file):
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in str(embeddings_file):
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default
|
||||
else:
|
||||
# Auto-detect from index path if provided, otherwise default to DPR
|
||||
if args.index_path:
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
|
||||
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||
|
||||
# Auto-generate index path if not provided
|
||||
if not args.index_path:
|
||||
indices_dir = data_root / "indices" / dataset_type
|
||||
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||
print(f"Auto-generated index path: {args.index_path}")
|
||||
|
||||
print(f"Building index from embeddings: {embeddings_file}")
|
||||
built_index_path = build_index_from_embeddings(
|
||||
embeddings_file, args.index_path, args.backend
|
||||
)
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
else:
|
||||
# For evaluation mode, don't need embeddings
|
||||
download_data_if_needed(data_root, download_embeddings=False)
|
||||
|
||||
# Auto-detect index path if not provided
|
||||
if not args.index_path:
|
||||
# Default to using downloaded indices
|
||||
indices_dir = data_root / "indices"
|
||||
|
||||
# Try common datasets in order of preference
|
||||
for dataset in ["dpr", "rpj_wiki"]:
|
||||
dataset_dir = indices_dir / dataset
|
||||
if dataset_dir.exists():
|
||||
# Look for index files
|
||||
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||
dataset_dir.glob("*_disk.index")
|
||||
)
|
||||
if index_files:
|
||||
args.index_path = str(
|
||||
index_files[0].with_suffix("")
|
||||
) # Remove .index extension
|
||||
print(f"Using index: {args.index_path}")
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(
|
||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||
)
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = (
|
||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
)
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, "r") as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
queries = queries[:num_eval_queries]
|
||||
|
||||
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||
recall_scores = []
|
||||
search_times = []
|
||||
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(
|
||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
# Get golden texts directly from the searcher's passage manager
|
||||
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||
|
||||
overlap = len(new_texts & golden_texts)
|
||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||
recall_scores.append(recall)
|
||||
|
||||
print("\n--- EVALUATION RESULTS ---")
|
||||
print(f"Query: {queries[i]}")
|
||||
print(f"New Results: {new_texts}")
|
||||
print(f"Golden Results: {golden_texts}")
|
||||
print(f"Overlap: {overlap}")
|
||||
print(f"Recall: {recall}")
|
||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||
print("--------------------------------")
|
||||
|
||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||
avg_time = np.mean(search_times) if search_times else 0
|
||||
|
||||
print("\n🎉 --- Evaluation Complete ---")
|
||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,3 +1,3 @@
|
||||
from .history import ChromeHistoryReader
|
||||
|
||||
__all__ = ['ChromeHistoryReader']
|
||||
__all__ = ["ChromeHistoryReader"]
|
||||
@@ -1,77 +1,81 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChromeHistoryReader(BaseReader):
|
||||
"""
|
||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||
|
||||
|
||||
Reads Chrome history from the default Chrome profile location and creates documents
|
||||
with embedded metadata similar to the email reader structure.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Chrome history data from the default Chrome profile location.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Not used for Chrome history (kept for compatibility)
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of history entries to read.
|
||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||
|
||||
# Default Chrome profile path on macOS
|
||||
if chrome_profile_path is None:
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return docs
|
||||
|
||||
|
||||
try:
|
||||
# Connect to the Chrome history database
|
||||
print(f"Connecting to database: {history_db_path}")
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
# Query to get browsing history with metadata (removed created_time column)
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
"""
|
||||
|
||||
|
||||
print(f"Executing query on database: {history_db_path}")
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
print(f"Query returned {len(rows)} rows")
|
||||
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[Title]: {title}
|
||||
@@ -80,38 +84,43 @@ class ChromeHistoryReader(BaseReader):
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
"""
|
||||
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
conn.close()
|
||||
print(f"Loaded {len(docs)} Chrome history documents")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_chrome_profiles() -> List[Path]:
|
||||
def find_chrome_profiles() -> list[Path]:
|
||||
"""
|
||||
Find all Chrome profile directories.
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to Chrome profile directories
|
||||
"""
|
||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||
profile_dirs = []
|
||||
|
||||
|
||||
if not chrome_base_path.exists():
|
||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||
return profile_dirs
|
||||
|
||||
|
||||
# Find all profile directories
|
||||
for profile_dir in chrome_base_path.iterdir():
|
||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||
@@ -119,53 +128,59 @@ class ChromeHistoryReader(BaseReader):
|
||||
if history_path.exists():
|
||||
profile_dirs.append(profile_dir)
|
||||
print(f"Found Chrome profile: {profile_dir}")
|
||||
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
return profile_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||
def export_history_to_file(
|
||||
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Export Chrome history to a text file using the same SQL query format.
|
||||
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
"""
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
||||
cursor.execute(query, (max_count,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||
|
||||
f.write(
|
||||
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||
)
|
||||
|
||||
conn.close()
|
||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
@@ -2,30 +2,31 @@ import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WeChatHistoryReader(BaseReader):
|
||||
"""
|
||||
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||
|
||||
|
||||
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||
|
||||
|
||||
Also includes utilities for automatic WeChat chat history export.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||
|
||||
|
||||
def check_wechat_running(self) -> bool:
|
||||
"""Check if WeChat is currently running."""
|
||||
try:
|
||||
@@ -33,24 +34,30 @@ class WeChatHistoryReader(BaseReader):
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def install_wechattweak(self) -> bool:
|
||||
"""Install WeChatTweak CLI tool."""
|
||||
try:
|
||||
# Create wechat-exporter directory if it doesn't exist
|
||||
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||
if not wechattweak_path.exists():
|
||||
print("Downloading WeChatTweak CLI...")
|
||||
subprocess.run([
|
||||
"curl", "-L", "-o", str(wechattweak_path),
|
||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||
], check=True)
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-L",
|
||||
"-o",
|
||||
str(wechattweak_path),
|
||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Make executable
|
||||
wechattweak_path.chmod(0o755)
|
||||
|
||||
|
||||
# Install WeChatTweak
|
||||
print("Installing WeChatTweak...")
|
||||
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||
@@ -58,7 +65,7 @@ class WeChatHistoryReader(BaseReader):
|
||||
except Exception as e:
|
||||
print(f"Error installing WeChatTweak: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def restart_wechat(self):
|
||||
"""Restart WeChat to apply WeChatTweak."""
|
||||
try:
|
||||
@@ -69,302 +76,325 @@ class WeChatHistoryReader(BaseReader):
|
||||
time.sleep(5) # Wait for WeChat to start
|
||||
except Exception as e:
|
||||
print(f"Error restarting WeChat: {e}")
|
||||
|
||||
|
||||
def check_api_available(self) -> bool:
|
||||
"""Check if WeChatTweak API is available."""
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||
], capture_output=True, text=True, timeout=5)
|
||||
result = subprocess.run(
|
||||
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
def _extract_readable_text(self, content: str) -> str:
|
||||
"""
|
||||
Extract readable text from message content, removing XML and system messages.
|
||||
|
||||
|
||||
Args:
|
||||
content: The raw message content (can be string or dict)
|
||||
|
||||
|
||||
Returns:
|
||||
Cleaned, readable text
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
# Handle dictionary content (like quoted messages)
|
||||
if isinstance(content, dict):
|
||||
# Extract text from dictionary structure
|
||||
text_parts = []
|
||||
if 'title' in content:
|
||||
text_parts.append(str(content['title']))
|
||||
if 'quoted' in content:
|
||||
text_parts.append(str(content['quoted']))
|
||||
if 'content' in content:
|
||||
text_parts.append(str(content['content']))
|
||||
if 'text' in content:
|
||||
text_parts.append(str(content['text']))
|
||||
|
||||
if "title" in content:
|
||||
text_parts.append(str(content["title"]))
|
||||
if "quoted" in content:
|
||||
text_parts.append(str(content["quoted"]))
|
||||
if "content" in content:
|
||||
text_parts.append(str(content["content"]))
|
||||
if "text" in content:
|
||||
text_parts.append(str(content["text"]))
|
||||
|
||||
if text_parts:
|
||||
return " | ".join(text_parts)
|
||||
else:
|
||||
# If we can't extract meaningful text from dict, return empty
|
||||
return ""
|
||||
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return ""
|
||||
|
||||
|
||||
# Remove common prefixes like "wxid_xxx:\n"
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||
|
||||
# If it's just XML or system message, return empty
|
||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||
return ""
|
||||
|
||||
|
||||
return clean_content.strip()
|
||||
|
||||
|
||||
def _is_text_message(self, content: str) -> bool:
|
||||
"""
|
||||
Check if a message contains readable text content.
|
||||
|
||||
|
||||
Args:
|
||||
content: The message content (can be string or dict)
|
||||
|
||||
|
||||
Returns:
|
||||
True if the message contains readable text, False otherwise
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
|
||||
# Handle dictionary content
|
||||
if isinstance(content, dict):
|
||||
# Check if dict has any readable text fields
|
||||
text_fields = ['title', 'quoted', 'content', 'text']
|
||||
text_fields = ["title", "quoted", "content", "text"]
|
||||
for field in text_fields:
|
||||
if field in content and content[field]:
|
||||
if content.get(field):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
|
||||
# Skip image messages (contain XML with img tags)
|
||||
if '<img' in content and 'cdnurl' in content:
|
||||
if "<img" in content and "cdnurl" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip emoji messages (contain emoji XML tags)
|
||||
if '<emoji' in content and 'productid' in content:
|
||||
if "<emoji" in content and "productid" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip voice messages
|
||||
if '<voice' in content:
|
||||
if "<voice" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip video messages
|
||||
if '<video' in content:
|
||||
if "<video" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip file messages
|
||||
if '<appmsg' in content and 'appid' in content:
|
||||
if "<appmsg" in content and "appid" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip system messages (like "recalled a message")
|
||||
if 'recalled a message' in content:
|
||||
if "recalled a message" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Check if there's actual readable text (not just XML or system messages)
|
||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||
|
||||
# If after cleaning we have meaningful text, consider it readable
|
||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||
|
||||
def _concatenate_messages(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_length: int = 128,
|
||||
time_window_minutes: int = 30,
|
||||
overlap_messages: int = 0,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Concatenate messages based on length and time rules.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||
overlap_messages: Number of messages to overlap between consecutive groups
|
||||
|
||||
|
||||
Returns:
|
||||
List of concatenated message groups
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
|
||||
concatenated_groups = []
|
||||
current_group = []
|
||||
current_length = 0
|
||||
last_timestamp = None
|
||||
|
||||
|
||||
for message in messages:
|
||||
# Extract message info
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
message.get("fromUser", "")
|
||||
message.get("toUser", "")
|
||||
message.get("isSentFromSelf", False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not readable_text.strip():
|
||||
continue
|
||||
|
||||
|
||||
# Check time window constraint (only if time_window_minutes != -1)
|
||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||
if time_diff_minutes > time_window_minutes:
|
||||
# Time gap too large, start new group
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
current_length = sum(
|
||||
len(
|
||||
self._extract_readable_text(msg.get("content", ""))
|
||||
or msg.get("message", "")
|
||||
)
|
||||
for msg in current_group
|
||||
)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
|
||||
# Check length constraint (only if max_length != -1)
|
||||
message_length = len(readable_text)
|
||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||
# Current group would exceed max length, save it and start new
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
current_length = sum(
|
||||
len(
|
||||
self._extract_readable_text(msg.get("content", ""))
|
||||
or msg.get("message", "")
|
||||
)
|
||||
for msg in current_group
|
||||
)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
|
||||
# Add message to current group
|
||||
current_group.append(message)
|
||||
current_length += message_length
|
||||
last_timestamp = create_time
|
||||
|
||||
|
||||
# Add the last group if it exists
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return concatenated_groups
|
||||
|
||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||
|
||||
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||
"""
|
||||
Create concatenated content from a group of messages.
|
||||
|
||||
|
||||
Args:
|
||||
message_group: Dictionary containing messages and metadata
|
||||
contact_name: Name of the contact
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
messages = message_group['messages']
|
||||
start_time = message_group['start_time']
|
||||
end_time = message_group['end_time']
|
||||
|
||||
messages = message_group["messages"]
|
||||
start_time = message_group["start_time"]
|
||||
end_time = message_group["end_time"]
|
||||
|
||||
# Format timestamps
|
||||
if start_time:
|
||||
try:
|
||||
start_timestamp = datetime.fromtimestamp(start_time)
|
||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
start_time_str = str(start_time)
|
||||
else:
|
||||
start_time_str = "Unknown"
|
||||
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
end_timestamp = datetime.fromtimestamp(end_time)
|
||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
end_time_str = str(end_time)
|
||||
else:
|
||||
end_time_str = "Unknown"
|
||||
|
||||
|
||||
# Build concatenated message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
|
||||
# Format individual message
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
Time Range: {start_time_str} - {end_time_str}
|
||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
doc_content = f"""
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content, contact_name
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load WeChat chat history data from exported JSON files.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing exported WeChat JSON files
|
||||
**load_kwargs:
|
||||
@@ -376,97 +406,104 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||
include_non_text = load_kwargs.get('include_non_text', False)
|
||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||
max_length = load_kwargs.get('max_length', 1000)
|
||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||
include_non_text = load_kwargs.get("include_non_text", False)
|
||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||
max_length = load_kwargs.get("max_length", 1000)
|
||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
wechat_export_dir = "./wechat_export_test"
|
||||
|
||||
|
||||
if not os.path.exists(wechat_export_dir):
|
||||
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||
return docs
|
||||
|
||||
|
||||
try:
|
||||
# Find all JSON files in the export directory
|
||||
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||
print(f"Found {len(json_files)} WeChat chat history files")
|
||||
|
||||
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
with open(json_file, encoding="utf-8") as f:
|
||||
chat_data = json.load(f)
|
||||
|
||||
|
||||
# Extract contact name from filename
|
||||
contact_name = json_file.stem
|
||||
|
||||
|
||||
if concatenate_messages:
|
||||
# Filter messages to only include readable text messages
|
||||
readable_messages = []
|
||||
for message in chat_data:
|
||||
try:
|
||||
content = message.get('content', '')
|
||||
content = message.get("content", "")
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
continue
|
||||
|
||||
|
||||
readable_messages.append(message)
|
||||
except Exception as e:
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||
readable_messages,
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=0, # No overlap between groups
|
||||
)
|
||||
|
||||
|
||||
# Create documents from concatenated groups
|
||||
for message_group in message_groups:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(
|
||||
message_group, contact_name
|
||||
)
|
||||
doc = Document(
|
||||
text=doc_content,
|
||||
metadata={"contact_name": contact_name},
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||
|
||||
|
||||
print(
|
||||
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Original single-message processing
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
# Extract message information
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
message.get("fromUser", "")
|
||||
message.get("toUser", "")
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||
|
||||
# Handle content that might be dict or string
|
||||
try:
|
||||
# Check if this is a readable text message
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
@@ -475,17 +512,17 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
# Skip messages that cause processing errors
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Convert timestamp to readable format
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
# Create document content with metadata header and contact info
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
@@ -493,57 +530,66 @@ Is sent from self: {is_sent_from_self}
|
||||
Time: {time_str}
|
||||
Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading WeChat history: {e}")
|
||||
return docs
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_wechat_export_dirs() -> List[Path]:
|
||||
def find_wechat_export_dirs() -> list[Path]:
|
||||
"""
|
||||
Find all WeChat export directories.
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
Path("./chat_export"),
|
||||
]
|
||||
|
||||
|
||||
for export_dir in possible_dirs:
|
||||
if export_dir.exists() and export_dir.is_dir():
|
||||
json_files = list(export_dir.glob("*.json"))
|
||||
if json_files:
|
||||
export_dirs.append(export_dir)
|
||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||
|
||||
print(
|
||||
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||
)
|
||||
|
||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||
return export_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||
def export_chat_to_file(
|
||||
output_file: str = "wechat_chat_export.txt",
|
||||
max_count: int = 1000,
|
||||
export_dir: str | None = None,
|
||||
include_non_text: bool = False,
|
||||
):
|
||||
"""
|
||||
Export WeChat chat history to a text file.
|
||||
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
@@ -552,36 +598,36 @@ Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
if export_dir is None:
|
||||
export_dir = "./wechat_export_test"
|
||||
|
||||
|
||||
if not os.path.exists(export_dir):
|
||||
print(f"WeChat export directory not found at: {export_dir}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
json_files = list(Path(export_dir).glob("*.json"))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||
with open(json_file, encoding="utf-8") as json_f:
|
||||
chat_data = json.load(json_f)
|
||||
|
||||
|
||||
contact_name = json_file.stem
|
||||
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||
|
||||
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
from_user = message.get('fromUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
|
||||
|
||||
from_user = message.get("fromUser", "")
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
|
||||
# Skip non-text messages unless requested
|
||||
if not include_non_text:
|
||||
reader = WeChatHistoryReader()
|
||||
@@ -591,83 +637,90 @@ Message: {readable_text if readable_text else message_text}
|
||||
if not readable_text:
|
||||
continue
|
||||
message_text = readable_text
|
||||
|
||||
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||
count += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Exported {count} chat entries to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting WeChat chat history: {e}")
|
||||
|
||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||
"""
|
||||
Export WeChat chat history using wechat-exporter tool.
|
||||
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history
|
||||
|
||||
|
||||
Returns:
|
||||
Path to export directory if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
# Create export directory
|
||||
export_path = Path(export_dir)
|
||||
export_path.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
print(f"Exporting WeChat chat history to {export_path}...")
|
||||
|
||||
|
||||
# Check if wechat-exporter directory exists
|
||||
if not self.wechat_exporter_dir.exists():
|
||||
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||
return None
|
||||
|
||||
|
||||
# Install requirements if needed
|
||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||
if requirements_file.exists():
|
||||
print("Installing wechat-exporter requirements...")
|
||||
subprocess.run([
|
||||
"uv", "pip", "install", "-r", str(requirements_file)
|
||||
], check=True)
|
||||
|
||||
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||
|
||||
# Run the export command
|
||||
print("Running wechat-exporter...")
|
||||
result = subprocess.run([
|
||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||
"export-all", str(export_path)
|
||||
], capture_output=True, text=True, check=True)
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.wechat_exporter_dir / "main.py"),
|
||||
"export-all",
|
||||
str(export_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
print("Export command output:")
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Export errors:")
|
||||
print(result.stderr)
|
||||
|
||||
|
||||
# Check if export was successful
|
||||
if export_path.exists() and any(export_path.glob("*.json")):
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||
print(
|
||||
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||
)
|
||||
return export_path
|
||||
else:
|
||||
print("Export completed but no JSON files found")
|
||||
return None
|
||||
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Export command failed: {e}")
|
||||
print(f"Command output: {e.stdout}")
|
||||
@@ -678,18 +731,18 @@ Message: {readable_text if readable_text else message_text}
|
||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
return None
|
||||
|
||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||
"""
|
||||
Find existing WeChat exports or create new ones.
|
||||
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history if needed
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
|
||||
# Look for existing exports in common locations
|
||||
possible_export_dirs = [
|
||||
Path("./wechat_database_export"),
|
||||
@@ -697,23 +750,25 @@ Message: {readable_text if readable_text else message_text}
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
Path("./chat_export"),
|
||||
]
|
||||
|
||||
|
||||
for export_dir_path in possible_export_dirs:
|
||||
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||
export_dirs.append(export_dir_path)
|
||||
print(f"Found existing export: {export_dir_path}")
|
||||
|
||||
|
||||
# If no existing exports, try to export automatically
|
||||
if not export_dirs:
|
||||
print("No existing WeChat exports found. Starting direct export...")
|
||||
|
||||
|
||||
# Try to export using wechat-exporter
|
||||
exported_path = self.export_wechat_chat_history(export_dir)
|
||||
if exported_path:
|
||||
export_dirs = [exported_path]
|
||||
else:
|
||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
|
||||
return export_dirs
|
||||
print(
|
||||
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||
)
|
||||
|
||||
return export_dirs
|
||||
@@ -1,230 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import requests
|
||||
import time
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: List[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from local readers module
|
||||
from .readers import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(
|
||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from .readers import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,719 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Dict, Optional
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from datetime import datetime
|
||||
|
||||
class WeChatHistoryReader(BaseReader):
|
||||
"""
|
||||
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||
|
||||
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||
|
||||
Also includes utilities for automatic WeChat chat history export.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||
|
||||
def check_wechat_running(self) -> bool:
|
||||
"""Check if WeChat is currently running."""
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def install_wechattweak(self) -> bool:
|
||||
"""Install WeChatTweak CLI tool."""
|
||||
try:
|
||||
# Create wechat-exporter directory if it doesn't exist
|
||||
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||
if not wechattweak_path.exists():
|
||||
print("Downloading WeChatTweak CLI...")
|
||||
subprocess.run([
|
||||
"curl", "-L", "-o", str(wechattweak_path),
|
||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||
], check=True)
|
||||
|
||||
# Make executable
|
||||
wechattweak_path.chmod(0o755)
|
||||
|
||||
# Install WeChatTweak
|
||||
print("Installing WeChatTweak...")
|
||||
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error installing WeChatTweak: {e}")
|
||||
return False
|
||||
|
||||
def restart_wechat(self):
|
||||
"""Restart WeChat to apply WeChatTweak."""
|
||||
try:
|
||||
print("Restarting WeChat...")
|
||||
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
||||
time.sleep(2)
|
||||
subprocess.run(["open", "-a", "WeChat"], check=True)
|
||||
time.sleep(5) # Wait for WeChat to start
|
||||
except Exception as e:
|
||||
print(f"Error restarting WeChat: {e}")
|
||||
|
||||
def check_api_available(self) -> bool:
|
||||
"""Check if WeChatTweak API is available."""
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||
], capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
def _extract_readable_text(self, content: str) -> str:
|
||||
"""
|
||||
Extract readable text from message content, removing XML and system messages.
|
||||
|
||||
Args:
|
||||
content: The raw message content (can be string or dict)
|
||||
|
||||
Returns:
|
||||
Cleaned, readable text
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# Handle dictionary content (like quoted messages)
|
||||
if isinstance(content, dict):
|
||||
# Extract text from dictionary structure
|
||||
text_parts = []
|
||||
if 'title' in content:
|
||||
text_parts.append(str(content['title']))
|
||||
if 'quoted' in content:
|
||||
text_parts.append(str(content['quoted']))
|
||||
if 'content' in content:
|
||||
text_parts.append(str(content['content']))
|
||||
if 'text' in content:
|
||||
text_parts.append(str(content['text']))
|
||||
|
||||
if text_parts:
|
||||
return " | ".join(text_parts)
|
||||
else:
|
||||
# If we can't extract meaningful text from dict, return empty
|
||||
return ""
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return ""
|
||||
|
||||
# Remove common prefixes like "wxid_xxx:\n"
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
# If it's just XML or system message, return empty
|
||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||
return ""
|
||||
|
||||
return clean_content.strip()
|
||||
|
||||
def _is_text_message(self, content: str) -> bool:
|
||||
"""
|
||||
Check if a message contains readable text content.
|
||||
|
||||
Args:
|
||||
content: The message content (can be string or dict)
|
||||
|
||||
Returns:
|
||||
True if the message contains readable text, False otherwise
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
# Handle dictionary content
|
||||
if isinstance(content, dict):
|
||||
# Check if dict has any readable text fields
|
||||
text_fields = ['title', 'quoted', 'content', 'text']
|
||||
for field in text_fields:
|
||||
if field in content and content[field]:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
# Skip image messages (contain XML with img tags)
|
||||
if '<img' in content and 'cdnurl' in content:
|
||||
return False
|
||||
|
||||
# Skip emoji messages (contain emoji XML tags)
|
||||
if '<emoji' in content and 'productid' in content:
|
||||
return False
|
||||
|
||||
# Skip voice messages
|
||||
if '<voice' in content:
|
||||
return False
|
||||
|
||||
# Skip video messages
|
||||
if '<video' in content:
|
||||
return False
|
||||
|
||||
# Skip file messages
|
||||
if '<appmsg' in content and 'appid' in content:
|
||||
return False
|
||||
|
||||
# Skip system messages (like "recalled a message")
|
||||
if 'recalled a message' in content:
|
||||
return False
|
||||
|
||||
# Check if there's actual readable text (not just XML or system messages)
|
||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
# If after cleaning we have meaningful text, consider it readable
|
||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||
"""
|
||||
Concatenate messages based on length and time rules.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||
overlap_messages: Number of messages to overlap between consecutive groups
|
||||
|
||||
Returns:
|
||||
List of concatenated message groups
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
concatenated_groups = []
|
||||
current_group = []
|
||||
current_length = 0
|
||||
last_timestamp = None
|
||||
|
||||
for message in messages:
|
||||
# Extract message info
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
# Skip empty messages
|
||||
if not readable_text.strip():
|
||||
continue
|
||||
|
||||
# Check time window constraint (only if time_window_minutes != -1)
|
||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||
if time_diff_minutes > time_window_minutes:
|
||||
# Time gap too large, start new group
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
# Check length constraint (only if max_length != -1)
|
||||
message_length = len(readable_text)
|
||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||
# Current group would exceed max length, save it and start new
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
# Add message to current group
|
||||
current_group.append(message)
|
||||
current_length += message_length
|
||||
last_timestamp = create_time
|
||||
|
||||
# Add the last group if it exists
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
|
||||
return concatenated_groups
|
||||
|
||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||
"""
|
||||
Create concatenated content from a group of messages.
|
||||
|
||||
Args:
|
||||
message_group: Dictionary containing messages and metadata
|
||||
contact_name: Name of the contact
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
messages = message_group['messages']
|
||||
start_time = message_group['start_time']
|
||||
end_time = message_group['end_time']
|
||||
|
||||
# Format timestamps
|
||||
if start_time:
|
||||
try:
|
||||
start_timestamp = datetime.fromtimestamp(start_time)
|
||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
start_time_str = str(start_time)
|
||||
else:
|
||||
start_time_str = "Unknown"
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
end_timestamp = datetime.fromtimestamp(end_time)
|
||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
end_time_str = str(end_time)
|
||||
else:
|
||||
end_time_str = "Unknown"
|
||||
|
||||
# Build concatenated message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
# Format individual message
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
Time Range: {start_time_str} - {end_time_str}
|
||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
doc_content = f"""
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content, contact_name
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load WeChat chat history data from exported JSON files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing exported WeChat JSON files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of chat entries to read.
|
||||
wechat_export_dir (str): Custom path to WeChat export directory.
|
||||
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
||||
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
||||
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||
include_non_text = load_kwargs.get('include_non_text', False)
|
||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||
max_length = load_kwargs.get('max_length', 1000)
|
||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
wechat_export_dir = "./wechat_export_test"
|
||||
|
||||
if not os.path.exists(wechat_export_dir):
|
||||
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||
return docs
|
||||
|
||||
try:
|
||||
# Find all JSON files in the export directory
|
||||
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||
print(f"Found {len(json_files)} WeChat chat history files")
|
||||
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
chat_data = json.load(f)
|
||||
|
||||
# Extract contact name from filename
|
||||
contact_name = json_file.stem
|
||||
|
||||
if concatenate_messages:
|
||||
# Filter messages to only include readable text messages
|
||||
readable_messages = []
|
||||
for message in chat_data:
|
||||
try:
|
||||
content = message.get('content', '')
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
continue
|
||||
|
||||
readable_messages.append(message)
|
||||
except Exception as e:
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
for message_group in message_groups:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||
|
||||
else:
|
||||
# Original single-message processing
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
# Extract message information
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
# Handle content that might be dict or string
|
||||
try:
|
||||
# Check if this is a readable text message
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
continue
|
||||
except Exception as e:
|
||||
# Skip messages that cause processing errors
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
# Convert timestamp to readable format
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
# Create document content with metadata header and contact info
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
Is sent from self: {is_sent_from_self}
|
||||
Time: {time_str}
|
||||
Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {json_file}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading WeChat history: {e}")
|
||||
return docs
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_wechat_export_dirs() -> List[Path]:
|
||||
"""
|
||||
Find all WeChat export directories.
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
]
|
||||
|
||||
for export_dir in possible_dirs:
|
||||
if export_dir.exists() and export_dir.is_dir():
|
||||
json_files = list(export_dir.glob("*.json"))
|
||||
if json_files:
|
||||
export_dirs.append(export_dir)
|
||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||
|
||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||
return export_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||
"""
|
||||
Export WeChat chat history to a text file.
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
export_dir: Directory containing WeChat JSON files
|
||||
include_non_text: Whether to include non-text messages
|
||||
"""
|
||||
if export_dir is None:
|
||||
export_dir = "./wechat_export_test"
|
||||
|
||||
if not os.path.exists(export_dir):
|
||||
print(f"WeChat export directory not found at: {export_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
json_files = list(Path(export_dir).glob("*.json"))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||
chat_data = json.load(json_f)
|
||||
|
||||
contact_name = json_file.stem
|
||||
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
from_user = message.get('fromUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
|
||||
# Skip non-text messages unless requested
|
||||
if not include_non_text:
|
||||
reader = WeChatHistoryReader()
|
||||
if not reader._is_text_message(content):
|
||||
continue
|
||||
readable_text = reader._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
continue
|
||||
message_text = readable_text
|
||||
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {json_file}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Exported {count} chat entries to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting WeChat chat history: {e}")
|
||||
|
||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||
"""
|
||||
Export WeChat chat history using wechat-exporter tool.
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history
|
||||
|
||||
Returns:
|
||||
Path to export directory if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Create export directory
|
||||
export_path = Path(export_dir)
|
||||
export_path.mkdir(exist_ok=True)
|
||||
|
||||
print(f"Exporting WeChat chat history to {export_path}...")
|
||||
|
||||
# Check if wechat-exporter directory exists
|
||||
if not self.wechat_exporter_dir.exists():
|
||||
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||
return None
|
||||
|
||||
# Install requirements if needed
|
||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||
if requirements_file.exists():
|
||||
print("Installing wechat-exporter requirements...")
|
||||
subprocess.run([
|
||||
"uv", "pip", "install", "-r", str(requirements_file)
|
||||
], check=True)
|
||||
|
||||
# Run the export command
|
||||
print("Running wechat-exporter...")
|
||||
result = subprocess.run([
|
||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||
"export-all", str(export_path)
|
||||
], capture_output=True, text=True, check=True)
|
||||
|
||||
print("Export command output:")
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Export errors:")
|
||||
print(result.stderr)
|
||||
|
||||
# Check if export was successful
|
||||
if export_path.exists() and any(export_path.glob("*.json")):
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||
return export_path
|
||||
else:
|
||||
print("Export completed but no JSON files found")
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Export command failed: {e}")
|
||||
print(f"Command output: {e.stdout}")
|
||||
print(f"Command errors: {e.stderr}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Export failed: {e}")
|
||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
return None
|
||||
|
||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||
"""
|
||||
Find existing WeChat exports or create new ones.
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history if needed
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
# Look for existing exports in common locations
|
||||
possible_export_dirs = [
|
||||
Path("./wechat_database_export"),
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
]
|
||||
|
||||
for export_dir_path in possible_export_dirs:
|
||||
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||
export_dirs.append(export_dir_path)
|
||||
print(f"Found existing export: {export_dir_path}")
|
||||
|
||||
# If no existing exports, try to export automatically
|
||||
if not export_dirs:
|
||||
print("No existing WeChat exports found. Starting direct export...")
|
||||
|
||||
# Try to export using wechat-exporter
|
||||
exported_path = self.export_wechat_chat_history(export_dir)
|
||||
if exported_path:
|
||||
export_dirs = [exported_path]
|
||||
else:
|
||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
|
||||
return export_dirs
|
||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
@@ -1,13 +1,28 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
@@ -27,7 +42,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
@@ -64,7 +79,7 @@ When all tests pass, you should see:
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
@@ -98,7 +113,7 @@ pkill -f "embedding_server"
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
@@ -117,4 +132,4 @@ These tests are designed to be run in automated environments:
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
@@ -1,43 +1,46 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx_lm import load
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
@@ -45,24 +48,25 @@ def benchmark_mlx(model, tokenizer, sentences):
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
@@ -92,13 +96,15 @@ def main():
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
mlx_times = [
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||
]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
@@ -109,20 +115,27 @@ def main():
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def _meta_exists(index_path: str) -> bool:
|
||||
p = Path(index_path)
|
||||
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||
|
||||
|
||||
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||
# if _meta_exists(index_path):
|
||||
# return
|
||||
kwargs = {}
|
||||
if backend_name == "hnsw":
|
||||
kwargs["is_compact"] = is_recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=is_recompute,
|
||||
num_threads=4,
|
||||
**kwargs,
|
||||
)
|
||||
for i in range(num_docs):
|
||||
builder.add_text(
|
||||
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||
)
|
||||
builder.build_index(index_path)
|
||||
|
||||
|
||||
def _bench_group(
|
||||
index_path: str,
|
||||
recompute: bool,
|
||||
query: str,
|
||||
repeats: int,
|
||||
complexity: int = 32,
|
||||
top_k: int = 10,
|
||||
) -> float:
|
||||
# Independent searcher per group; fixed port when recompute
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
|
||||
# Warm-up once
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
|
||||
def _once() -> float:
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
return time.time() - t0
|
||||
|
||||
if repeats <= 1:
|
||||
t = _once()
|
||||
else:
|
||||
vals = [_once() for _ in range(repeats)]
|
||||
vals.sort()
|
||||
t = vals[len(vals) // 2]
|
||||
|
||||
searcher.cleanup()
|
||||
return t
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-docs", type=int, default=5000)
|
||||
parser.add_argument("--repeats", type=int, default=3)
|
||||
parser.add_argument("--complexity", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
# ---------- Build HNSW variants ----------
|
||||
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||
|
||||
# ---------- Build DiskANN variants ----------
|
||||
diskann_r = str(base / "diskann_r.leann")
|
||||
diskann_nr = str(base / "diskann_nr.leann")
|
||||
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||
|
||||
# ---------- Helpers ----------
|
||||
def _size_for(prefix: str) -> int:
|
||||
p = Path(prefix)
|
||||
base_dir = p.parent
|
||||
stem = p.stem
|
||||
total = 0
|
||||
for f in base_dir.iterdir():
|
||||
if f.is_file() and f.name.startswith(stem):
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
# ---------- HNSW benchmark ----------
|
||||
t_hnsw_r = _bench_group(
|
||||
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_hnsw_nr = _bench_group(
|
||||
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
size_hnsw_r = _size_for(hnsw_r)
|
||||
size_hnsw_nr = _size_for(hnsw_nr)
|
||||
|
||||
print("Benchmark results (HNSW):")
|
||||
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||
print(
|
||||
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||
|
||||
# ---------- DiskANN benchmark ----------
|
||||
t_diskann_r = _bench_group(
|
||||
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_diskann_nr = _bench_group(
|
||||
diskann_nr,
|
||||
False,
|
||||
"DiskANN NR test doc 123",
|
||||
repeats=args.repeats,
|
||||
complexity=args.complexity,
|
||||
)
|
||||
size_diskann_r = _size_for(diskann_r)
|
||||
size_diskann_nr = _size_for(diskann_nr)
|
||||
|
||||
print("\nBenchmark results (DiskANN):")
|
||||
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,14 +3,15 @@
|
||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Setup logging
|
||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
||||
|
||||
for line in lines:
|
||||
if "Peak Memory:" in line:
|
||||
peak_memory = float(
|
||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||
)
|
||||
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||
|
||||
return {"peak_memory": peak_memory}
|
||||
|
||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
||||
|
||||
tracker.checkpoint("After imports")
|
||||
|
||||
from leann.api import LeannBuilder
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
print(f"Total number of chunks: {len(all_texts)}")
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
@@ -196,16 +195,14 @@ def test_leann_hnsw():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
# Load searcher
|
||||
searcher = LeannSearcher(index_path)
|
||||
tracker.checkpoint("After searcher loading")
|
||||
|
||||
|
||||
|
||||
print("Running search queries...")
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -303,21 +300,15 @@ def main():
|
||||
|
||||
print("\nLEANN vs Faiss Performance:")
|
||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||
print(
|
||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||
)
|
||||
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||
|
||||
# Storage comparison
|
||||
if leann_storage_size > faiss_storage_size:
|
||||
storage_ratio = leann_storage_size / faiss_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||
elif faiss_storage_size > leann_storage_size:
|
||||
storage_ratio = faiss_storage_size / leann_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||
else:
|
||||
print(" Storage Size: similar")
|
||||
else:
|
||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
1562
benchmarks/data/prompts_g5/prompt_dump_gpqa_hnsw.txt
Normal file
1562
benchmarks/data/prompts_g5/prompt_dump_gpqa_hnsw.txt
Normal file
File diff suppressed because it is too large
Load Diff
484
benchmarks/data/prompts_g5/prompt_dump_hotpot_hnsw.txt
Normal file
484
benchmarks/data/prompts_g5/prompt_dump_hotpot_hnsw.txt
Normal file
File diff suppressed because one or more lines are too long
484
benchmarks/data/prompts_g5/prompt_dump_nq_hnsw.txt
Normal file
484
benchmarks/data/prompts_g5/prompt_dump_nq_hnsw.txt
Normal file
File diff suppressed because one or more lines are too long
484
benchmarks/data/prompts_g5/prompt_dump_trivia_hnsw.txt
Normal file
484
benchmarks/data/prompts_g5/prompt_dump_trivia_hnsw.txt
Normal file
@@ -0,0 +1,484 @@
|
||||
=== Prompt Dump for TRIVIA + HNSW ===
|
||||
Total prompts: 50
|
||||
Showing first 20 prompts:
|
||||
|
||||
==================================================
|
||||
PROMPT #1:
|
||||
==================================================
|
||||
Jason Lee also portrays David Seville in live action/CGI films starring Alvin and the Chipmunks, which use a combination of live-action acting and computer animation. While Ross Bagdasarian Jr. does not do any voices for the film series, the films are all produced in association with Bagdasarian Productions, which owns the rights to all of the characters. Portrayed by Filmography Films Television See also References Fictional characters introduced in 1958 Alter egos Alvin and the Chipmunks Fictional managers Fictional producers American male characters in televisionRoss Dickran Bagdasarian (born May 6, 1949) is an American actor, animator and producer, known for his work on the Alvin and the Chipmunks franchise. He is the son of the franchise's creator, Ross Bagdasarian. Early life Bagdasarian was born in Fresno, California, the son of Armenian-American parents Armenuhi Bagdasarian (née Kulhanjian) and Ross Bagdasarian (1919–1972). As a child, he worked with his father on The Alvin Show by helping edit and coordinate the soundtracks and falsetto voice-overs of the Chipmunks. Career Bagdasarian graduated from law school. He succeeded his father as president of Bagdasarian Productions in 1972 after the death of the elder Bagdasarian. The company had fallen into obscurity after significant success between 1958 and the late 1960s. Bagdasarian was also admitted to the California bar as an attorney in 1975. Under Bagdasarian's supervision, new Chipmunks records were created shortly after his marriage to Karman, including Chipmunk Punk. In 1981, the Chipmunks returned to television in the cartoon special A Chipmunk Christmas. Two years later, Ruby-Spears Productions' Alvin and the Chipmunks Saturday morning cartoon series debuted on NBC. Based on that series, a feature film, The Chipmunk Adventure was released in 1987. Bagdasarian voices Alvin, Simon, and Dave Seville, and Karman voices Theodore and the Chipettes (Brittany, Jeanette, and Eleanor). Bagdasarian and Karman hold tight creative and financial control over the Chipmunk franchise, reviewing each and every business contract in great detail. In the mid-90s, Bagdasarian bought out his brother's and sister's portions of the Chipmunk rights, to take complete control of the franchise.Alvin and the Chipmunks, originally David Seville and the Chipmunks or simply The Chipmunks, are an American animated virtual band and media franchise first created by Ross Bagdasarian for novelty records in 1958. The group consists of three singing animated anthropomorphic chipmunks named Alvin, Simon, and Theodore who are originally managed by their human adoptive father, David "Dave" Seville. Bagdasarian provided the group's voices by producing sped-up recordings of his own, a technique pioneered on the successful "Witch Doctor". Later in 1958, Bagdasarian released the similarly-engineered "The Chipmunk Song" for which he came up with the chipmunk characters and their human father, attributing the track to them. David Seville and the Chipmunks released several more records over the following decade until Bagdasarian's death in 1972. The franchise was revived in 1979 with the characters' voices provided by his son Ross Bagdasarian Jr. and the latter's wife Janice Karman. Through the successful franchise, the Chipmunks have become one of the most successful children's artists of all time. It has garnered two number-one singles on the Billboard Hot 100 and won five Grammy Awards, having four Top 10 albums on the Billboard 200 and three certified platinum albums. "The Chipmunk Song" became one of the best-selling singles of all time at 5 million physical copies sold. The Chipmunks were first depicted in animated form in The Alvin Show (1961). The characters have since featured in several television series and films, as well as other media. In 2019, The Chipmunks received a star on the Hollywood Walk of Fame.
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Rita Coolidge sang the title song for which Bond film??
|
||||
A: Octopussy
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who was the man behind The Chipmunks?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #2:
|
||||
==================================================
|
||||
and the drum set. Their film counterparts are Michelle and Eleni. Production history Broadway (2015-2019) Auditions began on January 19, 2015 for children ages nine through fifteen. Some recruiting was done through the School of Rock after-school educational program (which predated the film by several years) and open calls were held in New York at the Winter Garden, in Chicago and in Los Angeles. The production closed on January 20, 2019, after 1,309 performances. West End (2016–2020) On 7 December 2015, following the show's Broadway opening, it was announced by Andrew Lloyd Webber that the show would transfer to London's West End in autumn 2016, with the intention to open at the London Palladium. On 20 May 2016, the musical was confirmed at the Gillian Lynne Theatre instead of the Palladium with previews starting on 24 October 2016, opening night on 14 November 2016, and public booking opening on 25 May 2016. Lloyd Webber revealed that the production was able to open several months earlier than anticipated due to finding the child musician actors easily. Anna Louizos' scenery has been modified to fit the architecture of the Gillian Lynne Theatre from the traditional proscenium arch stage at Winter Garden Theatre. Changes include the removal of the pre-show curtain, the use of a revolving stage and action taking place in the aisles of the stalls. While the show remains to be set in America, the script has been adapted to include some minor references for a British audience. The original London cast includes David Fynn as DeweyThe Sound of Music, Camelot and Fiddler on the Roof played at the theatre in the early 1980s. In 1984, the interior was extensively modified by the introduction of a 'race track' that ran through the audience, for the show Starlight Express with performers on roller skates. The show premièred on 27 March, composed by Andrew Lloyd Webber and directed by Trevor Nunn and ran for 7,406 performances, over 18 years. With the removal of the 'tracks', the interior was extensively restored by architects Jaques Muir and Partners. This included the removal of 3,500 incandescent lamps that had become difficult to maintain and consumed a considerable amount of power. These were replaced by 88,000 low power LEDs specially designed for the theatre, creating the first auditorium completely lit in this way. Another Lloyd Webber production followed, Bombay Dreams premièred on 19 June 2002. It was created by A. R. Rahman with lyrics by Don Black and was directed by Steven Pimlott, closing after 1,500 performances on 13 June 2004. This was followed by the return to the West End of the Bee Gee's musical Saturday Night Fever on 6 July 2004, closing 22 October 2005 to tour. This was followed on 10 April 2006 by the jukebox musical Movin' Out, featuring the music of Billy Joel. This starred James Fox but ran for only two months. The Broadway musical Wicked received its London première at the venue on 27 September 2006 with a cast featuring Idina Menzel as Elphaba, Helen Dallimore as Glinda, Nigel Planer asand also starred comedian Tim Minchin as Judas Iscariot, former Spice Girl Melanie C as Mary Magdalene and BBC Radio 1 DJ Chris Moyles as King Herod. Tickets for most venues went on sale on 18 May 2012. In 2013, Lloyd Webber reunited with Christopher Hampton and Don Black on Stephen Ward the Musical. For his next project, a 2015 musical adaptation of the 2003 film School of Rock, auditions were held for children aged nine to fifteen in cooperation with the School of Rock music education program, which predated the film by several years. In April 2016, the English National Opera staged a revival of Sunset Boulevard at the London Coliseum. The limited run, semi-staged production directed by Lonny Price brought Glenn Close to reprise her star turn as "Norma Desmond", which was her first time performing the role in London; she had originated the role in Los Angeles in December 1993 and then on Broadway in November 1994 (which won her the 1995 Tony Award for Best Actress in a Musical). The 2016 London revival was so well-received that the production transferred to the Palace Theatre on Broadway in February 2017, making Lloyd Webber the first musical-theatre composer since 1953 to have four musicals running simultaneously on Broadway – a feat that his heroes Rodgers and Hammerstein had previously achieved. Lloyd Webber's memoir, Unmasked, was published in 2018. On 9 September 2018, Lloyd Webber, along with Tim Rice and John Legend each won an Emmy for Jesus Christ Superstar Live in Concert. With this
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Rita Coolidge sang the title song for which Bond film??
|
||||
A: Octopussy
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #3:
|
||||
==================================================
|
||||
Cabinet Louis Botha, Prime Minister of the Union of South Africa (1910–1919) Behind Churchill are: George Barnes, leader of the National Democratic and Labour Party Sir Robert Borden, Prime Minister of Canada (1911–1920) To their right are: Arthur Balfour, 1st Earl of Balfour, former Prime Minister of the United Kingdom (1902–1905); First Lord of the Admiralty (1915–1916) and Foreign Secretary (1916–1919) (standing adlocutio in a black suit) H. H. Asquith, 1st Earl of Oxford and Asquith, Prime Minister of the United Kingdom (1908–1916) (sitting in front) Sir Eric Geddes, First Lord of the Admiralty (1917–1919) (behind, cleanshaven) Bonar Law, Leader of the Opposition (United Kingdom) (1911–1915), Secretary of State for the Colonies (1915–1916), Chancellor of the Exchequer (1916–1919) (later Prime Minister of the United Kingdom, 1922–1923) (dark moustache) Edward Morris, 1st Baron Morris, Prime Minister of Newfoundland (1909–1917) (white moustache, in the shadows) Herbert Kitchener, 1st Earl Kitchener, Secretary of State for War (1914–1916) (in the shadows) Bailey decided that the painting should include British and Dominion civilian leaders in office at the beginning and the end of the First World War. It includes Prime Ministers of Australia, Canada, Newfoundland, and New Zealand, and the Prime Ministers, Foreign Secretaries, Secretaries of War, and First Lords of the Admiralty of the United Kingdom, together with two leaders of the British Conservative and Labour parties. The Maharaja of Bikaner, a member of the Imperial War Cabinet and the Indian delegate to the Versailles Peace Conference, stands to the left next to Botha, both in military uniform. Kitchener standsArthur James Balfour, 1st Earl of Balfour, (, ; 25 July 184819 March 1930), also known as Lord Balfour, was a British Conservative statesman who served as Prime Minister of the United Kingdom from 1902 to 1905. As foreign secretary in the Lloyd George ministry, he issued the Balfour Declaration of 1917 on behalf of the cabinet, which supported a "home for the Jewish people" in Palestine. Entering Parliament in 1874, Balfour achieved prominence as Chief Secretary for Ireland, in which position he suppressed agrarian unrest whilst taking measures against absentee landlords. He opposed Irish Home Rule, saying there could be no half-way house between Ireland remaining within the United Kingdom or becoming independent. From 1891 he led the Conservative Party in the House of Commons, serving under his uncle, Lord Salisbury, whose government won large majorities in 1895 and 1900. An esteemed debater, he was bored by the mundane tasks of party management. In July 1902, he succeeded his uncle as prime minister. In domestic policy he passed the Land Purchase (Ireland) Act 1903, which bought out most of the Anglo-Irish land owners. The Education Act 1902 had a major long-term impact in modernising the school system in England and Wales and provided financial support for schools operated by the Church of England and by the Catholic Church. Nonconformists were outraged and mobilised their voters, but were unable to reverse it. In foreign and defence policy, he oversaw reform of British defence policy and supported Jackie Fisher's naval innovations. He secured the Entente Cordiale withthe county of Haddington. In October 1922 he, with most of the Conservative leadership, resigned with Lloyd George's government following the Carlton Club meeting, a Conservative back-bench revolt against continuance of the coalition. Bonar Law became prime minister. Like many Coalition leaders, he did not hold office in the Conservative governments of 1922–1924, but as an elder statesman, he was consulted by the King in the choice of Stanley Baldwin as Bonar Law's successor as Conservative leader in May 1923. His advice was strongly in favour of Baldwin, ostensibly due to Baldwin's being an MP but in reality motivated by his personal dislike of Curzon. Later that evening, he met a mutual friend who asked 'Will dear George be chosen?' to which he replied with 'feline Balfourian satisfaction,' 'No, dear George will not.' His hostess replied, 'Oh, I am so sorry to hear that. He will be terribly disappointed.' Balfour retorted, 'Oh, I don't know. After all, even if he has lost the hope of glory he still possesses the means of Grace.' Balfour was not initially included in Baldwin's second government in 1924, but in 1925, he returned to the Cabinet, in place of the late Lord Curzon as Lord President of the Council, until the government ended in 1929. With 28 years of government service, Balfour had one of the longest ministerial careers in modern British politics, second only to Winston Churchill . Last years Lord Balfour had generally good health until 1928 and remained until then a regular tennis player. Four years previously
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Rita Coolidge sang the title song for which Bond film??
|
||||
A: Octopussy
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #4:
|
||||
==================================================
|
||||
classic '70s pop song." In 1992, Mexican trio Pandora released a cover version titled "Pierdo el Control" on their album Ilegal. In 1979 Ginger Rogers sang this song on The Love Boat in the episode "Critical Success / The Love Lamp Is Lit / Take My Boyfriend, Please / Rent a Family / The Man in Her Life: Parts 1 & 2" In 2001, the film Get Over It featured a dance to this song at the beginning by some of the cast. References 1973 songs 1975 debut singles Songs written by Neil Sedaka Songs with lyrics by Howard Greenfield Neil Sedaka songs Captain & Tennille songs Andy Williams songs Number-one singles in Australia Billboard Hot 100 number-one singles Cashbox number-one singles RPM Top Singles number-one singles Grammy Award for Record of the Year A&M Records singles Juno Award for Best Selling Single singlesMusic Week rated the song four out of five, concluding, "A third huge hit for the boys." Tracklisting CD single "Kiss You All Over" (Radio Edit) - 4:31 "Kiss You All Over" (Club Mix) - 5:53 "Bonita" (Radio Edit) - 3:54 "Bonita" (Club Mix) - 7:08 Charts Release history References 1978 songs 1978 singles 1997 singles 1998 singles Billboard Hot 100 number-one singles Cashbox number-one singles Exile (American band) songs Number-one singles in New Zealand Number-one singles in South Africa Number-one singles in Australia Songs written by Mike Chapman Song recordings produced by Frank Farian Song recordings produced by Mike Chapman Songs written by Nicky Chinn RAK Records singles Curb Records singles Hilltak Records singles Warner Records singles Arista Records singles No Mercy (pop band) songs Songs about kissing Phyllis Hyman songs"Kiss You All Over" is a 1978 song performed by American group Exile, written by Mike Chapman and Nicky Chinn. It was included on the band's third album, Mixed Emotions (1978), and featured lead vocalist Jimmy Stokley and guitarist J.P. Pennington on vocals. On the American Top 40 broadcast of May 26, 1979, Casey Kasem reported that Chapman stated his source of inspiration for "Kiss You All Over" was "It's Ecstasy When You Lay Down Next to Me" by Barry White. The song was a number one single in the United States, but proved to be Exile's only big hit in the pop market (they would later have great success on the country music charts). It held the number one spot on the Billboard Hot 100 for four weeks (starting September 30), and Billboard ranked it as the No. 5 song for 1978. The track also reached number-one in at least three other nations. In the United Kingdom, the song was released on Mickie Most's RAK Records, and peaked at number 6 on the UK Singles Chart. The strings are played with a synthesizer in a backing track. In 2010, Billboard ranked the song tenth on its list of "The 50 Sexiest Songs of All Time". Lead vocalist on the number, Stokley was ousted from the band in 1979, his health declining thereafter until he died at the age of 41 in 1985. After the success of soft rock singles from the albums Mixed Emotions and All There Is, the band moved into country music in
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Rita Coolidge sang the title song for which Bond film??
|
||||
A: Octopussy
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #5:
|
||||
==================================================
|
||||
21st century world: "We dislike low-lying voices, for one thing— contraltos now sound freakish and headmistressy, and even the majority of mezzo-sopranos should more accurately be categorised as almost-sopranos". However, she was "a singer of, and for, her time — a time of grief and weariness, national self-respect and a belief in human nobility". In this context "her artistry stands upright, austere, unfussy, fundamental and sincere". Shortly after Ferrier's death an appeal was launched by Barbirolli, Walter, Myra Hess and others, to establish a cancer research fund in Ferrier's name. Donations were received from all over the world. To publicise the fund a special concert was given at the Royal Festival Hall on 7 May 1954, at which Barbirolli and Walter shared the conducting duties without payment. Among the items was a rendition of Purcell's When I am laid in earth, which Ferrier had often sung; on this occasion the vocal part was played by a solo cor anglais. The Kathleen Ferrier Cancer Research Fund helped establish the Kathleen Ferrier Chair of Clinical Oncology at University College Hospital, in 1984. , it was continuing to fund oncology research. As the result of a separate appeal, augmented by the sales proceeds of a memoir edited by Neville Cardus, the Kathleen Ferrier Memorial Scholarship Fund was created to encourage young British and Commonwealth singers of either sex. The Fund, which has operated from 1956 under the auspices of the Royal Philharmonic Society, initially provided an annual award covering the cost of a year's study to a single prizewinner.In the course of her professional life the English contralto Kathleen Ferrier made a large number of recordings. In the summer of 1944 she signed a contract with Columbia, which lasted until February 1946. She then transferred to Decca, and remained with them until her death in October 1953. Apart from her studio recordings, many of her live performances and broadcast recitals were recorded, sometimes privately. Some of these were later issued as commercial recordings; others are held by individuals or in the archives of broadcasting companies. The following list is neither up to date nor entirely accurate, particularly in regard to a CD issue, entitled 'Kathleen Ferrier Remembered', released in June 2017, on SOMM264, comprising 26 tracks, 19 of which have never previously been issued. Most of these 19 are not listed below. They include Lieder by Schubert, Brahms, Wolf and Mahler and songs by Stanford, Parry, Jacobson and Rubbra, all taken from BBC broadcasts between 1947 and 1952. In April 2019, a recording of Ferrier singing in Bach's 'Magnificat' during the 1950 Vienna International Bach Festival was issued for the first time. The CD catalogue number is SOMM Ariadne 5004 and it also features Irmgard Seefried and Friedl Riegler (sopranos), Hugo Meyer-Welfing (tenor) and Otto Edelmann (bass). The Vienna Philharmonic Orchestra and Chorus of the Vienna State Opera are conducted by Volkmar Andreae. The existence of this recording was not known until a vinyl disc was offered for sale on an internet auction site in 2018. In superb recorded sound, this discovery is aKathleen Mary Ferrier, CBE (22 April 19128 October 1953) was an English contralto singer who achieved an international reputation as a stage, concert and recording artist, with a repertoire extending from folksong and popular ballads to the classical works of Bach, Brahms, Mahler and Elgar. Her death from cancer, at the height of her fame, was a shock to the musical world and particularly to the general public, which was kept in ignorance of the nature of her illness until after her death. The daughter of a Lancashire village schoolmaster, Ferrier showed early talent as a pianist, and won numerous amateur piano competitions while working as a telephonist with the General Post Office. She did not take up singing seriously until 1937, when after winning a prestigious singing competition at the Carlisle Festival she began to receive offers of professional engagements as a vocalist. Thereafter she took singing lessons, first with J.E. Hutchinson and later with Roy Henderson. After the outbreak of the Second World War Ferrier was recruited by the Council for the Encouragement of Music and the Arts (CEMA), and in the following years sang at concerts and recitals throughout the UK. In 1942 her career was boosted when she met the conductor Malcolm Sargent, who recommended her to the influential Ibbs and Tillett concert management agency. She became a regular performer at leading London and provincial venues, and made numerous BBC radio broadcasts. In 1946, Ferrier made her stage debut, in the Glyndebourne Festival premiere of Benjamin Britten's opera The Rape of Lucretia.
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: Rita Coolidge sang the title song for which Bond film??
|
||||
A: Octopussy
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #6:
|
||||
==================================================
|
||||
"You Only Live Twice", performed by Nancy Sinatra, is the theme song to the 1967 James Bond film of the same name. The music was by veteran Bond film composer John Barry, with lyrics by Leslie Bricusse. The song is widely recognized for its striking opening bars, featuring a simple 2-bar theme in the high octaves of the violins and lush harmonies from French horns. It is considered by some to be among the best James Bond theme songs, and has become one of Nancy Sinatra's best known hits. Shortly after Barry's production, Sinatra's producer Lee Hazlewood released a more guitar-based single version. The song has been covered by many artists including Coldplay, Soft Cell, Björk and Shirley Bassey. In 1998, Robbie Williams re-recorded portions of the song (including the opening strings) for use in his UK number-one single "Millennium". Background James Bond veteran John Barry returned to the franchise to produce the score. The lyrics were by Leslie Bricusse, who had previously cowritten the lyrics for the theme to Goldfinger. An initial version of the song was performed by Julie Rogers and recorded with a 50 or 60 piece orchestra at CTS Studios. However, this version was not used since Barry decided to re-write and re-record the song: "It was usually the producers that said 'this isn't working, there's a certain something that it needed'. If that energy wasn't there, if that mysterioso kind of thing wasn't there, then it wasn't going to work for the movie." The Rogers song shares only two lines withBassey belting out the fantastic title song." He added that the remastered edition's sound quality was "impeccable". Chart positions Track listing Credits Project manager: Herb Agner Creative director: Michelle Azzopardi Composer, conductor, primary artist: John Barry Primary artist, vocals: Shirley Bassey Liner notes: Jeff Bond Composer, lyricist: Leslie Bricusse Project manager: Wendy Brueder Producer, reissue producer: Frank Collura Remastering: Bob Fisher Guitar, soloist: Vic Flick Art direction, design: Peter Grant Orchestra contractor: Sid Margo Lyricist: Anthony Newley A&R: Gregg Ogorzelec Engineer: John Richards Saxophone, soloist: John Scott Source: Aftermath Following the success of her performance on the title track, Shirley Bassey sang the title songs for two later Bond films, Diamonds Are Forever and Moonraker. John Barry used the Goldfinger theme on his 1965 John Barry Plays Goldfinger album that featured Robert Brownjohn artwork. References Footnotes Citations Bibliography Soundtrack albums from James Bond films Soundtrack 1964 soundtrack albums EMI Records soundtracks John Barry (composer) soundtracksJames Bond (Roger Moore), and the title evidently refers to the key aerial sequences featured in the movie. Prior to Rita Coolidge being assigned the Octopussy theme, Mari Wilson was a contender, a British singer whose retro-image evoked the mid-'60s when the Bond series originated; but Wilson's lack of a US-profile led to a negative decision. In January 1983, the producer of Octopussy: Cubby Broccoli, stated that he hoped to have current hitmaker Laura Branigan sing the movie's theme song, an artist choice which both Barry and Rice have stated would have pleased them. However, on March 29, 1983 Rita Coolidge was revealed as the singer, a seemingly surprising choice in that Coolidge's career peak had occurred some six years previously. Coolidge recalls that Barbara Broccoli, daughter of Cubby Broccoli and herself the assistant director of Octopussy, was a fan of Coolidge and made a point of playing Coolidge records around her father until "one day [he said], "Who is that? That's the voice I want for the movie." Rice still had to complete his contribution as the singer arrived in the studio, with Coolidge stating that "we were waiting for the lyrics as the instrumental track had already been done." The chorus of "All Time High" features a lyric similar to that of Coolidge's #2 hit "(Your Love Has Lifted Me) Higher and Higher" whose lyric "When you wrap your loving arms around me I can stand up and face the world again" is echoed by the "All Time High" lyric "We'll take on the
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Rita Coolidge sang the title song for which Bond film?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #7:
|
||||
==================================================
|
||||
which allowed the first legal beer sales since the beginning of Prohibition on January 16, 1920. In 1933 state conventions ratified the Twenty-first Amendment, which repealed Prohibition. The Amendment was fully ratified on December 5, 1933. Federal laws enforcing Prohibition were then repealed. Dry counties Following repeal some states continued prohibition within their own jurisdictions. Almost two-thirds of the states adopted some form of local option which enabled residents in political subdivisions to vote for or against local prohibition. For a time, 38 percent of Americans lived in areas with Prohibition. By 1966, however, all states had repealed their statewide prohibition laws, with Mississippi the last state to do so. Notes Sources Walker, Robert S. and Samuel C. Patterson, Oklahoma Goes Wet: The Repeal of Prohibiton, Eagleton Institute, Rutgers University, (1961). External links Repeal Day is December Fifth See more related images by selecting the "Alcohol" subject at the Persuasive Cartography, The PJ Mode Collection, Cornell University Library Prohibition in the United States Economic history of the United States 1933 in the United States Articles containing video clipsimportation of alcoholic beverages in the United States. The resolution was sent to the states for ratification and became the Eighteenth Amendment to the U.S. Constitution. On January 8, 1918, Mississippi became the first state to ratify the amendment and on January 16, 1919, Nebraska became the 36th state to do so, securing its passage with the required three-fourths of the states. By the end of February 1919, only three states remained as hold-outs to ratification: New Jersey, Connecticut and Rhode Island. The National Prohibition Act, also known as the Volstead Act, was enacted on October 18, 1919. Prohibition in the United States went into effect on January 17, 1920. Nationwide prohibition was repealed in 1933 with the passage of the Twenty-first Amendment on February 20 and its ratification on December 5. List of formerly dry states This table lists the effective dates each state went dry and any dates of repeal that do not coincide with the end of national prohibition in 1933. See also Dry county Alcoholic beverage control state List of alcohol laws of the United States by state Notes Alcohol law in the United States Prohibition in the United StatesAugust 19. PPS functionals were completed August 21. GATV 5006 was then transferred to complex 14 for mating with the Atlas. July 27, 1966 (Wednesday) Following the announcement of his austerity programme, British Prime Minister Harold Wilson survived a vote of censure in the House of Commons, as members of his Labour Party (with an 88-seat majority) supported him. The final result was 246 votes in favor, and 325 against. On the same day, the nation's chief labor union, the Trades Union Congress, voted 20 to 12 in support of a resolution pledging to halt strikes that had been threatened during the six-month freeze against raising wages. For the first time in 58 years, liquor was legally served in Mississippi, the last of the United States to have repealed its prohibition laws. Effective July 1, individual local governments were allowed to hold referendum elections on whether to allow the sale of liquor at state-approved resorts, and Harrison County voters had endorsed the measure. At 6:55 p.m., after police cars escorted a liquor delivery truck into Biloxi. The first drink in the state was poured at the Broadwater Beach Hotel, and Louis Cobb, the first legal bartender in Mississippi, sold a glass of scotch whiskey to hotel manager T.M. Dorsett. Biloxi Mayor Dan Guice then cut the ribbon to open the entrance to the hotel's bar.Died: Brenda Sue Brown, 11, was beaten to death after walking with her sister to summer school in Shelby, North Carolina. Police were unable to charge a suspect with the crime, until
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: What was the last US state to reintroduce alcohol after prohibition?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #8:
|
||||
==================================================
|
||||
to New York City for work in summer stock theatre shortly before winning a supporting role in MGM's These Glamour Girls (1939) opposite Lana Turner and Lew Ayres. The role of Betty was said to have been written especially with Hunt in mind. Other roles in major studio productions soon followed, including supporting roles as Mary Bennet in MGM's version of Pride and Prejudice (1940) with Laurence Olivier, and as Martha Scott's surrogate child Hope Thompson in Cheers for Miss Bishop (1941). Years at MGM In 1941, Hunt signed a contract with MGM, where she remained for the next six years. While filming Blossoms in the Dust, film director Mervyn LeRoy lauded Hunt for her heartfelt and genuine acting ability. During this period she had starring roles in 21 films, including The Penalty (1941) opposite Lionel Barrymore, Panama Hattie (1942) opposite Ann Sothern and Red Skelton, and the war drama Pilot No. 5 (1943) in which she was cast as the love interest of Franchot Tone, and The Valley of Decision (1945). In 1944 she polled seventh in a list by exhibitors of "Stars of Tomorrow". She previously did a screen test to play Melanie Hamilton in Gone with the Wind (1939) and was told by David O. Selznick she would play the role, but to "keep it a secret for now." Three days later, it was announced that Olivia de Havilland was cast. In 1944, she appeared in None Shall Escape, a film that is now regarded as the first about the Holocaust. She playedMiss America 1941, the 15th Miss America pageant, was held at the Boardwalk Hall in Atlantic City, New Jersey on September 6, 1941. Shortly after the crowning of Miss California, Rosemary LaPlanche, who had been first runner-up in 1940, the pageant committee adopted this rule: "No contestant can compete in Atlantic City for the title of Miss America more than once", thus eliminating future state winners with more than one attempt at the national title. LaPlanche became a film actress, as did her sister, Louise LaPlanche. 1941 was also the first year that the special award, “Miss Congeniality” was created. It went to Mifaunwy Shunatona, a member of the Otoe and Pawnee tribes — she was also the first American Indian contestant in the pageant’s history. Results Awards Preliminary awards Other awards Contestants References Secondary sources External links Miss America official website 1941 1941 in the United States 1941 in New Jersey September 1941 events Events in Atlantic City, New JerseyMiss America 1942, the 16th Miss America pageant, was held at the Warner Theater in Atlantic City, New Jersey on September 12, 1942. Miss Texas, Jo-Carroll Dennison won the title after winning the swimsuit and talent categories. She was the first Miss Texas to win the Miss America title. Dennison became an actress and had roles in films such as Winged Victory. She was married at one time to comedian Phil Silvers. Results Awards Preliminary awards Other awards Contestants References Secondary sources External links Miss America (1942) 1942 1942 in the United States 1942 in New Jersey September 1942 events Events in Atlantic City, New Jersey
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Which actress was voted Miss Greenwich Village in 1942?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #9:
|
||||
==================================================
|
||||
De Tokyo Stock Price Index (Japans: 東証株価指数) of TOPIX is een belangrijke aandelenindex van de Tokyo Stock Exchange. Berekening In deze index zijn alle bedrijven opgenomen die op de beurs van Tokio staan genoteerd in de First Section. Dit zijn de grootste en meest liquide aandelen die op de beurs worden verhandeld. Tot medio 2006 werd het gewicht van de individuele bedrijven in de index bepaald op basis van de marktkapitalisatie, hierna wordt ook de free float in de berekening meegenomen. Het effect van deze verandering was significant, daar veel Japanse bedrijven aandelen houden in andere Japanse bedrijven, ook wel bekend als crossholdings, om daarmee de langdurige zakenrelatie te onderstrepen. Deze belangen worden voor lange tijd gehouden en worden niet tot de free float gerekend. De index heeft 4 januari 1968 als startdatum, maar ging op 1 juli 1969 daadwerkelijk van start. Een andere belangrijke beursindex in Japan is de Nikkei 225. In deze index zijn 225 bedrijven opgenomen en dit is een prijsgewogen index. Samenstelling Eind maart 2021 bestond de index uit 2187 aandelen. Door het grote aantal aandelen is het gewicht van de individuele namen zeer klein. De top 10 aandelen hebben een gezamenlijk gewicht in de index van slechts 18,4% en de lijst zag er als volgt uit, met de gewichten tussen de haakjes: De belangrijkste drie sectoren zijn: elektronische apparatuur, informatie technologie en chemie. Deze drie vertegenwoordigen tezamen zo'n 34% van de index, waarvan de sector elektronische apparatuur het grootst is met een gewicht van 17,5%. Koershistorie De hoogste stand van deTOPIX steht für Tōkyō Stock Price Index (jap. , Tōshō kabuka shisū) und ist neben dem Nikkei 225 ein Kursindex der Tokioter Börse. Berechnet wird der TOPIX seit dem 1. Juli 1969. Die Index-Basis liegt bei 100 Punkten per 4. Januar 1968. Er enthält alle japanischen Aktien, welche im amtlichen Handel zugelassen sind. Die Gewichtung der einzelnen Unternehmen im Index erfolgt anhand der Marktkapitalisierung. Gegenwärtig (8. September 2021) setzt sich der Index aus 2.189 Aktien zusammen. Wegen dieser hohen Zahl an vertretenen Unternehmen wird der TOPIX als aussagekräftiger für den Zustand der japanischen Wirtschaft angesehen als der Nikkei 225. Weblinks Beschreibung des TOPIX (engl.) TOPIX in Echtzeit Jährliche Entwicklung des TOPIX seit 1949 (Daten vor 1969 – dem Einführungsjahr des TOPIX – sind rückgerechnet; XLS-Format, 31,5 KB; abgerufen am 12. Oktober 2017) Einzelnachweise Aktienindex Wirtschaft (Japan) Abkürzung, commonly known as TOPIX, along with the Nikkei 225, is an important stock market index for the Tokyo Stock Exchange (TSE) in Japan, tracking all domestic companies of the exchange's Prime market division. It is calculated and published by the TSE. , there were 1,669 companies listed on the First Section of the TSE, and the market value for the index was ¥197.4 trillion. The index transitioned from a system where a company's weighting is based on the total number of shares outstanding to a weighting based on the number of shares available for trading (called the free float). This transition took place in three phases starting in October 2005 and was completed in June 2006. Although the change is a technicality, it had a significant effect on the weighting of many companies in the index, because many companies in Japan hold a significant number of shares of their business partners as a part of intricate business alliances, and such shares are no longer included in calculating the weight of companies in the index. The TOPIX index is traded as a future on the Osaka Exchange under the ticker symbol JTPX. The CQG contract specifications for the TOPIX Index are listed below. TSE currently calculates and distributes TOPIX every second and further plans to launch a new High-Speed Index dissemination service provided at the millisecond level starting from February 28, 2011. History of TOPIX 1969-07-01 TSE to begin calculating and publishing “TOPIX” and “TOPIX Sector Indices” 1969-08-18 TSE to begin calculating and publishing “Tokyo Stock
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: What is the Japanese share index called?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #10:
|
||||
==================================================
|
||||
Man in the Music: The Creative Life and Work of Michael Jackson is a non-fiction book written by Joseph Vogel, published in June 2011 by the Sterling Publishing. Reception Man in the Music: The Creative Life and Work of Michael Jackson, was described by the Associated Press as "a fascinating read and really a must have for any fan of Jackson." Filmmaker Spike Lee characterized it as having "brilliantly cracked the DNA, the code, the artistry of Michael Joseph Jackson." References Works about Michael Jackson 2011 non-fiction books Sterling Publishing booksMoonwalk is a 1988 autobiography written by American recording artist Michael Jackson. The book was first published by Doubleday on February 1, 1988, five months after the release of Jackson's 1987 Bad album, and named after Jackson's signature dance move, the moonwalk. The book contains a foreword by Jacqueline Onassis. It reached number one on the New York Times Best Seller list. The book was reissued by Doubleday on October 13, 2009, following Jackson's death on June 25, 2009. Production Jacqueline Onassis, who was an editor at Doubleday, secured the book deal and paid Jackson a $300,000 advance. As part of the deal Jackson wanted Onassis to write a foreword, which she initially refused not wanting her name on any books she worked on but agreed to three paragraphs. She also edited the book. The first manuscript of the book was written by Robert Hilburn and was refused by the publishers, Doubleday, because it lacked "juicy details". A second manuscript was written by Stephen Davis, which Jackson drastically edited. Jackson finally decided to write the book himself, with help from Shaye Areheart. Due to the public interest in Jackson, Moonwalk was prepared for publication in secret. Relatives of Doubleday employees were hired as couriers, to deliver portions of the book from the company's head office in Manhattan to the printing plant in Fairfield, Pennsylvania. At the printing plant, the book was given the code name "Neil Armstrong", after the first "moonwalker". Narrative Dedicated to Fred Astaire, the book discusses Jackson's show business friends, girlfriends and hisMichael Jackson: Unauthorized in a 1994 biography of the late pop star Michael Jackson, written by celebrity biographer Christopher Andersen. Development According to Andersen, work started on the book in early 1991 when he received a call from a fellow journalist, who told him that two workers at Jackson's Neverland Ranch allegedly witnessed Jackson fondling a young celebrity. Andersen tried to interview Jackson several times, but was turned down. When Michael was publicly accused of child molestation in 1993, Andersen was told that he was under surveillance from investigators. Reception The book was largely overlooked by the public. Dana Kennedy of Entertainment Weekly felt that, with its "killer material", Anderson "probably could have retired from the celebrity-bio grind for good" had it been released five years before. People magazine found it to be a "sad book", considering its dark revelations about Jackson's behaviour. References 1994 non-fiction books Unauthorized biographies Works about the Michael Jackson sexual abuse allegations Biographies about musicians
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: What was the name of Michael Jackson's autobiography written in 1988?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #11:
|
||||
==================================================
|
||||
including popular titles by Sérgio Mendes and Herb Alpert were released with this audio process starting in September 1968. Other record labels soon followed suit, and an estimated 10% of all stereophonic albums released during the late 1960s and early 1970s employed the system. Other labels known to have used the system include Warner Bros. Records and Reprise Records. One of the biggest selling albums using the process is The Association's Greatest Hits, released in 1968. This recording has sold more than 2 million copies in the United States. The process was also used on the 1968 Frank Sinatra album Cycles as well as on most of the studio recordings on Wheels of Fire by Cream. Early 1968 copies of Neil Young's self-titled debut album also used the system. Use of Haeco-CSG in promotional recordings for radio The original intention of using Haeco-CSG on commercial LP releases was rather short lived, however, use of the process continued well into the mid-1970s on promotional records sent to radio stations. Many commercial FM Rock stations did not transition from mono to stereo broadcasting until the mid to late 1970s. AM Pop music stations continued to broadcast in mono, as AM stereo broadcasting was not introduced until 1982 and was never widely adopted. Many promotional singles and some commercial singles from the Warner/Reprise/Atlantic label group from this era had "CSG Mono Process" or "CSG Process" printed on the labels. Artists included Frank Sinatra, Gordon Lightfoot, James Taylor, Seals and Crofts. Warner subsidiary labels such as Atlantic issued a serieswas introduced to the public on December 13, 1957, at the Times Auditorium in New York City. 500 copies of this initial demonstration record were pressed. On December 16, 1957, Frey advertised in the trade magazine Billboard that he would send a free copy to anyone in the industry who wrote to him on company letterhead. Frey became known as "Mr. Stereo" during that era. Stereophonic sound was not entirely new to the public. In 1952 sound engineer Emory Cook developed a "Binaural" disk that used two separate grooves and playback needles to produce stereophonic sound; the following year he had a catalog of about 25 disks available for audiophiles. Multi-channel sound was integral to the widescreen motion picture processes Cinerama (1952) and CinemaScope (1953). Stereophonic audio tapes had been commercially available to audiophiles, although expensive, since the early-1950s. After the release of the Audio Fidelity demonstration disks, the other spur to the popularity of stereo disks was the reduction in price of a stereo magnetic cartridge, for playing the disks, from $250 to $29.95 in June 1958. The first four stereo discs available to the general public were released by Audio Fidelity in March, 1958--Johnny Puleo and his Harmonica Gang Volume 1 (AFSD 5830), Railroad - Sounds of a Vanishing Era (AFSD 5843), Lionel - Lionel Hampton and his Orchestra (AFSD 5849) and Marching Along with the Dukes of Dixieland Volume 3 (AFSD 5851). By the end of March the company had four more stereo LPs available. In the summer of 1958, Audio Fidelity recordedin 1957, with his Essex Records office manager George Phillips, he founded Somerset Records and Somerset Stereo Fidelity Records budget albums. His greatest claim to fame was selling large amounts of cheaply priced albums, with Somerset claiming to have manufactured the first stereo budget albums. The name of Somerset high fidelity albums was suggested by Miller International's West Coast distributor, Jimmy Warren, with the name of Stereo Fidelity (stereo albums) thought of by Wally Hill to capitalize on the public's interest in both high fidelity and stereophonic sound. The economy came from Miller starting his own record factory in Swarthmore, Pennsylvania, using public domain music and non union musicians from outside the United States to record cover versions of hit songs of the time. Many original tunes were written by Monty Kelly, Robert Lowden, and Joseph Kuhn with the music published by Miller's own music publisher, Chesdel Music created in 1962. Miller had his own distribution channels of his records in supermarkets and drugstores with the cheap albums being sold in metal racks similar to those holding paperback books or cardboard record holders called "dumps" that could be placed anywhere. Miller's record albums were sold wholesale for 93 cents to salesmen who sold them to merchants who sold them to the public for $1.98. Somerset Records used artist Anthony "Chic" Laganella to create attractive eye catching album covers. Miller used the name 101 Strings for several German orchestras; their first album appearing in September 1957. In 1958 Somerset released 24 101 Strings titles. Miller International's philosophy
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: In which decade did stereo records first go on sale?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #12:
|
||||
==================================================
|
||||
Flack in 1896) to win gold medals in both the 800 m and 1500 m in the same Olympics. Billy Mills, an unfancied runner, became the only American to win the gold in the men's 10,000 m. Bob Hayes won the 100 metre title in a time of 10.06 seconds, equaling the world record, and set the current record for the fastest relay leg in the 4×100 m. Joe Frazier, future heavyweight champion of the world, won a gold medal in heavyweight boxing while competing with a broken thumb. This was the last Summer Olympics to use a cinder running track for athletic events, and the first to use fiberglass poles for pole vaulting. Zambia declared its independence on the day of the closing ceremony of the 1964 Summer Olympics, thereby becoming the first country ever to have entered an Olympic games as one country, and left it as another. This was celebrated in the ceremony itself by the team using a placard with "Zambia" instead of the "Northern Rhodesia" placard from the opening ceremony. Zambia was the only team to use a placard in the closing ceremony. The start of operations for the first Japanese "bullet train" (the Tōkaidō Shinkansen) between Tokyo Station and Shin-Ōsaka Station was scheduled to coincide with the Olympic games. The first regularly scheduled train ran on 1 October 1964, just nine days before the opening of the games, transporting passengers in about four hours, and connecting the three major metropolitan areas of Tokyo, Nagoya, and Osaka. Ranatunge Karunananda who representedsystems were used: official hand timing, hand started photo-finish times, and the Gustavus Town Kirby timing device, which was designed by Kirby to determine the correct order of finish in horse races. The official report for 1932 Olympics states: "In addition to hand timing, two auxiliary electrical timing devices were used. Both were started by an attachment to the starters gun. One was stopped by hand at the time the runners hit the tape. The other was provided with a motion picture camera which photographed the runner at the tape and the dial of the time indicator simultaneously." Kirby's system was also used at the 1932 US. Olympic Trials, where Ralph Metcalfe's winning time of 10.62 in the 100 meters is considered possibly the first automatically timed world record. FAT was also used in 1936, but very few times have been found. In 1948, Bulova began developing the Phototimer, a unique combination of photo-finish camera and precision electronic timing instrument. The Phototimer was the first automatic timing device to be used in competitive sports. It was used extensively in North America, including at the 1948 US Olympic trials. The Bulova device was activated by the sound of the starting gun firing, rather than by a direct connection, which means that the times were around 0.02 seconds faster than reality. The 1948 Olympics, however, continued to use Omega timing with a device called the 'Magic Eye', developed by British Race Finish Recording Co. Ltd. The automatic times produced in the 1948 Olympics have never been released, butWhile the most notable story coming out of 1968 was socio-political, politics involved with the Olympics was not something unique to this year. However, the year marked the beginning of several emerging elements of contemporary track and field. Automatic timing While timing to the 100th of a second had been experimented with for many years, the 1968 Summer Olympics were the first to use Fully Automatic Timing, in not only athletics, but in canoeing, rowing, cycling, equestrian and swimming competitions. Subsequently, systems to record such times became more common and thus the accuracy of Fully Automatic Timing became mandated for World Record acceptance. While this rule was officially put into place in 1977, many 1968 records still stood as the first Automatically timed record. All weather tracks This technology too had been developing, but Tartan tracks were used as the competition surface for the first time at an Olympics. Since then an all-weather running track was required for all top-level competition. Subsequently, the inconsistency of the running surface became a significantly smaller factor in athletic performance. Altitude With the Olympics happening in Mexico City, at high altitude, the effect of the thin air on athletic performance became a factor on world records. This was already a known phenomenon, and the American team was selected by holding the Olympic Trials at high altitude at Echo Summit, California. In 1955, Lou Jones set the world record in the 400 meters at altitude in Mexico City. Following the 1968 Summer Olympics the: Men's 100 meters record, set by Jim
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: In what year's Olympics were electric timing devices and a public-address system used for the first time?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #13:
|
||||
==================================================
|
||||
A list of stratovolcanoes follows below. Africa Cameroon Mount Cameroon Democratic Republic of Congo Mount Nyiragongo, Goma; designated as a Decade Volcano It contains an active lava lake inside its crater which overflowed due to cracks in 2002. Mount Mikeno Eritrea Alid Volcano Dubbi Volcano Nabro Volcano Ethiopia Adwa Borawli, Afar Region Dabbahu Volcano Mount Fentale Kenya Mount Kenya, which contains several volcanic plugs on its peak. Mount Longonot Rwanda Mount Bisoke, on the border between Rwanda and the Democratic Republic of the Congo. Mount Gahinga, on the border between Rwanda and Uganda. Mount Karisimbi, on the border between Rwanda and the Democratic Republic of the Congo. Mount Muhabura, on the border between Rwanda and Uganda. Mount Sabyinyo, marks the border between Rwanda, Uganda, and the Democratic Republic of the Congo. Tanzania Ol Doinyo Lengai, the Earth's only active carbonatite lava-producing volcano. Mount Kilimanjaro, a dormant stratovolcano. It is the highest point of Africa. Mount Meru Mid-Atlantic Ridge Mount Pico in Pico Island, Azores, Portugal Teide in Tenerife, Canary Islands, Spain; designated as a Decade Volcano Cumbre Vieja in La Palma, Canary Islands, Spain Mount Fogo in Fogo, Cape Verde Green Mountain, Ascension Island Pico de las Nieves in Gran Canaria, Canary Islands, Spain Americas Caribbean La Grande Soufrière on Basse-Terre Island, Guadeloupe Soufriere Hills on the island Montserrat Its 1995 eruptions resulted in the abandonment of its capital city, Plymouth. Soufrière on the island Saint Vincent Mount Pelée on the island Martinique Its devastating eruption on 8 May 1902 resulted in the complete destruction ofMount Kilimanjaro is a volcano in Tanzania and the highest mountain in Africa. Kilimanjaro may also refer to: Tanzania Kilimanjaro National Park comprises the whole of Mount Kilimanjaro above the tree line and six forest corridors stretching down Kilimanjaro Region, a region in Tanzania Kilimanjaro (ward), a ward in the Moshi Urban district of Kilimanjaro Region, Tanzania Kilimanjaro International Airport in Tanzania a Tanzanian beer, see Beer in Africa#Eastern Africa a Tanzanite jewellery brand owned by F. Hinds Music Killamanjaro, a Jamaican reggae sound system Albums Kilimanjaro, an album by German artist Superpitcher Kilimanjaro (The Rippingtons album), a 1988 album by The Rippingtons Kilimanjaro (The Teardrop Explodes album), an album by The Teardrop Explodes Songs "Kilimanjaro", song by The Del Vikings 1962 "Kilimanjaro", song by Manhattan Brothers 1955 "Kilimanjaro", song by The Teardrop Explodes 1980 "Kilimanjaro", song by Juluka 1984 "Kilimandjaro" (song), a 1966 French-language song by French singer Pascal Danel "Kilimanjaro" (song), a 2010 song by A.R. Rahman from the film Enthiran "Kilimanjaro", a song by KSI from the 2016 extended play Keep Up Film Kilimanjaro (film), a 2013 American film Nigeria Kilimanjaro restaurant, a fast-food chain in Nigeria. See also The Snows of Kilimanjaro (disambiguation)Mount Kilimanjaro () is a dormant volcano located in Kilimanjaro Region of Tanzania. It has three volcanic cones: Kibo, Mawenzi, and Shira. It is the highest mountain in Africa and the highest single free-standing mountain above sea level in the world: above sea level and about above its plateau base. It is the highest volcano in Africa and the Eastern Hemisphere. Kilimanjaro is the fourth most topographically prominent peak on Earth. It is part of Kilimanjaro National Park and is a major hiking and climbing destination. Because of its shrinking glaciers and ice fields, which are projected to disappear between 2025 and 2035, it has been the subject of many scientific studies. Toponymy The origin of the name Kilimanjaro is not known, but a number of theories exist. European explorers had adopted the name by 1860 and reported that Kilimanjaro was the mountain's Kiswahili name. The 1907 edition of The Nuttall Encyclopædia also records the name of the mountain as Kilima-Njaro. Johann Ludwig Krapf wrote in 1860 that Swahilis along the coast called the mountain Kilimanjaro. Although he did not offer any support, he claimed that Kilimanjaro meant either mountain of greatness or mountain of caravans. Under the latter meaning, kilima meant mountain and jaro meant caravans. Jim Thompson claimed in 1885, again without support, that the term Kilima-Njaro "has generally been understood to mean" the mountain (kilima) of greatness (njaro). He also suggested "though not improbably it may mean" the white mountain. Njaro is an ancient Kiswahili word for shining. Similarly, Krapf wrote that a
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Which volcano in Tanzania is the highest mountain in Africa?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #14:
|
||||
==================================================
|
||||
of the Libyan Draft Constitutional Charter for the Transitional Stage: The national flag shall have the following shape and dimensions: Its length shall be double its width, its shall be divided into three parallel coloured stripes, the uppermost being red, the centre black and lowest green, the black stripe shall be equal in area to the other two stripes together and shall bear in its centre a white crescent, between the two extremities of which there shall be a five-pointed white star. On 10 March 2011, France was the first country to recognise the council as the official government of Libya, as well as the first to allow the Libyan embassy staff to raise the flag. On 21 March, the flag was flown by the Permanent Mission of Libya to the United Nations and appeared on their official website, and thereafter in late August by the Arab League and by Libya's own telecommunications authority, the Libya Telecom & Technology, on its own website. In the following months many other Libyan embassies replaced the green flag of Gaddafi with the tricolour flag. This original flag of Libya is now the only flag used by the United Nations to represent Libya, according to the following UN statement: "Following the adoption by the General Assembly of resolution 66/1, the Permanent Mission of Libya to the United Nations formally notified the United Nations of a Declaration by the National Transitional Council of 3 August 2011 changing the official name of the Libyan Arab Jamahiriya to 'Libya' as well as athe flag's colours and symbols. According to Omar Faiek Shennib, "red was selected for the blood sacrificed for the freedom of Libya, black to remember the dark days that Libyans lived under the occupation of the Italians and green to represent its primary wealth, agriculture, [Libya once being referred to as the 'agricultural basket' or 'breadbasket' of the Ottoman Empire] and the future prosperity of the country. The star and crescent were placed within the black central strip of the flag as a reference to the Senussi flag and the role of King Idris in leading the country to independence". The flag's colours also echo the colours of the flags of the three regions of Libya: Fezzan (red), Cyrenaica (black), and Tripolitania (green). Under Muammar Gaddafi's dictatorship, Libya had a red-white-black flag from 1969 to 1977, and it was replaced by the all-green flag from 1977 to 2011, during which it was the only flag in the world to have one color and no design. During the Libyan Civil War against the rule of Muammar Gaddafi, the 1951–69 flag – as well as various makeshift versions without the crescent and star symbol, or without the green stripe – came back into use in areas held by the Libyan opposition and by protesters at several Libyan diplomatic missions abroad. The National Transitional Council, formed on 27 February 2011, adopted the flag previously used in the Kingdom of Libya between 1951 and 1969 as the "emblem of the Libyan Republic". The flag was officially defined in article threeThe flag of Libya from 1977 to 2011 was used by the Socialist People's Libyan Arab Jamahiriya from 1977 to 1986 and later the Great Socialist People's Libyan Arab Jamahiriya until 2011. The design is a green field in 1:2 ratio and was considered the only solid colour national flag in the world during its time. In 2011, after the collapse of Gaddafi's government, the 1951–1969 flag from the Kingdom of Libya was re-adopted but the flag introduced by Gaddafi remained in use by Pro-Gaddafists and Gaddafi loyalists. Before 1977, the country was called the Libyan Arab Republic from 1969 to 1977 and used a red-white-black flag similar to most traditional Arab national flags bearing a resemblance to the modern flag of Yemen. in 1977 after the Egyptian-Libyan War, the blank green flag was introduced to replace the red-white-black flag to avoid similarities with Egypt. History of Libya under Muammar Gaddafi Flags introduced in 1977 1977 establishments in Libya 2011 disestablishments in Libya
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: The flag of Libya is a plain rectangle of which color?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #15:
|
||||
==================================================
|
||||
la Francophonie. Places of worship Niger being a predominantly Muslim country, mosques are the most common places of worship, with the Grande Mosquée being the largest in the city. There are also various Christian churches, most notably Our Lady of Perpetual Help Cathedral and the Cathedral de Maourey. Governance Administration Niamey makes up a special capital district of Niger, which is surrounded by the Region of Tillabéri. The city of Niamey itself is governed as an autonomous first-level administrative block, the Niamey Urban Community (Fr. Communauté Urbaine de Niamey, or CUN). It includes five Urban Communes, divided into 44 "Districts" and 99 "Quartiers", including formerly independent towns. It is a co-equal first division subdivision with the seven Regions of Niger. The Niamey Urban Community includes an administration and Governor appointed by national leaders. Like the rest of Niger, Niamey has seen a decentralisation of governance since 2000. Government Ordinance n°2010–56 and Presidential Decree n°2010-679 of September 2010 mandated an elected City Council for the city of Niamey, subsumed under the CUN. This excludes some outlying areas of the CUN. Forty-five councillors are popularly elected and in turn elect the Mayor of the City of Niamey. In July 2011, the first Mayor under the new system, Oumarou Dogari Moumouni, was installed by the Governor of the CUN Mrs. Aïchatou Boulama Kané and the City Council. The City Council and Mayor have limited roles compared to the CUN Governor. Niamey has a third layer of government in the Commune system. Each Commune elects its own council, and outsidein Niger Niamey NigerNiamey () is the capital and largest city of Niger. Niamey lies on the Niger River, primarily situated on the east bank. Niamey's population was counted as 1,026,848 as of the 2012 census. As of 2017, population projections show the capital district growing at a slower rate than the country as a whole, which has the world's highest fertility rate. The city is located in a pearl millet growing region, while manufacturing industries include bricks, ceramic goods, cement, and weaving. History Niamey was probably founded in the 18th century and originated as a cluster of small villages (Gaweye, Kalley, Maourey, Zongo and Foulani Koira). Niamey was of little importance until the French developed it as a colonial centre in the late 1890s. The town, then with an estimated population of some 1,800, was chosen as the capital of the newly created Military Territory of Niger in 1905, however, the capital was shifted to the more established city of Zinder in 1912. Zinder's proximity to the Nigerian border and distance from French-controlled ports prompted the French to move the capital back to Niamey in 1926, by which time the city had some 3,000 inhabitants. A series of devastating droughts prompted significant population growth during this period, and by 1945 the population was about 8,000. Prior to 1926-27 the Upper Volta-Niger border ran along the Niger river, meaning that Niamey lay directly on the boundary. At the time of independence in 1960 the population had grown to around 30,000. The period from 1970 to 1988 was one in
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Of which African country is Niamey the capital?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #16:
|
||||
==================================================
|
||||
James Walter McCord Jr. (January 26, 1924 – June 15, 2017) was an American CIA officer, later head of security for President Richard Nixon's 1972 reelection campaign. He was involved as an electronics expert in the burglaries which precipitated the Watergate scandal. Career McCord was born in Waurika, Oklahoma. He served as a bombardier with the rank of second lieutenant in the Army Air Forces during World War II. He briefly attended Baylor University before receiving a B.B.A. from the University of Texas at Austin in 1949. In 1965, he received an M.S. in international affairs from George Washington University. After beginning his career at the Federal Bureau of Investigation (FBI), McCord worked for the Central Intelligence Agency (CIA), ultimately ascending to the GS-15 directorship of the Agency's Office of Security. For a period of time, he was in charge of physical security at the Agency's Langley headquarters. L. Fletcher Prouty, a former colonel in the United States Air Force, claimed then-Director of Central Intelligence Allen Dulles introduced McCord to him as "my top man.". In 1961, under his direction, a counter-intelligence program was launched against the Fair Play for Cuba Committee. He also held the rank of lieutenant colonel in the United States Air Force Reserve. Watergate scandal Shortly after resigning from the CIA, McCord was interviewed and then hired by Jack Caulfield in January 1972 "for strict, solely defensive security work at the Republican National Committee (RNC) and the Committee to Re-Elect the President (CRP)." Some of the money from this contract came fromadministration as assistant director of the Bureau of the Budget, devoting most of his time to Defense matters. In 1971, President Nixon appointed Schlesinger a member of the Atomic Energy Commission (AEC) and designated him as chairman. Serving in this position for about a year and a half, Schlesinger instituted extensive organizational and management changes in an effort to improve the AEC's regulatory performance. CIA Director Schlesinger was CIA Director from February 2, 1973, to July 2, 1973. He was succeeded by William Colby. Schlesinger was extremely unpopular with CIA staff, as he reduced CIA staff by 7%, and was considered a Nixon loyalist seeking to make the agency more obedient to Nixon. He had a CCTV camera installed near his official portrait at the CIA headquarters in Langley, Va., as it was believed that vandalism of the portrait by disgruntled staff was likely. Secretary of Defense (1973–1975) Schlesinger left the CIA to become Secretary of Defense on July 2, aged 44. As a university professor, researcher at Rand, and government official in three agencies, he had acquired an impressive resume in national security affairs. Nuclear strategy Shortly after assuming office, Schlesinger outlined the basic objectives that would guide his administration: maintain a "strong defense establishment"; "assure the military balance so necessary to deterrence and a more enduring peace"; obtain for members of the military "the respect, dignity and support that are their due"; assume "an . . . obligation to use our citizens' resources wisely"; and "become increasingly competitive with potential adversaries.... [W]e must nota conventional North Vietnamese assault in 1975. The CORDS model and its approach influenced U.S. strategy and thinking on counterinsurgency in the 2000s in Iraq and Afghanistan. CIA HQ: Director Colby returned to Washington in July 1971 and became executive director of CIA. After long-time DCI Richard Helms was dismissed by President Nixon in 1973, James Schlesinger assumed the helm at the Agency. A strong believer in reform of the CIA and the intelligence community more broadly, Schlesinger had written a 1971 Bureau of the Budget report outlining his views on the subject. Colby, who had had a somewhat unorthodox career in the CIA focused on political action and counterinsurgency, agreed with Schlesinger's reformist approach. Schlesinger appointed him head of the clandestine branch in early 1973. When Nixon reshuffled his agency heads and made Schlesinger secretary of defense, Colby emerged as a natural candidate for DCI—apparently on the basis of the recommendation that he was a professional who would not make waves. Colby was known as a media-friendly CIA director. His tenure as DCI, which lasted two and a half tumultuous years, was overshadowed by the Church and Pike congressional investigations into alleged U.S. intelligence malfeasance over the preceding 25 years, including 1975, the so-called Year of Intelligence. Colby's time as DCI was also eventful on the world stage. Shortly after he assumed leadership, the Yom Kippur War broke out, an event that surprised not only the American intelligence agencies but also the Israelis. This intelligence surprise reportedly affected Colby's credibility with the Nixon administration. Colby
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who was the director of the CIA from 1976-81?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #17:
|
||||
==================================================
|
||||
"On the Street Where You Live" is a song with music by Frederick Loewe and lyrics by Alan Jay Lerner from the 1956 Broadway musical My Fair Lady. It is sung in the musical by the character Freddy Eynsford-Hill, who was portrayed by John Michael King in the original production. In the 1964 film version, it was sung by Bill Shirley, dubbing for actor Jeremy Brett. Recorded versions The most popular single of the song was recorded by Vic Damone in 1956 for Columbia Records. It reached No. 4 on the Billboard chart and No. 6 on Cashbox magazine's chart. It was a No. 1 hit in the UK Singles Chart in 1958. Eddie Fisher also had a top 20 Billboard hit with the song in 1956, reaching No. 18. Lawrence Welk and His Orchestra released a version that went to No. 96 in 1956. Andy Williams' recording appeared in the Billboard top 40 in 1964, reaching No. 3 on the adult contemporary chart and No. 28 on the Billboard Hot 100. The song has been recorded by a wide variety of other performers, including Ray Conniff and Bing Crosby, who recorded the song in 1956 for use on his radio show and it was subsequently included in the boxed set The Bing Crosby CBS Radio Recordings (1954–56) issued by Mosaic Records (catalog MD7-245) in 2009, Lawrence Welk (whose band also performed it on his weekly TV series numerous times), Shirley Horn, Doris Day, George Shearing, Frank Chacksfield, Alfie Boe, Bobby Darin, Dean Martin, Mario Lanza,The Times praised it as "Alan Jay Lerner's terrific autobiography". The Street Where I Live was reissued in 1989 by Columbus Books and in 1994 by the Da Capo Press. In 2000, BBC radio broadcast a serialization of the book, read by Henry Goodman, which The Times called "one of the delights of the evening schedule". References Sources Non-fiction books about musical theatre"On the Street Where You Live" is a song from the 1956 Broadway musical My Fair Lady. On the Street Where You Live may also refer to: On the Street Where You Live (TV series), an Irish documentary television series On The Street Where You Live, a 2001 novel by Mary Higgins Clark
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Which musical featured the song The Street Where You Live?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #18:
|
||||
==================================================
|
||||
engineers were ordered to end construction work. The Allies were unaware of this and mounted further attacks on the site as part of the United States Army Air Forces experimental Operation Aphrodite, involving radio-controlled B-24 Liberators packed with explosives. Two such attacks were mounted but failed; in the second such attack, on 12 August, Lt Joseph P. Kennedy, Jr. – the elder brother of future US President John F. Kennedy – was killed when the drone aircraft exploded prematurely. By the end of the bombing campaign, over 4,100 tons of bombs had been dropped on Mimoyecques, more than on any other V-weapons site. The Mimoyecques site was never formally abandoned, but German forces left it at the start of September 1944 as the Allies advanced northeast from Normandy towards the Pas de Calais. It was captured on 5 September by the Canadian 3rd Infantry Division. Subsequent investigations and attempted demolition In September 1944, Duncan Sandys ordered the constitution of a Technical Inter-Services Mission under Colonel T.R.B. Sanders. It was given the task of investigating the V-weapons sites at Mimoyecques, Siracourt, Watten, and Wizernes, collectively known to the Allies as the "Heavy Crossbow" sites. Sanders' report was submitted to the War Cabinet on 19 March 1945. Even at this stage the true purpose of the site was unclear. Claims that it had been intended to be used for "electro-magnetic projectors" (railguns), firing huge shells at London, were debunked by Lord Cherwell, Winston Churchill's scientific adviser, who calculated that it would take sixty times the output of Battersearesearched at a facility in Peenemünde along with the V-1 flying bomb. The V-2's first target was Paris on 8 September 1944. The program while advanced proved to be an impediment to the war economy. The large capital investment was not repaid in military effectiveness. The rockets were built at an underground factory at Mittelwerk. Labor to build the A4 rockets came from the Mittelbau-Dora concentration camp. Of the 60,000 people who ended up at the camp 20,000 died, due to the appalling conditions. On 14 April 1944, Speer lost control of Organisation Todt to his Deputy, Franz Xaver Dorsch. He opposed the assassination attempt against Hitler on 20 July 1944. He was not involved in the plot, and played a minor role in the regime's efforts to regain control over Berlin after Hitler survived. After the plot Speer's rivals attacked some of his closest allies and his management system fell out of favor with radicals in the party. He lost yet more authority. Defeat of Nazi Germany Losses of territory and a dramatic expansion of the Allied strategic bombing campaign caused the collapse of the German economy from late 1944. Air attacks on the transport network were particularly effective, as they cut the main centres of production off from essential coal supplies. In January 1945, Speer told Goebbels that armaments production could be sustained for at least a year. However, he concluded that the war was lost after Soviet forces captured the important Silesian industrial region later that month. Nevertheless, Speer believed that Germany shouldof 1944 the Allies continued their gains in the Mediterranean Theatre and massed men and materiel for a European invasion along the French channel coastline. The conspirators began to organize for another attempt to assassinate Hitler and take over both German civil government and its military. The von Stauffenberg bomb attempt and aftermath By the summer of 1944 unrest in the German military and diplomatic ranks was widespread. The Allied landing at Normandy in June and failed German response raised the specter of doom among the upper ranks even of German field marshals. The Schwarze Kapelle responded by organizing a deadly attempt on Hitler's life at his Wolf's Lair compound in East Prussia. Undertaken by an aristocratic member of a hereditarily military family, Colonel Claus von Stauffenberg, the July 20 Plot nearly succeeded. Although surrounded by fatalities from the bomb Hitler escaped with a concussion and various injuries. In the aftermath he was determined to get vengeance upon the plotters. The Gestapo rounded up the members of the Schwarze Kapelle and many, many more it believed were either implicated in or sympathetic to it; according to its records it put 7,000 of them to death. Stauffenberg and three others were summarily shot that night. Most of the conspirators were put on trial in the Volksgerichtshof (People's Court) between August 1944 to February 1945. Many were executed the day after their convictions by hanging from meat hooks at Plötzensee Prison. Architect of the 1943 bomb plot on Hitler's plane Fabian von Schlabrendorff only escaped death because an
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: "Who was the target of the failed ""Bomb Plot"" of 1944?"
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #19:
|
||||
==================================================
|
||||
propelling him into the first rank of international superstars. The album contained the number-one hit "All Night Long", a Caribbean-flavored dance number that was promoted by a colorful music video produced by former Monkee Michael Nesmith. In 1984, he performed "All Night Long" at the ending ceremony of the XXIII Olympic Games in Los Angeles. Several more Top 10 hits followed, the most successful of which was the ballad "Hello" (1984), a sentimental love song that showed how far he had moved from his R&B roots. Richie had three more top ten hits in 1984, "Stuck on You" (No. 3), "Running with the Night" (No. 7) and "Penny Lover" (No. 8), as well as writing and producing "Missing You" for former labelmate and duet partner Diana Ross (No. 10 Pop, No. 1 R&B). In 1985, he wrote and performed "Say You, Say Me" for the film White Nights. The song won an Academy Award and reached No. 1 on the U.S. charts, staying there for four weeks, making it the number-two song of 1986 according to Billboards Year-End Hot 100 chart, behind the charity single "That's What Friends Are For" by Dionne and Friends. He also collaborated with Michael Jackson on the charity single "We Are the World" by USA for Africa, another number-one hit. In 1986, Richie released Dancing on the Ceiling, his last widely popular album, which produced a run of five US and UK hits, "Say You, Say Me" (U.S. No. 1), "Dancing on the Ceiling" (U.S. No. 2), "Love Will Conquer All"top 20 US R&B chart hit in 1972. Their first few recordings were released on Buddah Records, including "Hold Back the Night", which was a hit on the Billboard R&B chart in 1973, before a re-release saw it climb in the UK two years later. Several R&B hits followed during a stay with Philadelphia International subsidiary Golden Fleece (run by Baker-Harris-Young) before they signed to Atlantic Records. Their single "Disco Inferno" (1976), which was included on the Grammy Award-winning Saturday Night Fever: The Original Movie Sound Track in 1977, reached No. 11 on the Billboard Hot 100 chart in May 1978. Other major hits included "Hold Back the Night" (1975) (UK No. 5) and "That's Where the Happy People Go" (1976). In late 1977, the Trammps released the song "The Night the Lights Went Out" to commemorate the electrical blackout that affected New York City on July 13–14, 1977. Their signature song "Disco Inferno" has been covered by Tina Turner and Cyndi Lauper. In addition, Graham Parker covered "Hold Back the Night" on "The Pink Parker EP" in 1977, and reached No. 24 in the UK Singles Chart, and top 60 in the US. In 2021, "Disco Inferno" was certified Silver by the British Phonographic Industry, together with "Can We Come Together" (from the album Where the Happy People Go). Dissolution and aftermath On September 19, 2005, the group's "Disco Inferno" was inducted into the Dance Music Hall of Fame at a ceremony held in New York. The song was part-written by Ron Kersey, a producer-arranger"Hold On to the Nights" is a power ballad written and performed by American rock singer/songwriter/musician Richard Marx. This was the fourth and final single released from his self-titled debut album, and his first to reach number one on the US Billboard Hot 100 chart. The song has been re-released on numerous albums and is included on Marx's live performance DVD A Night Out with Friends (2012). Release "Hold On to the Nights" reached the Billboard Hot 100 number 1 position on July 23, 1988, preventing Def Leppard's "Pour Some Sugar on Me" from reaching the top spot that same week. The song was on the chart for twenty-one weeks, and left the chart at number 91. The song also reached at number three on the Billboard Adult Contemporary chart. Chart performance Charts Personnel Richard Marx – vocals, keyboards, acoustic piano Michael Landau – guitars Patrick O'Hearn – bass Tris Imboden – drums Paulinho da Costa – percussion Other performances Marx appeared as lounge singer/piano player Buddy Daquiri in the "Poison Fire Teats Universe" episode of the TV series Life in Pieces in 2017, in which he played the song on the piano while whistling. References 1987 songs 1988 singles Richard Marx songs Billboard Hot 100 number-one singles Songs written by Richard Marx Pop ballads Rock ballads EMI Records singles Songs about nights
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who had an 80s No 1 hit with Hold On To The Nights?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
==================================================
|
||||
PROMPT #20:
|
||||
==================================================
|
||||
Turner Classic Movies in November 2006 features directors Steven Spielberg, Clint Eastwood, and Martin Scorsese, who suggest that the string of classic films Ford directed during 1936 to 1941 was due in part to an intense six-month extramarital affair with Katharine Hepburn, the star of Mary of Scotland (1936), an Elizabethan costume drama. 1939–1941 Stagecoach (1939) was Ford's first western since 3 Bad Men in 1926, and it was his first with sound. Orson Welles claimed that he watched Stagecoach forty times in preparation for making Citizen Kane. It remains one of the most admired and imitated of all Hollywood movies, not least for its climactic stagecoach chase and the hair-raising horse-jumping scene, performed by the stuntman Yakima Canutt. The Dudley Nichols–Ben Hecht screenplay was based on an Ernest Haycox story that Ford had spotted in Collier's magazine and he purchased the screen rights for just $2500. Production chief Walter Wanger urged Ford to hire Gary Cooper and Marlene Dietrich for the lead roles, but eventually accepted Ford's decision to cast Claire Trevor as Dallas and a virtual unknown, his friend John Wayne, as Ringo; Wanger reportedly had little further influence over the production. In making Stagecoach, Ford faced entrenched industry prejudice about the now-hackneyed genre which he had helped to make so popular. Although low-budget western features and serials were still being churned out in large numbers by "Poverty Row" studios, the genre had fallen out of favor with the big studios during the 1930s and they were regarded as B-grade "pulp" movies at best.Stagecoach is a 1986 American made-for-television Western action drama film and remake of the classic 1939 film Stagecoach, directed by Ted Post and starring Kris Kristofferson as the Ringo Kid, the role originally played by John Wayne. Willie Nelson portrays famous gunslinger and dentist Doc Holliday, Johnny Cash portrays Marshal Curly Wilcox and Waylon Jennings plays the gambler Hatfield. The four main stars of the film (Nelson, Kristofferson, Cash and Jennings) were associated as members of the country music supergroup The Highwaymen. The supporting cast features Elizabeth Ashley, Anthony Newley, Tony Franciosa, Mary Crosby, June Carter Cash and Jessi Colter. Plot In 1880, a group of strangers boards the east-bound stagecoach from Tonto, Arizona Territory, to Lordsburg, New Mexico Territory. The travelers seem ordinary, but many have secrets from which they are running. Among them are Dallas, a prostitute, who is being driven out of town; an alcoholic dentist, Doc Holliday; pregnant Lucy Mallory, who is meeting her cavalry officer husband; and whiskey salesman Trevor Peacock. As the stage sets out, U.S. Cavalry Lieutenant Blanchard announces that Geronimo and his Apaches are on the warpath; his small troop will provide an escort to Dry Fork. Cast Willie Nelson as Doc Holliday Kris Kristofferson as Ringo / Ringo Kid / Bill Williams Johnny Cash as Marshal Curly Wilcox Waylon Jennings as Hatfield (Gambler) John Schneider as Buck (Overland Stage Driver) Elizabeth Ashley as Dallas Anthony Newley as Trevor Peacock (Old John's Whiskey Salesman) Tony Franciosa as Henry Gatewood (Tonto Banker) Merritt Butrick as Lieutenant Blanchard Mary CrosbyStagecoach is a 1939 American Western film directed by John Ford and starring Claire Trevor and John Wayne in his breakthrough role. The screenplay by Dudley Nichols is an adaptation of "The Stage to Lordsburg", a 1937 short story by Ernest Haycox. The film follows a group of strangers riding on a stagecoach through dangerous Apache territory. The film has long been recognized as an important work that transcends the Western genre. Philosopher Robert B. Pippin has observed that both the collection of characters and their journey "are archetypal rather than merely individual" and that the film is a "mythic representation of the American aspiration toward a form of politically meaningful equality." In 1995, the film was deemed "culturally, historically, or aesthetically significant" by the United States Library of Congress and selected for preservation in their National Film Registry. Still, Stagecoach has not avoided controversy. Like most Westerns of the era, its depiction of Native Americans as simplistic savages has been criticized. Stagecoach was the first of many Westerns that Ford shot in Monument Valley, on the Arizona–Utah border in the American Southwest. Many of the movies Ford shot there also starred John Wayne. Scenes from Stagecoach, including a sequence introducing John Wayne's character the Ringo Kid, blended shots of Monument Valley with shots filmed on the Iverson Movie Ranch in Chatsworth, California, RKO Encino Movie Ranch, and other locations. Geographic incongruities are visible throughout the film, including the closing scene where Ringo (Wayne) and Dallas (Trevor) depart Lordsburg, in southwestern New Mexico, by way of
|
||||
Think hard, but answer shortly and concisely. Only give direct answers to the questions. No additional explanations. Directly answer these questions:
|
||||
Q: Who was the man behind The Chipmunks??
|
||||
A: David Seville
|
||||
|
||||
Q: Which Lloyd Webber musical premiered in the US on 10th December 1993??
|
||||
A: Sunset Boulevard
|
||||
|
||||
Q: Who was the next British Prime Minister after Arthur Balfour??
|
||||
A: Campbell-Bannerman
|
||||
|
||||
Q: Who had a 70s No 1 hit with Kiss You All Over??
|
||||
A: Exile
|
||||
|
||||
Q: What claimed the life of singer Kathleen Ferrier??
|
||||
A: Cancer
|
||||
|
||||
Q: Who directed the classic 30s western Stagecoach?
|
||||
A:
|
||||
==================================================
|
||||
|
||||
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import multiprocessing as mp
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||
try:
|
||||
mp.set_start_method("fork", force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up (ensure embedding server shutdown and object GC)
|
||||
try:
|
||||
if hasattr(searcher, "cleanup"):
|
||||
searcher.cleanup()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
# Flush stdio to ensure message is visible before hard-exit
|
||||
try:
|
||||
import sys as _sys
|
||||
|
||||
_sys.stdout.flush()
|
||||
_sys.stderr.flush()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||
import os as _os
|
||||
|
||||
_os._exit(0)
|
||||
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test only Faiss HNSW"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
@@ -37,20 +37,20 @@ def main():
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu`")
|
||||
print(
|
||||
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
StorageContext,
|
||||
Settings,
|
||||
node_parser,
|
||||
Document,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
tracker = MemoryTracker("Faiss HNSW")
|
||||
tracker.checkpoint("Initial")
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -90,8 +90,9 @@ def main():
|
||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||
)
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
print(f"Index loaded from ./storage_faiss")
|
||||
print("Index loaded from ./storage_faiss")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
index_loaded = True
|
||||
except Exception as e:
|
||||
@@ -99,19 +100,18 @@ def main():
|
||||
print("Cleaning up corrupted index and building new one...")
|
||||
# Clean up corrupted index
|
||||
import shutil
|
||||
|
||||
if os.path.exists("./storage_faiss"):
|
||||
shutil.rmtree("./storage_faiss")
|
||||
|
||||
|
||||
if not index_loaded:
|
||||
print("Building new Faiss HNSW index...")
|
||||
|
||||
|
||||
# Use the correct Faiss building pattern from the example
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
transformations=[node_parser]
|
||||
documents, storage_context=storage_context, transformations=[node_parser]
|
||||
)
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
@@ -124,10 +124,10 @@ def main():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -141,7 +141,7 @@ def main():
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
114
benchmarks/generation_speed_bench.py
Normal file
114
benchmarks/generation_speed_bench.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
from leann.chat import get_llm
|
||||
|
||||
|
||||
def parse_prompts_from_file(file_path: str) -> list[str]:
|
||||
"""
|
||||
Parse a prompt dump file into individual prompt strings.
|
||||
|
||||
Splits by lines that look like: "PROMPT #<n>:".
|
||||
Keeps the content from each marker up to the next marker (or EOF).
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
matches = list(re.finditer(r"^PROMPT\s+#\d+:\s*$", text, flags=re.MULTILINE))
|
||||
if not matches:
|
||||
# Fallback: try a more permissive pattern
|
||||
matches = list(
|
||||
re.finditer(r"^=+\nPROMPT\s+#\d+:\n=+\s*$", text, flags=re.MULTILINE)
|
||||
)
|
||||
|
||||
prompts: list[str] = []
|
||||
if not matches:
|
||||
# No explicit markers; treat the whole file as a single prompt
|
||||
return [text]
|
||||
|
||||
for i, m in enumerate(matches):
|
||||
start = m.end()
|
||||
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
|
||||
block = text[start:end].strip()
|
||||
# Reattach the marker line content above the block for full context
|
||||
header_line_start = text.rfind("\n", 0, m.start()) + 1
|
||||
header = text[header_line_start : m.end()].strip()
|
||||
prompts.append(f"{header}\n{block}".strip())
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Iterate prompts in a dump file, time generations, print outputs, and report last-10 average time."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
default="benchmarks/data/prompts_g5/prompt_dump_nq_hnsw.txt",
|
||||
help="Path to the prompt dump file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
default="ollama",
|
||||
choices=["hf", "openai", "ollama", "gemini", "simulated"],
|
||||
help="LLM backend type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="Qwen/Qwen3-4B",
|
||||
help="Model identifier (depends on backend)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Max new tokens to generate per prompt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
llm_config = {"type": args.type, "model": args.model}
|
||||
chat = get_llm(llm_config)
|
||||
|
||||
prompts = parse_prompts_from_file(args.path)
|
||||
print(f"Found {len(prompts)} prompts in {args.path}")
|
||||
|
||||
times: list[float] = []
|
||||
for idx, prompt in enumerate(prompts, start=1):
|
||||
print("\n" + "=" * 80)
|
||||
print(f"PROMPT {idx}/{len(prompts)}")
|
||||
print("-" * 80)
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
output = chat.ask(prompt, max_tokens=args.max_tokens)
|
||||
except Exception as e:
|
||||
output = f"<error: {e}>"
|
||||
elapsed = time.perf_counter() - start
|
||||
times.append(elapsed)
|
||||
print(f"Time: {elapsed:.3f}s")
|
||||
print("-" * 80)
|
||||
print(output)
|
||||
print("=" * 80)
|
||||
|
||||
if times:
|
||||
window = times[-10:] if len(times) >= 10 else times
|
||||
avg_last_10 = mean(window)
|
||||
print(
|
||||
f"\nAverage time over last {len(window)} prompts: {avg_last_10:.3f}s"
|
||||
)
|
||||
else:
|
||||
print("No prompts processed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
batch_sizes: list[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
@@ -28,47 +28,45 @@ class BenchmarkConfig:
|
||||
|
||||
class GraphContainer:
|
||||
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
||||
self.graphs: dict[int, GraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = GraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class GraphWrapper:
|
||||
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.device = self._get_device()
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
# For MPS or CPU, just store the model
|
||||
self.use_cuda_graph = False
|
||||
self.static_output = None
|
||||
|
||||
|
||||
def _get_device(self) -> str:
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -76,22 +74,20 @@ class GraphWrapper:
|
||||
return "mps"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||
)
|
||||
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
self.static_input.copy_(input_ids)
|
||||
@@ -105,14 +101,14 @@ class GraphWrapper:
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
@@ -124,53 +120,59 @@ class ModelOptimizer:
|
||||
model = model.cpu()
|
||||
device = "cpu"
|
||||
print(f"- Model moved to {device}")
|
||||
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
|
||||
# Check if using SDPA (only on CUDA)
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Flash Attention (only on CUDA)
|
||||
if config.use_flash_attention and torch.cuda.is_available():
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
|
||||
# Memory efficient attention (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
if torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -182,7 +184,7 @@ class Timer:
|
||||
else:
|
||||
# CPU timing
|
||||
self.use_gpu_timing = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
if self.use_gpu_timing:
|
||||
@@ -195,7 +197,7 @@ class Timer:
|
||||
start_time = time.time()
|
||||
yield
|
||||
self.cpu_elapsed = time.time() - start_time
|
||||
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
if self.use_gpu_timing:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
@@ -205,14 +207,14 @@ class Timer:
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||
@@ -220,25 +222,27 @@ class Benchmark:
|
||||
self.graphs = None
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
print(f"ERROR in benchmark initialization: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
# Check if using custom 8bit quantization
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
|
||||
# Load original model (without quantization config)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
@@ -246,112 +250,121 @@ class Benchmark:
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
|
||||
# Define replacement function
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
# Get original linear layer parameters
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
|
||||
# Create 8bit linear layer
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
|
||||
# Copy weights and bias
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
|
||||
# Replace module
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
# Process child modules recursively
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
|
||||
# Replace all linear layers
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# Move model to GPU (quantization happens here)
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
# Use original Int4 quantization method
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
|
||||
# Create quantization config
|
||||
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
device_map="auto", # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
@@ -365,76 +378,83 @@ class Benchmark:
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
)
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto"
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
print(f"ERROR loading model: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
graph_wrapper: Optional[GraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if graph_wrapper is not None:
|
||||
output = graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
# Reset peak memory stats
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@@ -443,22 +463,20 @@ class Benchmark:
|
||||
pass
|
||||
else:
|
||||
print("- No GPU memory stats available")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
# Get or create graph for this batch size
|
||||
graph_wrapper = (
|
||||
self.graphs.get_or_create(batch_size)
|
||||
if self.graphs is not None
|
||||
else None
|
||||
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||
)
|
||||
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
@@ -469,44 +487,44 @@ class Benchmark:
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
# Log memory usage
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have max_memory_allocated, use 0
|
||||
peak_memory_gb = 0.0
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
print("- No GPU memory usage available")
|
||||
|
||||
|
||||
if peak_memory_gb > 0:
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
else:
|
||||
print("\n- GPU memory usage not available")
|
||||
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -566,14 +584,14 @@ def main():
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
@@ -586,45 +604,56 @@ def main():
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
||||
precision_type = (
|
||||
"int4"
|
||||
if config.use_int4
|
||||
else "int8"
|
||||
if config.use_int8
|
||||
else "fp16"
|
||||
if config.use_fp16
|
||||
else "fp32"
|
||||
)
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
"config": {
|
||||
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||
},
|
||||
"results": {str(k): v for k, v in results.items()},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
49
benchmarks/run_all.sh
Executable file
49
benchmarks/run_all.sh
Executable file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# 公共参数
|
||||
INDEX_PATH="benchmarks/data/indices/rpj_wiki/rpj_wiki"
|
||||
NUM_QUERIES=20
|
||||
BATCH_SIZE=128
|
||||
LLM_MODEL="qwen3:4b"
|
||||
TOP_K=3
|
||||
|
||||
# 日志目录(带时间戳)
|
||||
LOG_DIR="logs/eval_runs_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# dataset -> ef 列表
|
||||
declare -A EF_MAP=(
|
||||
[nq_open.jsonl]="32 62 190"
|
||||
[trivia_qa.jsonl]="77 150 249"
|
||||
[gpqa.jsonl]="41 72 124"
|
||||
[hotpot_qa.jsonl]="137 299 1199"
|
||||
)
|
||||
|
||||
# 按指定顺序遍历
|
||||
ORDERED_DATASETS=(nq_open.jsonl trivia_qa.jsonl gpqa.jsonl hotpot_qa.jsonl)
|
||||
|
||||
for dataset in "${ORDERED_DATASETS[@]}"; do
|
||||
for ef in ${EF_MAP[$dataset]}; do
|
||||
log_file="${LOG_DIR}/${dataset%.jsonl}_ef${ef}.log"
|
||||
|
||||
# 展示并记录将要执行的命令
|
||||
cmd=(python benchmarks/run_evaluation.py "$INDEX_PATH" \
|
||||
--num-queries "$NUM_QUERIES" \
|
||||
--ef "$ef" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--llm-model "$LLM_MODEL" \
|
||||
--top-k "$TOP_K" \
|
||||
--queries-file "$dataset")
|
||||
|
||||
echo "=== Running dataset=${dataset} ef=${ef} ===" | tee -a "$log_file"
|
||||
printf 'CMD: '; printf '%q ' "${cmd[@]}" | tee -a "$log_file"; echo | tee -a "$log_file"
|
||||
|
||||
# 同时输出到命令行和日志文件
|
||||
"${cmd[@]}" 2>&1 | tee -a "$log_file"
|
||||
|
||||
echo | tee -a "$log_file"
|
||||
done
|
||||
done
|
||||
|
||||
echo "All runs completed. Logs in: $LOG_DIR"
|
||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print(
|
||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||
)
|
||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -63,7 +60,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(
|
||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||
)
|
||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||
return golden_texts
|
||||
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
def load_queries(file_path: Path) -> list[str]:
|
||||
queries = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument(
|
||||
"index_path",
|
||||
type=str,
|
||||
@@ -202,26 +193,50 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||
)
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Batch size for HNSW batched search (0 disables batching)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queries-file",
|
||||
type=str,
|
||||
default="nq_open.jsonl",
|
||||
help=(
|
||||
"Queries file to use. Provide a filename under benchmarks/data/queries "
|
||||
"or an absolute path to a .jsonl file (default: nq_open.jsonl)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-type",
|
||||
type=str,
|
||||
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||
default="ollama",
|
||||
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default="qwen3:1.7b",
|
||||
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
@@ -262,9 +277,7 @@ def main():
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
@@ -293,11 +306,9 @@ def main():
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -310,14 +321,54 @@ def main():
|
||||
else:
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(
|
||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||
)
|
||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = (
|
||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
)
|
||||
# Resolve queries file (supports absolute path or name under data/queries)
|
||||
queries_file_candidate = Path(args.queries_file)
|
||||
if queries_file_candidate.is_absolute():
|
||||
queries_file = queries_file_candidate
|
||||
else:
|
||||
queries_file = data_root / "queries" / args.queries_file
|
||||
|
||||
if not queries_file.exists():
|
||||
print(f"Error: Queries file not found: {queries_file}")
|
||||
print("Tip: Use --queries-file with a filename under benchmarks/data/queries or an absolute path.")
|
||||
sys.exit(1)
|
||||
|
||||
# Infer ground-truth file from the queries filename
|
||||
qname = queries_file.name.lower()
|
||||
if "hotpot" in qname:
|
||||
task_key = "hotpot"
|
||||
elif "trivia" in qname:
|
||||
task_key = "trivia"
|
||||
elif "gpqa" in qname:
|
||||
task_key = "gpqa"
|
||||
elif "nq" in qname:
|
||||
task_key = "nq"
|
||||
else:
|
||||
print(
|
||||
"Error: Could not infer task from queries filename. Supported names include 'nq', 'hotpot', 'trivia', 'gpqa'."
|
||||
)
|
||||
print(f"Filename was: {queries_file.name}")
|
||||
sys.exit(1)
|
||||
|
||||
golden_results_file = data_root / "ground_truth" / dataset_type / f"flat_results_{task_key}_k3.json"
|
||||
if not golden_results_file.exists():
|
||||
gt_dir = data_root / "ground_truth" / dataset_type
|
||||
try:
|
||||
available = sorted(p.name for p in gt_dir.glob("flat_results_*_k3.json"))
|
||||
except Exception:
|
||||
available = []
|
||||
print(
|
||||
f"Error: Ground truth file not found for task '{task_key}' under dataset '{dataset_type}': {golden_results_file}"
|
||||
)
|
||||
if available:
|
||||
print("Available ground truth files:")
|
||||
for name in available:
|
||||
print(f" - {name}")
|
||||
else:
|
||||
print(f"No ground truth files found in {gt_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
@@ -327,7 +378,7 @@ def main():
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, "r") as f:
|
||||
with open(golden_results_file) as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
@@ -340,10 +391,23 @@ def main():
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(
|
||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||
# llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||
# chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||
# answer = chat.ask(
|
||||
# queries[i],
|
||||
# top_k=args.top_k,
|
||||
# complexity=args.ef_search,
|
||||
# batch_size=args.batch_size,
|
||||
# )
|
||||
# print(f"Answer: {answer}")
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
@@ -367,10 +431,16 @@ def main():
|
||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||
avg_time = np.mean(search_times) if search_times else 0
|
||||
|
||||
print(f"search time: {search_times}")
|
||||
|
||||
print("\n🎉 --- Evaluation Complete ---")
|
||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||
|
||||
# avg last 10 search times
|
||||
avg_last_10_search_times = np.mean(search_times[-10:])
|
||||
print(f"Avg. Last 10 Search Times: {avg_last_10_search_times:.4f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||
import traceback
|
||||
55
benchmarks/run_speed_bench_all.sh
Executable file
55
benchmarks/run_speed_bench_all.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Absolute paths (adjust if needed)
|
||||
PROMPTS_DIR="/home/tony/yichuan/leann/benchmarks/data/prompts_g5"
|
||||
SCRIPT_PATH="/home/tony/yichuan/leann/benchmarks/generation_speed_bench.py"
|
||||
|
||||
# Common args
|
||||
MAX_TOKENS=2048
|
||||
OLLAMA_MODEL="qwen3:4b"
|
||||
HF_MODEL="Qwen/Qwen3-4B"
|
||||
|
||||
# Logs
|
||||
LOG_DIR="/home/tony/yichuan/leann/logs/speed_bench_$(date +%Y%m%d_%H%M%S)"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
echo "Scanning: $PROMPTS_DIR"
|
||||
|
||||
# Iterate all .txt files under PROMPTS_DIR
|
||||
while IFS= read -r -d '' file; do
|
||||
base_name=$(basename "$file")
|
||||
stem_name="${base_name%.*}"
|
||||
|
||||
# 1) Ollama
|
||||
log_ollama="${LOG_DIR}/${stem_name}_ollama.log"
|
||||
cmd_ollama=(python "$SCRIPT_PATH" \
|
||||
--path "$file" \
|
||||
--type ollama \
|
||||
--model "$OLLAMA_MODEL" \
|
||||
--max_tokens "$MAX_TOKENS")
|
||||
|
||||
echo "=== Running (ollama) file=${file} model=${OLLAMA_MODEL} ===" | tee -a "$log_ollama"
|
||||
printf 'CMD: '; printf '%q ' "${cmd_ollama[@]}" | tee -a "$log_ollama"; echo | tee -a "$log_ollama"
|
||||
"${cmd_ollama[@]}" 2>&1 | tee -a "$log_ollama"
|
||||
echo | tee -a "$log_ollama"
|
||||
|
||||
# 2) HF
|
||||
log_hf="${LOG_DIR}/${stem_name}_hf.log"
|
||||
cmd_hf=(python "$SCRIPT_PATH" \
|
||||
--path "$file" \
|
||||
--type hf \
|
||||
--model "$HF_MODEL" \
|
||||
--max_tokens "$MAX_TOKENS")
|
||||
|
||||
echo "=== Running (hf) file=${file} model=${HF_MODEL} ===" | tee -a "$log_hf"
|
||||
printf 'CMD: '; printf '%q ' "${cmd_hf[@]}" | tee -a "$log_hf"; echo | tee -a "$log_hf"
|
||||
"${cmd_hf[@]}" 2>&1 | tee -a "$log_hf"
|
||||
echo | tee -a "$log_hf"
|
||||
|
||||
done < <(find "$PROMPTS_DIR" -type f -name '*.txt' -print0)
|
||||
|
||||
|
||||
echo "All runs completed. Logs in: $LOG_DIR"
|
||||
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel
|
||||
|
||||
# Add MLX imports
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||
MLX_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
batch_sizes: List[int] = None
|
||||
model_path: str = "facebook/contriever-msmarco"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
use_fp16: bool = True
|
||||
@@ -30,18 +31,19 @@ class BenchmarkConfig:
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
use_mlx: bool = False # New flag for MLX testing
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
"""MLX-specific benchmark for embedding models"""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model, self.tokenizer = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Load MLX model and tokenizer following the API pattern"""
|
||||
print(f"Loading MLX model from {self.config.model_path}...")
|
||||
@@ -52,55 +54,51 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Error loading MLX model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int):
|
||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
"""Run MLX inference with same input as PyTorch"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Convert PyTorch tensor to MLX array
|
||||
input_ids_mlx = mx.array(input_ids.numpy())
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = self.model(input_ids_mlx)
|
||||
|
||||
|
||||
# Mean pooling (following the API pattern)
|
||||
pooled = embeddings.mean(axis=1)
|
||||
|
||||
|
||||
# Convert to numpy (following the API pattern)
|
||||
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||
|
||||
|
||||
# Force computation
|
||||
_ = pooled_numpy.shape
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX inference error: {e}")
|
||||
return float('inf')
|
||||
return float("inf")
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
"""Run the MLX benchmark across all batch sizes"""
|
||||
results = {}
|
||||
|
||||
|
||||
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||
times = []
|
||||
|
||||
|
||||
# Create input batch (same as PyTorch)
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
|
||||
# Warm up
|
||||
print("Warming up...")
|
||||
for _ in range(3):
|
||||
@@ -109,26 +107,26 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Warmup error: {e}")
|
||||
break
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
if elapsed_time != float('inf'):
|
||||
if elapsed_time != float("inf"):
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during MLX inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"Skipping batch size {batch_size} due to errors")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
@@ -136,122 +134,133 @@ class MLXBenchmark:
|
||||
"min_time": np.min(times),
|
||||
"max_time": np.max(times),
|
||||
}
|
||||
|
||||
|
||||
print(f"MLX Results for batch size {batch_size}:")
|
||||
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f" Min Time: {np.min(times):.4f}s")
|
||||
print(f" Max Time: {np.max(times):.4f}s")
|
||||
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
self.device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
self.model = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
if self.config.use_fp16:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
model = model.to(self.device)
|
||||
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# print shape of input_ids and attention_mask
|
||||
print(f"input_ids shape: {input_ids.shape}")
|
||||
print(f"attention_mask shape: {attention_mask.shape}")
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
continue
|
||||
|
||||
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
|
||||
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_benchmark():
|
||||
"""Main function to run the benchmark with optimized parameters."""
|
||||
config = BenchmarkConfig()
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
@@ -260,55 +269,49 @@ def run_mlx_benchmark():
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available"
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
||||
use_mlx=True
|
||||
)
|
||||
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
try:
|
||||
benchmark = MLXBenchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results"
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== PyTorch Benchmark ===")
|
||||
pytorch_result = run_benchmark()
|
||||
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
print("\n=== MLX Benchmark ===")
|
||||
mlx_result = run_mlx_benchmark()
|
||||
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
# Compare results
|
||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
||||
print(f"\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||
print("\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
BIN
data/2501.14312v1 (1).pdf
Normal file
BIN
data/2501.14312v1 (1).pdf
Normal file
Binary file not shown.
7905
data/2506.08276v1.pdf
Normal file
7905
data/2506.08276v1.pdf
Normal file
File diff suppressed because it is too large
Load Diff
14905
data/PrideandPrejudice.txt
Normal file
14905
data/PrideandPrejudice.txt
Normal file
File diff suppressed because it is too large
Load Diff
105
demo.ipynb
105
demo.ipynb
@@ -1,37 +1,116 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start \n",
|
||||
"\n",
|
||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||
"\n",
|
||||
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||
"# install this if you are using colab\n",
|
||||
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||
"! uv pip install leann --no-deps\n",
|
||||
"# For Colab environment, we need to set some environment variables\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build the index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"# 1. Build the index (no embeddings stored!)\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||
"builder.add_text(\n",
|
||||
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||
")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"builder.build_index(INDEX_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with real-time embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chat with LEANN using retrieved results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"\n",
|
||||
"# 3. Chat with LEANN using retrieved results\n",
|
||||
"llm_config = {\n",
|
||||
" \"type\": \"ollama\",\n",
|
||||
" \"model\": \"llama3.2:1b\"\n",
|
||||
" \"type\": \"hf\",\n",
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||
" top_k=2,\n",
|
||||
")"
|
||||
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
## 🚀 Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install uv** (fast Python package installer):
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. **Clone the repository**:
|
||||
```bash
|
||||
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||
cd LEANN-RAG
|
||||
```
|
||||
|
||||
3. **Install system dependencies**:
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
```
|
||||
|
||||
**Ubuntu/Debian:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
```
|
||||
|
||||
4. **Build from source**:
|
||||
```bash
|
||||
# macOS
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
|
||||
# Ubuntu/Debian
|
||||
uv sync
|
||||
```
|
||||
|
||||
## 🔨 Pre-commit Hooks
|
||||
|
||||
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Our pre-commit configuration includes:
|
||||
- **Trailing whitespace removal**
|
||||
- **End-of-file fixing**
|
||||
- **YAML validation**
|
||||
- **Large file prevention**
|
||||
- **Merge conflict detection**
|
||||
- **Debug statement detection**
|
||||
- **Code formatting with ruff**
|
||||
- **Code linting with ruff**
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest test/test_filename.py
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest --cov=leann
|
||||
```
|
||||
|
||||
### Writing Tests
|
||||
|
||||
- Place tests in the `test/` directory
|
||||
- Follow the naming convention `test_*.py`
|
||||
- Use descriptive test names that explain what's being tested
|
||||
- Include both positive and negative test cases
|
||||
|
||||
## 📝 Code Style
|
||||
|
||||
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||
|
||||
### Format Your Code
|
||||
|
||||
```bash
|
||||
# Format all files
|
||||
ruff format
|
||||
|
||||
# Check formatting without changing files
|
||||
ruff format --check
|
||||
```
|
||||
|
||||
### Lint Your Code
|
||||
|
||||
```bash
|
||||
# Run linter with auto-fix
|
||||
ruff check --fix
|
||||
|
||||
# Just check without fixing
|
||||
ruff check
|
||||
```
|
||||
|
||||
### Style Guidelines
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use descriptive variable names
|
||||
- Add type hints where appropriate
|
||||
- Write docstrings for all public functions and classes
|
||||
- Keep functions focused and single-purpose
|
||||
|
||||
## 🚦 CI/CD
|
||||
|
||||
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||
|
||||
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||
4. **Wheel building**: Ensures packages can be built and distributed
|
||||
|
||||
### CI Commands
|
||||
|
||||
The CI uses the same commands as pre-commit to ensure consistency:
|
||||
```bash
|
||||
# Linting
|
||||
ruff check .
|
||||
|
||||
# Format checking
|
||||
ruff format --check .
|
||||
```
|
||||
|
||||
Make sure your code passes these checks locally before pushing!
|
||||
|
||||
## 🔄 Pull Request Process
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`:
|
||||
```bash
|
||||
git checkout -b feature/your-feature-name
|
||||
```
|
||||
|
||||
2. **Make your changes**:
|
||||
- Write clean, documented code
|
||||
- Add tests for new functionality
|
||||
- Update documentation as needed
|
||||
|
||||
3. **Run pre-commit checks**:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
4. **Test your changes**:
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
5. **Commit with descriptive messages**:
|
||||
```bash
|
||||
git commit -m "feat: add new search algorithm"
|
||||
```
|
||||
|
||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
- `feat:` for new features
|
||||
- `fix:` for bug fixes
|
||||
- `docs:` for documentation changes
|
||||
- `test:` for test additions/changes
|
||||
- `refactor:` for code refactoring
|
||||
- `perf:` for performance improvements
|
||||
|
||||
6. **Push and create a pull request**:
|
||||
- Provide a clear description of your changes
|
||||
- Reference any related issues
|
||||
- Include examples or screenshots if applicable
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
When adding new features or making significant changes:
|
||||
|
||||
1. Update relevant documentation in `/docs`
|
||||
2. Add docstrings to new functions/classes
|
||||
3. Update README.md if needed
|
||||
4. Include usage examples
|
||||
|
||||
## 🤔 Getting Help
|
||||
|
||||
- **Discord**: Join our community for discussions
|
||||
- **Issues**: Check existing issues or create a new one
|
||||
- **Discussions**: For general questions and ideas
|
||||
|
||||
## 📄 License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Release Guide
|
||||
|
||||
## Setup (One-time)
|
||||
|
||||
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||
1. Get token: https://pypi.org/manage/account/token/
|
||||
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||
|
||||
## Release (One-click)
|
||||
|
||||
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||
2. Click "Run workflow"
|
||||
3. Enter version: `0.1.2`
|
||||
4. Click green "Run workflow" button
|
||||
|
||||
That's it! The workflow will automatically:
|
||||
- ✅ Update version in all packages
|
||||
- ✅ Build all packages
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Thinking Budget Feature Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||
|
||||
## Feature Description
|
||||
|
||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||
- **`medium`**: Balanced speed and reasoning depth
|
||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. Command Line Interface
|
||||
|
||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||
|
||||
```bash
|
||||
# LEANN CLI
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# RAG Examples
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||
```
|
||||
|
||||
### 2. LLM Backend Support
|
||||
|
||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
options.pop("thinking_budget", None)
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
```
|
||||
|
||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||
|
||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
```
|
||||
|
||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||
|
||||
### 3. Parameter Propagation
|
||||
|
||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||
|
||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||
|
||||
## Files Modified
|
||||
|
||||
### Core Implementation
|
||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||
|
||||
### Documentation
|
||||
- `README.md`: Added thinking budget parameter to usage examples
|
||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||
|
||||
### Examples
|
||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# High reasoning effort for complex questions
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Medium reasoning for balanced performance
|
||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||
|
||||
# Low reasoning for fast responses
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||
```
|
||||
|
||||
### RAG Examples
|
||||
```bash
|
||||
# Email RAG with high reasoning
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Document RAG with medium reasoning
|
||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Ollama Models
|
||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||
|
||||
### OpenAI Models
|
||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||
|
||||
## Testing
|
||||
|
||||
The implementation includes comprehensive testing:
|
||||
- Parameter handling verification
|
||||
- Backend-specific API format validation
|
||||
- CLI argument parsing tests
|
||||
- Integration with existing LEANN architecture
|
||||
128
docs/ast_chunking_guide.md
Normal file
128
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# AST-Aware Code chunking guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Enable AST chunking for mixed content (code + docs)
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||
|
||||
# Specialized code repository indexing
|
||||
python -m apps.code_rag --repo-dir ./my_codebase
|
||||
|
||||
# Global CLI with AST support
|
||||
leann build my-code-index --docs ./src --use-ast-chunking
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install LEANN with AST chunking support
|
||||
uv pip install -e "."
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### When to Use AST Chunking
|
||||
|
||||
✅ **Recommended for:**
|
||||
- Code repositories with multiple languages
|
||||
- Mixed documentation and code content
|
||||
- Complex codebases with deep function/class hierarchies
|
||||
- When working with Claude Code for code assistance
|
||||
|
||||
❌ **Not recommended for:**
|
||||
- Pure text documents
|
||||
- Very large files (>1MB)
|
||||
- Languages not supported by tree-sitter
|
||||
|
||||
### Optimal Configuration
|
||||
|
||||
```bash
|
||||
# Recommended settings for most codebases
|
||||
python -m apps.code_rag \
|
||||
--repo-dir ./src \
|
||||
--ast-chunk-size 768 \
|
||||
--ast-chunk-overlap 96 \
|
||||
--exclude-dirs .git __pycache__ node_modules build dist
|
||||
```
|
||||
|
||||
### Supported Languages
|
||||
|
||||
| Extension | Language | Status |
|
||||
|-----------|----------|--------|
|
||||
| `.py` | Python | ✅ Full support |
|
||||
| `.java` | Java | ✅ Full support |
|
||||
| `.cs` | C# | ✅ Full support |
|
||||
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### Document RAG with Code Support
|
||||
|
||||
```python
|
||||
# Enable code chunking in document RAG
|
||||
python -m apps.document_rag \
|
||||
--enable-code-chunking \
|
||||
--data-dir ./project \
|
||||
--query "How does authentication work in the codebase?"
|
||||
```
|
||||
|
||||
### Claude Code Integration
|
||||
|
||||
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||
- Code completion and suggestions
|
||||
- Bug analysis and debugging
|
||||
- Architecture understanding
|
||||
- Refactoring assistance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Fallback to Traditional Chunking**
|
||||
- Normal behavior for unsupported languages
|
||||
- Check logs for specific language support
|
||||
|
||||
2. **Performance with Large Files**
|
||||
- Adjust `--max-file-size` parameter
|
||||
- Use `--exclude-dirs` to skip unnecessary directories
|
||||
|
||||
3. **Quality Issues**
|
||||
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||
- Adjust overlap for better context preservation
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
export LEANN_LOG_LEVEL=DEBUG
|
||||
python -m apps.code_rag --repo-dir ./my_code
|
||||
```
|
||||
|
||||
## Migration from Traditional Chunking
|
||||
|
||||
Existing workflows continue to work without changes. To enable AST chunking:
|
||||
|
||||
```bash
|
||||
# Before
|
||||
python -m apps.document_rag --chunk-size 256
|
||||
|
||||
# After (maintains traditional chunking for non-code files)
|
||||
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||
|
||||
---
|
||||
|
||||
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comparison between Sentence Transformers and OpenAI embeddings
|
||||
|
||||
This example shows how different embedding models handle complex queries
|
||||
and demonstrates the differences between local and API-based embeddings.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
# OpenAI API key should be set as environment variable
|
||||
# export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Test data
|
||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||
|
||||
# Two queries with same intent but different wording
|
||||
query1 = "Tell me my browser history about some conference i often visit"
|
||||
query2 = "browser history about conference I often visit"
|
||||
|
||||
texts = [query1, query2, conference_text, browser_text]
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) # Already normalized
|
||||
|
||||
|
||||
def analyze_embeddings(embeddings, model_name):
|
||||
print(f"\n=== {model_name} Results ===")
|
||||
|
||||
# Results for Query 1
|
||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||
|
||||
print(f"Query 1: '{query1}'")
|
||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||
|
||||
# Results for Query 2
|
||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||
|
||||
print(f"\nQuery 2: '{query2}'")
|
||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||
|
||||
# Show the impact
|
||||
print("\n=== Impact Analysis ===")
|
||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||
|
||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||
print("✅ STABLE: Conference remains winner in both queries")
|
||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||
print("✅ STABLE: Browser remains winner in both queries")
|
||||
else:
|
||||
print("🔄 MIXED: Results vary between queries")
|
||||
|
||||
return {
|
||||
"query1_conf": sim1_conf,
|
||||
"query1_browser": sim1_browser,
|
||||
"query2_conf": sim2_conf,
|
||||
"query2_browser": sim2_browser,
|
||||
}
|
||||
|
||||
|
||||
# Test Sentence Transformers
|
||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||
try:
|
||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||
except Exception as e:
|
||||
print(f"❌ Sentence Transformers failed: {e}")
|
||||
st_results = None
|
||||
|
||||
# Test OpenAI
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing OpenAI (text-embedding-3-small)...")
|
||||
try:
|
||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||
except Exception as e:
|
||||
print(f"❌ OpenAI failed: {e}")
|
||||
openai_results = None
|
||||
|
||||
# Compare results
|
||||
if st_results and openai_results:
|
||||
print("\n" + "=" * 60)
|
||||
print("=== COMPARISON SUMMARY ===")
|
||||
384
docs/configuration-guide.md
Normal file
384
docs/configuration-guide.md
Normal file
@@ -0,0 +1,384 @@
|
||||
# LEANN Configuration Guide
|
||||
|
||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||
|
||||
## Getting Started: Simple is Better
|
||||
|
||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||
|
||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||
```bash
|
||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||
```
|
||||
|
||||
**For other data sources**: Limit the dataset size for quick testing
|
||||
```bash
|
||||
# WeChat: Test with recent messages only
|
||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||
|
||||
# Browser history: Last few days
|
||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||
|
||||
# Email: Recent inbox
|
||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||
```
|
||||
|
||||
Once validated, scale up gradually:
|
||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||
- This helps identify issues early before committing to long processing times
|
||||
|
||||
## Embedding Model Selection: Understanding the Trade-offs
|
||||
|
||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||
|
||||
### Small Models (< 100M parameters)
|
||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||
- **Pros**: Lightweight, fast for both indexing and inference
|
||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||
|
||||
### Medium Models (100M-500M parameters)
|
||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||
- **Cons**: Requires more compute than small models
|
||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||
|
||||
### Large Models (500M+ parameters)
|
||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||
- **Cons**: Slower inference, longer index build times
|
||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||
|
||||
### Quick Start: Cloud and Local Embedding Options
|
||||
|
||||
**OpenAI Embeddings (Fastest Setup)**
|
||||
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||
```bash
|
||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
**Ollama Embeddings (Privacy-Focused)**
|
||||
For local embeddings with complete privacy:
|
||||
```bash
|
||||
# First, pull an embedding model
|
||||
ollama pull nomic-embed-text
|
||||
|
||||
# Use Ollama embeddings
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||
|
||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||
- **Pros**: No local compute needed, consistently fast, high quality
|
||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||
|
||||
**Local Embeddings**
|
||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||
|
||||
</details>
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||
- Full recomputation required
|
||||
- High memory usage during build phase
|
||||
- Excellent recall (95%+)
|
||||
|
||||
```bash
|
||||
# Optimal for most use cases
|
||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
### DiskANN
|
||||
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||
|
||||
**Key advantages:**
|
||||
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||
- **Better scaling**: Designed for 100k+ documents
|
||||
|
||||
**Recompute behavior:**
|
||||
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||
|
||||
```bash
|
||||
# Recommended for most use cases
|
||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||
|
||||
## LLM Selection: Engine and Model Comparison
|
||||
|
||||
### LLM Engines
|
||||
|
||||
**OpenAI** (`--llm openai`)
|
||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||
|
||||
**Ollama** (`--llm ollama`)
|
||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||
|
||||
**HuggingFace** (`--llm hf`)
|
||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||
- **Cons**: More complex initial setup
|
||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||
|
||||
## Parameter Tuning Guide
|
||||
|
||||
### Search Complexity Parameters
|
||||
|
||||
**`--build-complexity`** (index building)
|
||||
- Controls thoroughness during index construction
|
||||
- Higher = better recall but slower build
|
||||
- Recommendations:
|
||||
- 32: Quick prototyping
|
||||
- 64: Balanced (default)
|
||||
- 128: Production systems
|
||||
- 256: Maximum quality
|
||||
|
||||
**`--search-complexity`** (query time)
|
||||
- Controls search thoroughness
|
||||
- Higher = better results but slower
|
||||
- Recommendations:
|
||||
- 16: Fast/Interactive search
|
||||
- 32: High quality with diversity
|
||||
- 64+: Maximum accuracy
|
||||
|
||||
### Top-K Selection
|
||||
|
||||
**`--top-k`** (number of retrieved chunks)
|
||||
- More chunks = better context but slower LLM processing
|
||||
- Should be always smaller than `--search-complexity`
|
||||
- Guidelines:
|
||||
- 10-20: General questions (default: 20)
|
||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||
|
||||
**Trade-off formula**:
|
||||
- Retrieval time ∝ log(n) × search_complexity
|
||||
- LLM processing time ∝ top_k × chunk_size
|
||||
- Total context = top_k × chunk_size tokens
|
||||
|
||||
### Thinking Budget for Reasoning Models
|
||||
|
||||
**`--thinking-budget`** (reasoning effort level)
|
||||
- Controls the computational effort for reasoning models
|
||||
- Options: `low`, `medium`, `high`
|
||||
- Guidelines:
|
||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||
- `medium`: Balanced speed and reasoning depth
|
||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||
- **Supported Models**:
|
||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||
|
||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||
|
||||
**💡 Quick Examples:**
|
||||
```bash
|
||||
# OpenAI o-series reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm openai --llm-model o3 --thinking-budget medium
|
||||
|
||||
# Ollama reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
```
|
||||
|
||||
### Graph Degree (HNSW/DiskANN)
|
||||
|
||||
**`--graph-degree`**
|
||||
- Number of connections per node in the graph
|
||||
- Higher = better recall but more memory
|
||||
- HNSW: 16-32 (default: 32)
|
||||
- DiskANN: 32-128 (default: 64)
|
||||
|
||||
|
||||
## Performance Optimization Checklist
|
||||
|
||||
### If Embedding is Too Slow
|
||||
|
||||
1. **Switch to smaller model**:
|
||||
```bash
|
||||
# From large model
|
||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||
# To small model
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
2. **Limit dataset size for testing**:
|
||||
```bash
|
||||
--max-items 1000 # Process first 1k items only
|
||||
```
|
||||
|
||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||
```bash
|
||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||
```
|
||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||
|
||||
4. **Use Ollama**
|
||||
```bash
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||
### If Search Quality is Poor
|
||||
|
||||
1. **Increase retrieval count**:
|
||||
```bash
|
||||
--top-k 30 # Retrieve more candidates
|
||||
```
|
||||
|
||||
2. **Upgrade embedding model**:
|
||||
```bash
|
||||
# For English
|
||||
--embedding-model BAAI/bge-base-en-v1.5
|
||||
# For multilingual
|
||||
--embedding-model intfloat/multilingual-e5-large
|
||||
```
|
||||
|
||||
## Understanding the Trade-offs
|
||||
|
||||
Every configuration choice involves trade-offs:
|
||||
|
||||
| Factor | Small/Fast | Large/Quality |
|
||||
|--------|------------|---------------|
|
||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||
| Chunk Size | 512 tokens | 128 tokens |
|
||||
| Index Type | HNSW | DiskANN |
|
||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||
|
||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||
|
||||
## Low-resource setups
|
||||
|
||||
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||
|
||||
### 1) Use OpenAI embeddings (no local compute)
|
||||
|
||||
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=sk-...
|
||||
|
||||
# Build with OpenAI embeddings
|
||||
leann build my-index \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-3-small
|
||||
|
||||
# Search with OpenAI embeddings (recompute at query time)
|
||||
leann search my-index "your query" \
|
||||
--recompute
|
||||
```
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
pip install skypilot
|
||||
|
||||
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||
sky launch -c leann-gpu sky/leann-build.yaml
|
||||
|
||||
# Override parameters via -e key=value (optional)
|
||||
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||
-e index_name=my-index \
|
||||
-e backend=hnsw \
|
||||
-e embedding_mode=sentence-transformers \
|
||||
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||
|
||||
# Copy the built index back to your local .leann (use rsync)
|
||||
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||
```
|
||||
|
||||
### 3) Disable recomputation to trade storage for speed
|
||||
|
||||
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||
|
||||
```bash
|
||||
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||
leann build my-index --no-recompute --no-compact
|
||||
|
||||
# Search without recomputation
|
||||
leann search my-index "your query" --no-recompute
|
||||
```
|
||||
|
||||
When to use:
|
||||
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||
- Read-heavy workloads where storage is cheaper than latency
|
||||
- No always-available GPU
|
||||
|
||||
Constraints:
|
||||
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||
|
||||
Storage impact:
|
||||
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||
|
||||
Converting an existing index (rebuild required):
|
||||
```bash
|
||||
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||
leann build my-index --force --no-recompute --no-compact
|
||||
```
|
||||
|
||||
Python API usage:
|
||||
```python
|
||||
from leann import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||
```
|
||||
|
||||
Trade-offs:
|
||||
- Lower latency and fewer network hops at query time
|
||||
- Significantly higher storage (10–100× vs selective recomputation)
|
||||
- Slightly larger memory footprint during build and search
|
||||
|
||||
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||
|
||||
- HNSW
|
||||
|
||||
```text
|
||||
recompute=True: search_time=0.818s, size=1.1MB
|
||||
recompute=False: search_time=0.012s, size=16.6MB
|
||||
```
|
||||
|
||||
- DiskANN
|
||||
|
||||
```text
|
||||
recompute=True: search_time=0.041s, size=5.9MB
|
||||
recompute=False: search_time=0.013s, size=24.6MB
|
||||
```
|
||||
|
||||
Conclusion:
|
||||
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||
|
||||
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# FAQ
|
||||
|
||||
## 1. My building time seems long
|
||||
|
||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
23
docs/features.md
Normal file
23
docs/features.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# ✨ Detailed Features
|
||||
|
||||
## 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# LEANN Metadata Filtering Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Adding Metadata to Your Documents
|
||||
|
||||
When building your index, add metadata to each text chunk:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder("hnsw")
|
||||
|
||||
# Add text with metadata
|
||||
builder.add_text(
|
||||
text="Chapter 1: Alice falls down the rabbit hole",
|
||||
metadata={
|
||||
"chapter": 1,
|
||||
"character": "Alice",
|
||||
"themes": ["adventure", "curiosity"],
|
||||
"word_count": 150
|
||||
}
|
||||
)
|
||||
|
||||
builder.build_index("alice_in_wonderland_index")
|
||||
```
|
||||
|
||||
### Searching with Metadata Filters
|
||||
|
||||
Use the `metadata_filters` parameter in search calls:
|
||||
|
||||
```python
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="What happens to Alice?",
|
||||
top_k=10,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Filter Syntax
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"field_name": {"operator": value},
|
||||
"another_field": {"operator": value}
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
#### Comparison Operators
|
||||
- `"=="`: Equal to
|
||||
- `"!="`: Not equal to
|
||||
- `"<"`: Less than
|
||||
- `"<="`: Less than or equal
|
||||
- `">"`: Greater than
|
||||
- `">="`: Greater than or equal
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||
{"page": {">": 100}} # Pages after 100
|
||||
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||
{"word_count": {"<": 500}} # Short passages
|
||||
```
|
||||
|
||||
#### Membership Operators
|
||||
- `"in"`: Value is in list
|
||||
- `"not_in"`: Value is not in list
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||
```
|
||||
|
||||
#### String Operators
|
||||
- `"contains"`: String contains substring
|
||||
- `"starts_with"`: String starts with prefix
|
||||
- `"ends_with"`: String ends with suffix
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||
{"filename": {"ends_with": ".py"}} # Python files
|
||||
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||
```
|
||||
|
||||
#### Boolean Operators
|
||||
- `"is_true"`: Field is truthy
|
||||
- `"is_false"`: Field is falsy
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"is_published": {"is_true": True}} # Published content
|
||||
{"is_draft": {"is_false": False}} # Not drafts
|
||||
```
|
||||
|
||||
### Multiple Operators on Same Field
|
||||
|
||||
You can apply multiple operators to the same field (AND logic):
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"word_count": {
|
||||
">=": 100, # At least 100 words
|
||||
"<=": 500 # At most 500 words
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Compound Filters
|
||||
|
||||
Multiple fields are combined with AND logic:
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"chapter": {"<=": 10}, # Up to chapter 10
|
||||
"character": {"==": "Alice"}, # About Alice
|
||||
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||
}
|
||||
```
|
||||
|
||||
## Use Case Examples
|
||||
|
||||
### 1. Spoiler-Free Book Search
|
||||
|
||||
```python
|
||||
# Reader has only read up to chapter 5
|
||||
def search_spoiler_free(query, max_chapter):
|
||||
return searcher.search(
|
||||
query=query,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": max_chapter},
|
||||
"spoiler_level": {"in": ["none", "low"]}
|
||||
}
|
||||
)
|
||||
|
||||
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||
```
|
||||
|
||||
### 2. Document Management by Date
|
||||
|
||||
```python
|
||||
# Find recent documents
|
||||
recent_docs = searcher.search(
|
||||
query="project updates",
|
||||
metadata_filters={
|
||||
"date": {">=": "2024-01-01"},
|
||||
"document_type": {"==": "report"}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Code Search by File Type
|
||||
|
||||
```python
|
||||
# Search only Python files
|
||||
python_code = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Content Filtering by Audience
|
||||
|
||||
```python
|
||||
# Age-appropriate content
|
||||
family_content = searcher.search(
|
||||
query="adventure stories",
|
||||
metadata_filters={
|
||||
"age_rating": {"in": ["G", "PG"]},
|
||||
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Multi-Book Series Management
|
||||
|
||||
```python
|
||||
# Search across first 3 books only
|
||||
early_series = searcher.search(
|
||||
query="character development",
|
||||
metadata_filters={
|
||||
"series": {"==": "Harry Potter"},
|
||||
"book_number": {"<=": 3}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||
|
||||
```bash
|
||||
# Don't forget to set up the environment
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Run the spoiler-free book RAG example
|
||||
uv run examples/spoiler_free_book_rag.py
|
||||
```
|
||||
|
||||
This example demonstrates:
|
||||
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||
- Different scenarios for readers at various points in the book
|
||||
|
||||
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Chunking with metadata
|
||||
|
||||
```python
|
||||
def chunk_book_with_metadata(book_text, book_info):
|
||||
chunks = []
|
||||
|
||||
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||
# Extract entities, themes, etc.
|
||||
characters = extract_characters(chapter_text)
|
||||
themes = classify_themes(chapter_text)
|
||||
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||
|
||||
# Create chunks with rich metadata
|
||||
for paragraph in split_paragraphs(chapter_text):
|
||||
chunks.append({
|
||||
"text": paragraph,
|
||||
"metadata": {
|
||||
"book_title": book_info["title"],
|
||||
"chapter": chapter_num,
|
||||
"characters": characters,
|
||||
"themes": themes,
|
||||
"spoiler_level": spoiler_level,
|
||||
"word_count": len(paragraph.split()),
|
||||
"reading_level": calculate_reading_level(paragraph)
|
||||
}
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Efficient Filtering Strategies
|
||||
|
||||
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||
|
||||
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||
|
||||
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||
|
||||
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||
|
||||
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||
|
||||
### Adding Metadata to Existing Indices
|
||||
|
||||
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||
|
||||
```python
|
||||
# Read existing passages and add metadata
|
||||
def add_metadata_to_existing_chunks(chunks):
|
||||
for chunk in chunks:
|
||||
# Extract or assign metadata based on content
|
||||
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||
return chunks
|
||||
|
||||
# Rebuild index with metadata
|
||||
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||
builder = LeannBuilder("hnsw")
|
||||
for chunk in enhanced_chunks:
|
||||
builder.add_text(chunk["text"], chunk["metadata"])
|
||||
builder.build_index("enhanced_index")
|
||||
```
|
||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Normalized Embeddings Support in LEANN
|
||||
|
||||
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||
|
||||
## What are Normalized Embeddings?
|
||||
|
||||
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||
|
||||
## Automatic Detection
|
||||
|
||||
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||
|
||||
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||
2. **Show a warning** if you manually specify a different distance metric
|
||||
3. **Provide optimal search performance** with the correct metric
|
||||
|
||||
## Supported Normalized Embedding Models
|
||||
|
||||
### OpenAI
|
||||
All OpenAI text embedding models are normalized:
|
||||
- `text-embedding-ada-002`
|
||||
- `text-embedding-3-small`
|
||||
- `text-embedding-3-large`
|
||||
|
||||
### Voyage AI
|
||||
All Voyage AI embedding models are normalized:
|
||||
- `voyage-2`
|
||||
- `voyage-3`
|
||||
- `voyage-large-2`
|
||||
- `voyage-multilingual-2`
|
||||
- `voyage-code-2`
|
||||
|
||||
### Cohere
|
||||
All Cohere embedding models are normalized:
|
||||
- `embed-english-v3.0`
|
||||
- `embed-multilingual-v3.0`
|
||||
- `embed-english-light-v3.0`
|
||||
- `embed-multilingual-light-v3.0`
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
# Automatic detection - will use cosine distance
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai"
|
||||
)
|
||||
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||
# Automatically setting distance_metric='cosine'
|
||||
|
||||
# Manual override (not recommended)
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
distance_metric="mips" # Will show warning
|
||||
)
|
||||
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||
```
|
||||
|
||||
## Non-Normalized Embeddings
|
||||
|
||||
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||
|
||||
## Why This Matters
|
||||
|
||||
Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# 📈 Roadmap
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] HNSW backend integration
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
## 🚀 Q3 2025
|
||||
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||
- [ ] Add OpenAI recompute API
|
||||
|
||||
## 🌟 Q4 2025
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
88
examples/basic_demo.py
Normal file
88
examples/basic_demo.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/basic_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simple demo of Leann with selectable embedding models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||
print()
|
||||
|
||||
# Sample knowledge base
|
||||
chunks = [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Computer vision enables machines to interpret and understand visual information from images and videos.",
|
||||
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
|
||||
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
|
||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||
]
|
||||
|
||||
print("1. Building index (no embeddings stored)...")
|
||||
builder = LeannBuilder(
|
||||
embedding_model=args.embedding_model,
|
||||
backend_name="hnsw",
|
||||
)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk)
|
||||
builder.build_index("demo_knowledge.leann")
|
||||
print()
|
||||
|
||||
print("2. Searching with real-time embeddings...")
|
||||
searcher = LeannSearcher("demo_knowledge.leann")
|
||||
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"How does neural network work?",
|
||||
"Tell me about data processing",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print(f"Query: {query}")
|
||||
results = searcher.search(query, top_k=2)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f" {i}. Score: {result.score:.3f}")
|
||||
print(f" Text: {result.text[:100]}...")
|
||||
print()
|
||||
|
||||
print("3. Interactive chat demo:")
|
||||
print(" (Note: Requires OpenAI API key for real responses)")
|
||||
|
||||
chat = LeannChat("demo_knowledge.leann")
|
||||
|
||||
# Demo questions
|
||||
demo_questions: list[str] = [
|
||||
"What is the difference between machine learning and deep learning?",
|
||||
"How is data science related to big data?",
|
||||
]
|
||||
|
||||
for question in demo_questions:
|
||||
print(f" Q: {question}")
|
||||
response = chat.ask(question)
|
||||
print(f" A: {response}")
|
||||
print()
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python apps/document_rag.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,122 +0,0 @@
|
||||
import os
|
||||
import email
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split('\n', 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get('Subject', 'No Subject')
|
||||
from_addr = msg.get('From', 'Unknown')
|
||||
to_addr = msg.get('To', 'Unknown')
|
||||
date = msg.get('Date', 'Unknown')
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||
if part.get_content_type() == "text/html" and not self.include_html:
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
@@ -1,285 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import argparse
|
||||
try:
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=profile_path,
|
||||
max_count=max_count
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} history documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||
parser.add_argument('--max-entries', type=int, default=1000,
|
||||
help='Maximum number of history entries to process (default: 1000)')
|
||||
parser.add_argument('--query', type=str, default=None,
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||
help='Automatically find all Chrome profiles (default: True)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming"
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,288 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
# Default mail path for macOS
|
||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1 # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path,
|
||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
start_time = time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=12,
|
||||
beam_width=1,
|
||||
|
||||
)
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
|
||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||
parser.add_argument('--max-emails', type=int, default=1000,
|
||||
help='Maximum number of emails to process (-1 means all)')
|
||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||
help='Single query to run (default: runs example queries)')
|
||||
parser.add_argument('--include-html', action='store_true', default=False,
|
||||
help='Include HTML content in email processing (default: False)')
|
||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||
help='Embedding model to use (default: facebook/contriever)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
|
||||
print('len(messages_dirs): ', len(messages_dirs))
|
||||
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "="*60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,115 +0,0 @@
|
||||
import argparse
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import asyncio
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Leann Chat with various LLM backends."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
|
||||
# Define the path for our new MLX-based index
|
||||
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||
@@ -38,7 +39,5 @@ chat = LeannChat(index_path=INDEX_PATH)
|
||||
# add query
|
||||
query = "MLX is an array framework for machine learning on Apple silicon."
|
||||
print(f"Query: {query}")
|
||||
response = chat.ask(
|
||||
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
|
||||
)
|
||||
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
|
||||
print(f"Response: {response}")
|
||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||
|
||||
This example demonstrates how to use LEANN's metadata filtering to create
|
||||
a spoiler-free book RAG system where users can search for information
|
||||
up to a specific chapter they've read.
|
||||
|
||||
Usage:
|
||||
python spoiler_free_book_rag.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
# Add LEANN to path (adjust path as needed)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Create sample book chunks with metadata for demonstration.
|
||||
|
||||
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||
and extract chapter boundaries, character mentions, etc.
|
||||
|
||||
Args:
|
||||
book_title: Title of the book
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with text and metadata
|
||||
"""
|
||||
# Sample book chunks with metadata
|
||||
# In practice, you'd use proper text processing libraries
|
||||
|
||||
sample_chunks = [
|
||||
{
|
||||
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 1,
|
||||
"characters": ["Alice", "Sister"],
|
||||
"themes": ["boredom", "curiosity"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 2,
|
||||
"characters": ["Alice", "White Rabbit"],
|
||||
"themes": ["decision", "surprise", "magic"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 2,
|
||||
"page": 15,
|
||||
"characters": ["Alice"],
|
||||
"themes": ["falling", "wonder", "transformation"],
|
||||
"location": "rabbit hole",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 6,
|
||||
"page": 85,
|
||||
"characters": ["Alice", "Cheshire Cat"],
|
||||
"themes": ["madness", "philosophy", "identity"],
|
||||
"location": "Duchess's house",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 8,
|
||||
"page": 120,
|
||||
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||
"themes": ["justice", "absurdity", "authority"],
|
||||
"location": "Queen's court",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 12,
|
||||
"page": 180,
|
||||
"characters": ["Alice", "Sister", "Rabbit"],
|
||||
"themes": ["revelation", "reality", "growth"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return sample_chunks
|
||||
|
||||
|
||||
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||
"""
|
||||
Build a LEANN index with book chunks that include spoiler metadata.
|
||||
|
||||
Args:
|
||||
book_chunks: List of book chunks with metadata
|
||||
index_name: Name for the index
|
||||
|
||||
Returns:
|
||||
Path to the built index
|
||||
"""
|
||||
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||
|
||||
# Initialize LEANN builder
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||
)
|
||||
|
||||
# Add each chunk with its metadata
|
||||
for chunk in book_chunks:
|
||||
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||
|
||||
# Build the index
|
||||
index_path = f"{index_name}_book_index"
|
||||
builder.build_index(index_path)
|
||||
|
||||
print(f"✅ Index built successfully: {index_path}")
|
||||
return index_path
|
||||
|
||||
|
||||
def spoiler_free_search(
|
||||
index_path: str,
|
||||
query: str,
|
||||
max_chapter: int,
|
||||
character_filter: Optional[list[str]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform a spoiler-free search on the book index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: Search query
|
||||
max_chapter: Maximum chapter number to include
|
||||
character_filter: Optional list of characters to focus on
|
||||
|
||||
Returns:
|
||||
List of search results safe for the reader
|
||||
"""
|
||||
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||
|
||||
if character_filter:
|
||||
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||
|
||||
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def demo_spoiler_free_rag():
|
||||
"""
|
||||
Demonstrate the spoiler-free book RAG system.
|
||||
"""
|
||||
print("🎭 Spoiler-Free Book RAG Demo")
|
||||
print("=" * 40)
|
||||
|
||||
# Step 1: Prepare book data
|
||||
book_title = "Alice's Adventures in Wonderland"
|
||||
book_chunks = chunk_book_with_metadata(book_title)
|
||||
|
||||
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||
|
||||
# Step 2: Build the index (in practice, this would be done once)
|
||||
try:
|
||||
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||
print(
|
||||
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||
)
|
||||
return
|
||||
|
||||
# Step 3: Demonstrate various spoiler-free searches
|
||||
search_scenarios = [
|
||||
{
|
||||
"description": "Reader who has only read Chapter 1",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 1,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read up to Chapter 5",
|
||||
"query": "Tell me about Alice's adventures",
|
||||
"max_chapter": 5,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read most of the book",
|
||||
"query": "What does the Cheshire Cat represent?",
|
||||
"max_chapter": 10,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read the whole book",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 12,
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in search_scenarios:
|
||||
print(f"\n📚 Scenario: {scenario['description']}")
|
||||
print(f" Query: {scenario['query']}")
|
||||
|
||||
try:
|
||||
results = spoiler_free_search(
|
||||
index_path=index_path,
|
||||
query=scenario["query"],
|
||||
max_chapter=scenario["max_chapter"],
|
||||
)
|
||||
|
||||
print(f" 📄 Found {len(results)} results:")
|
||||
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||
chapter = result.metadata.get("chapter", "?")
|
||||
location = result.metadata.get("location", "?")
|
||||
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Search failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||
print("=====================================")
|
||||
|
||||
try:
|
||||
demo_spoiler_free_rag()
|
||||
except ImportError as e:
|
||||
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error running demo: {e}")
|
||||
@@ -1,319 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import dotenv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
import requests
|
||||
import time
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: List[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(
|
||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||
)
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
export_dir: str = None,
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from WeChat chat history data.
|
||||
|
||||
Args:
|
||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process
|
||||
"""
|
||||
print("Creating LEANN index from WeChat chat history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=export_dir,
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} chat documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print(f"--- Building new LEANN index ---")
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
@@ -1 +1 @@
|
||||
# This file makes the directory a Python package
|
||||
# This file makes the directory a Python package
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from . import diskann_backend
|
||||
from . import diskann_backend as diskann_backend
|
||||
from . import graph_partition
|
||||
|
||||
# Export main classes and functions
|
||||
from .graph_partition import GraphPartitioner, partition_graph
|
||||
|
||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import numpy as np
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import contextlib
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
import numpy as np
|
||||
import psutil
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +22,11 @@ logger = logging.getLogger(__name__)
|
||||
@contextlib.contextmanager
|
||||
def suppress_cpp_output_if_needed():
|
||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
||||
if os.getenv("CI") == "true":
|
||||
yield
|
||||
return
|
||||
|
||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
|
||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||
@@ -85,6 +90,43 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||
|
||||
Args:
|
||||
data: The embedding data array
|
||||
|
||||
Returns:
|
||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
||||
"""
|
||||
num_vectors, dim = data.shape
|
||||
|
||||
# Calculate embedding storage size
|
||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
||||
|
||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
||||
# This controls Product Quantization size - smaller means more compression
|
||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
||||
|
||||
# build_memory_maximum: Based on available system RAM for sharding control
|
||||
# This controls how much memory DiskANN uses during index construction
|
||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
||||
|
||||
logger.info(
|
||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
||||
)
|
||||
|
||||
return search_memory_gb, build_memory_gb
|
||||
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -100,7 +142,72 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||
"""
|
||||
Safely cleanup files after partition.
|
||||
In partition mode, C++ doesn't read _disk.index content,
|
||||
so we can delete it if all derived files exist.
|
||||
"""
|
||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||
|
||||
# Required files that C++ partition mode needs
|
||||
# Note: C++ generates these with _disk.index suffix
|
||||
disk_suffix = "_disk.index"
|
||||
required_files = [
|
||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||
]
|
||||
|
||||
# Check if all required files exist
|
||||
missing_files = []
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if not file_path.exists():
|
||||
missing_files.append(filename)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(
|
||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||
)
|
||||
logger.info("Keeping all original files for safety")
|
||||
return
|
||||
|
||||
# Calculate space savings
|
||||
space_saved = 0
|
||||
files_to_delete = []
|
||||
|
||||
if disk_index_file.exists():
|
||||
space_saved += disk_index_file.stat().st_size
|
||||
files_to_delete.append(disk_index_file)
|
||||
|
||||
if beam_search_file.exists():
|
||||
space_saved += beam_search_file.stat().st_size
|
||||
files_to_delete.append(beam_search_file)
|
||||
|
||||
# Safe to delete!
|
||||
for file_to_delete in files_to_delete:
|
||||
try:
|
||||
os.remove(file_to_delete)
|
||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||
|
||||
if space_saved > 0:
|
||||
space_saved_mb = space_saved / (1024 * 1024)
|
||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||
|
||||
# Show what files are kept
|
||||
logger.info("📁 Kept essential files for partition mode:")
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
@@ -114,6 +221,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
|
||||
# Extract is_recompute from nested backend_kwargs if needed
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||
|
||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||
if "backend_kwargs" in build_kwargs:
|
||||
nested_params = build_kwargs.pop("backend_kwargs")
|
||||
build_kwargs.update(nested_params)
|
||||
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
@@ -122,6 +240,16 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
# Calculate smart memory configuration if not explicitly provided
|
||||
if (
|
||||
"search_memory_maximum" not in build_kwargs
|
||||
or "build_memory_maximum" not in build_kwargs
|
||||
):
|
||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
||||
else:
|
||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
@@ -132,12 +260,36 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_prefix,
|
||||
build_kwargs.get("complexity", 64),
|
||||
build_kwargs.get("graph_degree", 32),
|
||||
build_kwargs.get("search_memory_maximum", 4.0),
|
||||
build_kwargs.get("build_memory_maximum", 8.0),
|
||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
||||
build_kwargs.get("num_threads", 8),
|
||||
build_kwargs.get("pq_disk_bytes", 0),
|
||||
"",
|
||||
)
|
||||
|
||||
# Auto-partition if is_recompute is enabled
|
||||
if build_kwargs.get("is_recompute", False):
|
||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||
from .graph_partition import partition_graph
|
||||
|
||||
# Partition the index using absolute paths
|
||||
# Convert to absolute paths to avoid issues with working directory changes
|
||||
absolute_index_dir = Path(index_dir).resolve()
|
||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
index_prefix_path=absolute_index_prefix_path,
|
||||
output_dir=str(absolute_index_dir),
|
||||
partition_prefix=index_prefix,
|
||||
)
|
||||
|
||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||
|
||||
logger.info("✅ Graph partitioning completed successfully!")
|
||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||
logger.info(f" - Partition file: {partition_bin_path}")
|
||||
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
@@ -164,18 +316,69 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||
# C++ internally constructs: index_prefix + "_disk.index"
|
||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||
|
||||
# Auto-detect partition files and set partition_prefix
|
||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||
|
||||
partition_prefix = ""
|
||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||
# C++ expects full path prefix, not just filename
|
||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
else:
|
||||
logger.debug("No partition files detected, using standard index files")
|
||||
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
# Log partition configuration for debugging
|
||||
if partition_prefix:
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||
|
||||
def _ensure_index_loaded(self, zmq_port: int):
|
||||
"""Ensure the index is loaded with the correct zmq_port."""
|
||||
if self._index is None or self._current_zmq_port != zmq_port:
|
||||
# Need to (re)load the index with the correct zmq_port
|
||||
with suppress_cpp_output_if_needed():
|
||||
if self._index is not None:
|
||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||
else:
|
||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||
|
||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||
self._init_params["metric_enum"],
|
||||
self._init_params["full_index_prefix"],
|
||||
self._init_params["num_threads"],
|
||||
self._init_params["num_nodes_to_cache"],
|
||||
self._init_params["cache_mechanism"],
|
||||
zmq_port,
|
||||
self._init_params["pq_prefix"],
|
||||
self._init_params["partition_prefix"],
|
||||
)
|
||||
self._current_zmq_port = zmq_port
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -190,7 +393,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using DiskANN index.
|
||||
|
||||
@@ -213,18 +416,15 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(
|
||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||
)
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
self._ensure_index_loaded(zmq_port)
|
||||
else:
|
||||
# If not recomputing, we still need an index, use a default port
|
||||
if self._index is None:
|
||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
@@ -241,7 +441,14 @@ class DiskannSearcher(BaseSearcher):
|
||||
else: # "global"
|
||||
use_global_pruning = True
|
||||
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
# Strategy:
|
||||
# - Traversal always uses PQ distances
|
||||
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
||||
# (fetch embeddings for the final candidate set only)
|
||||
# - Do not recompute neighbor distances along the path
|
||||
use_deferred_fetch = True if recompute_embeddings else False
|
||||
recompute_neighors = False # Expected typo. For backward compatibility.
|
||||
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
@@ -250,17 +457,15 @@ class DiskannSearcher(BaseSearcher):
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
use_deferred_fetch,
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
recompute_neighors,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -3,16 +3,17 @@ DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -36,6 +37,7 @@ def create_diskann_embedding_server(
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
distance_metric: str = "l2",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
@@ -50,8 +52,8 @@ def create_diskann_embedding_server(
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
@@ -76,13 +78,12 @@ def create_diskann_embedding_server(
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||
|
||||
# Import protobuf after ensuring the path is correct
|
||||
try:
|
||||
@@ -100,8 +101,9 @@ def create_diskann_embedding_server(
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -150,9 +152,7 @@ def create_diskann_embedding_server(
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||
)
|
||||
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
@@ -167,9 +167,7 @@ def create_diskann_embedding_server(
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
@@ -180,9 +178,7 @@ def create_diskann_embedding_server(
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(
|
||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||
) # Show first 5
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
@@ -199,9 +195,7 @@ def create_diskann_embedding_server(
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
@@ -226,30 +220,217 @@ def create_diskann_embedding_server(
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
"""ZMQ server thread that respects shutdown signal.
|
||||
|
||||
This creates its own REP socket, binds to zmq_port, and periodically
|
||||
checks shutdown_event using recv timeouts to exit cleanly.
|
||||
"""
|
||||
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||
|
||||
context = zmq.Context()
|
||||
rep_socket = context.socket(zmq.REP)
|
||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
# Set receive timeout so we can check shutdown_event periodically
|
||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
e2e_start = time.time()
|
||||
# REP socket receives single-part messages
|
||||
message = rep_socket.recv()
|
||||
|
||||
# Check for empty messages - REP socket requires response to every request
|
||||
if not message:
|
||||
logger.warning("Received empty message, sending empty response")
|
||||
rep_socket.send(b"")
|
||||
continue
|
||||
|
||||
# Try protobuf first (same logic as original)
|
||||
texts = []
|
||||
is_text_request = False
|
||||
|
||||
try:
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = list(req_proto.node_ids)
|
||||
|
||||
# Look up texts by node IDs
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
|
||||
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||
except Exception:
|
||||
# Fallback to msgpack for text requests
|
||||
try:
|
||||
import msgpack
|
||||
|
||||
request = msgpack.unpackb(message)
|
||||
if isinstance(request, list) and all(
|
||||
isinstance(item, str) for item in request
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception:
|
||||
logger.error("Both protobuf and msgpack parsing failed!")
|
||||
# Send error response
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
rep_socket.send(resp_proto.SerializeToString())
|
||||
continue
|
||||
|
||||
# Process the request
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Validation
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error("NaN or Inf detected in embeddings!")
|
||||
# Send error response
|
||||
if is_text_request:
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb([])
|
||||
else:
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
response_data = resp_proto.SerializeToString()
|
||||
rep_socket.send(response_data)
|
||||
continue
|
||||
|
||||
# Prepare response based on request type
|
||||
if is_text_request:
|
||||
# For direct text requests, return msgpack
|
||||
import msgpack
|
||||
|
||||
response_data = msgpack.packb(embeddings.tolist())
|
||||
else:
|
||||
# For protobuf requests, return protobuf
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# Send response back to the client
|
||||
rep_socket.send(response_data)
|
||||
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
except zmq.Again:
|
||||
# Timeout - check shutdown_event and continue
|
||||
continue
|
||||
except Exception as e:
|
||||
if not shutdown_event.is_set():
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
try:
|
||||
# Send error response for REP socket
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
rep_socket.send(resp_proto.SerializeToString())
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
rep_socket.close(0)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||
|
||||
# Add shutdown coordination
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_zmq_server():
|
||||
"""Gracefully shutdown ZMQ server."""
|
||||
logger.info("Initiating graceful shutdown...")
|
||||
shutdown_event.set()
|
||||
|
||||
if zmq_thread.is_alive():
|
||||
logger.info("Waiting for ZMQ thread to finish...")
|
||||
zmq_thread.join(timeout=5)
|
||||
if zmq_thread.is_alive():
|
||||
logger.warning("ZMQ thread did not finish in time")
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
# Note: socket and context are cleaned up by thread exit
|
||||
logger.info("ZMQ resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||
|
||||
# Clean up other resources
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
logger.info("Additional resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning additional resources: {e}")
|
||||
|
||||
logger.info("Graceful shutdown completed")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers within this function scope
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
shutdown_zmq_server()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Start ZMQ thread (NOT daemon!)
|
||||
zmq_thread = threading.Thread(
|
||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||
daemon=False, # Not daemon - we want to wait for it
|
||||
)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
while not shutdown_event.is_set():
|
||||
time.sleep(0.1) # Check shutdown more frequently
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DiskANN Server shutting down...")
|
||||
shutdown_zmq_server()
|
||||
return
|
||||
|
||||
# If we reach here, shutdown was triggered by signal
|
||||
logger.info("Main loop exited, process should be shutting down")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
# Signal handlers are now registered within create_diskann_embedding_server
|
||||
|
||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
@@ -268,9 +449,16 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "mips", "cosine"],
|
||||
help="Distance metric for similarity computation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -280,4 +468,5 @@ if __name__ == "__main__":
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: embedding.proto
|
||||
# ruff: noqa
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||
)
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides Python bindings for the graph partition functionality
|
||||
of DiskANN, allowing users to partition disk-based indices for better
|
||||
performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
"""
|
||||
A Python interface for DiskANN's graph partition functionality.
|
||||
|
||||
This class provides methods to partition disk-based indices for improved
|
||||
search performance and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, build_type: str = "release"):
|
||||
"""
|
||||
Initialize the GraphPartitioner.
|
||||
|
||||
Args:
|
||||
build_type: Build type for the executables ("debug" or "release")
|
||||
"""
|
||||
self.build_type = build_type
|
||||
self._ensure_executables()
|
||||
|
||||
def _get_executable_path(self, name: str) -> str:
|
||||
"""Get the path to a graph partition executable."""
|
||||
# Get the directory where this Python module is located
|
||||
module_dir = Path(__file__).parent
|
||||
# Navigate to the graph_partition directory
|
||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||
|
||||
if not executable_path.exists():
|
||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||
|
||||
return str(executable_path)
|
||||
|
||||
def _ensure_executables(self):
|
||||
"""Ensure that the required executables are built."""
|
||||
try:
|
||||
self._get_executable_path("partitioner")
|
||||
self._get_executable_path("index_relayout")
|
||||
except FileNotFoundError:
|
||||
# Try to build the executables automatically
|
||||
print("Executables not found, attempting to build them...")
|
||||
self._build_executables()
|
||||
|
||||
def _build_executables(self):
|
||||
"""Build the required executables."""
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Clean any existing build
|
||||
if (graph_partition_dir / "build").exists():
|
||||
shutil.rmtree(graph_partition_dir / "build")
|
||||
|
||||
# Run the build script
|
||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||
|
||||
# Check if executables were created
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
print(f"✅ Built partitioner: {partitioner_path}")
|
||||
print(f"✅ Built index_relayout: {relayout_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build executables: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def partition_graph(
|
||||
self,
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning:
|
||||
- gp_times: Number of LDG partition iterations (default: 10)
|
||||
- lock_nums: Number of lock nodes (default: 10)
|
||||
- cut: Cut adjacency list degree (default: 100)
|
||||
- scale_factor: Scale factor (default: 1)
|
||||
- data_type: Data type (default: "float")
|
||||
- thread_nums: Number of threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the partitioning process fails
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine partition prefix
|
||||
if partition_prefix is None:
|
||||
partition_prefix = Path(index_prefix_path).name
|
||||
|
||||
# Get executable paths
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Change to the graph_partition directory for temporary files
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Create temporary data directory
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Run partitioner
|
||||
gp_file_path = graph_gp_path / "_part.bin"
|
||||
partitioner_cmd = [
|
||||
partitioner_path,
|
||||
"--index_file",
|
||||
old_index_file,
|
||||
"--data_type",
|
||||
params["data_type"],
|
||||
"--gp_file",
|
||||
str(gp_file_path),
|
||||
"-T",
|
||||
str(params["thread_nums"]),
|
||||
"--ldg_times",
|
||||
str(params["gp_times"]),
|
||||
"--scale",
|
||||
str(params["scale_factor"]),
|
||||
"--mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||
result = subprocess.run(
|
||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Partitioner failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Run relayout
|
||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||
relayout_cmd = [
|
||||
relayout_path,
|
||||
old_index_file,
|
||||
str(gp_file_path),
|
||||
params["data_type"],
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||
result = subprocess.run(
|
||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Relayout failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Copy results to output directory
|
||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||
|
||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||
shutil.copy2(gp_file_path, partition_bin_path)
|
||||
|
||||
print(f"Results copied to: {output_dir}")
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||
"""
|
||||
Get information about a partition file.
|
||||
|
||||
Args:
|
||||
partition_bin_path: Path to the partition binary file
|
||||
|
||||
Returns:
|
||||
Dictionary containing partition information
|
||||
"""
|
||||
if not os.path.exists(partition_bin_path):
|
||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||
|
||||
# For now, return basic file information
|
||||
# In the future, this could parse the binary file for detailed info
|
||||
stat = os.stat(partition_bin_path)
|
||||
return {
|
||||
"file_size": stat.st_size,
|
||||
"file_path": partition_bin_path,
|
||||
"modified_time": stat.st_mtime,
|
||||
}
|
||||
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
build_type: Build type for executables ("debug" or "release")
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
partitioner = GraphPartitioner(build_type=build_type)
|
||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example: partition an index
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||
)
|
||||
print("Partitioning completed successfully!")
|
||||
print(f"Disk graph index: {disk_graph_path}")
|
||||
print(f"Partition binary: {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Partitioning failed: {e}")
|
||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
version = "0.3.2"
|
||||
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
@@ -16,4 +16,6 @@ wheel.packages = ["leann_backend_diskann"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
# Let CMake find packages via Homebrew prefix
|
||||
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 25339b0341...c593831474
@@ -2,12 +2,12 @@ syntax = "proto3";
|
||||
|
||||
package protoembedding;
|
||||
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
}
|
||||
|
||||
message NodeEmbeddingResponse {
|
||||
bytes embeddings_data = 1; // All embedded binary datas
|
||||
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
||||
repeated uint32 missing_ids = 3; // Missing node ids
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,28 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
||||
if(EXISTS "/opt/homebrew/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/opt/homebrew")
|
||||
elseif(EXISTS "/usr/local/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/usr/local")
|
||||
else()
|
||||
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
||||
endif()
|
||||
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||
|
||||
# Set minimum macOS version for better compatibility
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
@@ -52,4 +69,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
add_subdirectory(third_party/faiss)
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import hnsw_backend
|
||||
from . import hnsw_backend as hnsw_backend
|
||||
|
||||
@@ -1,87 +1,122 @@
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Set up logging to avoid print buffer issues
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# --- FourCCs (add more if needed) ---
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||
|
||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
|
||||
|
||||
# --- Helper functions for reading/writing binary data ---
|
||||
|
||||
|
||||
def read_struct(f, fmt):
|
||||
"""Reads data according to the struct format."""
|
||||
size = struct.calcsize(fmt)
|
||||
data = f.read(size)
|
||||
if len(data) != size:
|
||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
||||
raise EOFError(
|
||||
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
|
||||
)
|
||||
return struct.unpack(fmt, data)[0]
|
||||
|
||||
|
||||
def read_vector_raw(f, element_fmt_char):
|
||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||
count = -1 # Initialize count
|
||||
total_bytes = -1 # Initialize total_bytes
|
||||
count = -1 # Initialize count
|
||||
total_bytes = -1 # Initialize total_bytes
|
||||
try:
|
||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
||||
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
|
||||
element_size = struct.calcsize(element_fmt_char)
|
||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||
if count > max_reasonable_count or count < 0:
|
||||
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
||||
raise MemoryError(
|
||||
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
|
||||
)
|
||||
|
||||
total_bytes = count * element_size
|
||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||
raise MemoryError(
|
||||
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
|
||||
)
|
||||
|
||||
data_bytes = f.read(total_bytes)
|
||||
|
||||
if len(data_bytes) != total_bytes:
|
||||
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
||||
raise EOFError(
|
||||
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
|
||||
)
|
||||
return count, data_bytes
|
||||
except (MemoryError, OverflowError) as e:
|
||||
# Add context to the error message
|
||||
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
||||
raise e # Re-raise the original error type
|
||||
# Add context to the error message
|
||||
print(
|
||||
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e # Re-raise the original error type
|
||||
|
||||
|
||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||
"""Reads a vector into a NumPy array."""
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(
|
||||
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||
if count > 0 and len(data_bytes) > 0:
|
||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||
if arr.size != count:
|
||||
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
||||
raise ValueError(
|
||||
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
|
||||
)
|
||||
return arr
|
||||
elif count == 0:
|
||||
return np.array([], dtype=np_dtype)
|
||||
return np.array([], dtype=np_dtype)
|
||||
else:
|
||||
raise ValueError("Read zero bytes but count > 0.")
|
||||
raise ValueError("Read zero bytes but count > 0.")
|
||||
except MemoryError as e:
|
||||
# Now count should be defined (or -1 if error was in read_struct)
|
||||
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
||||
print(
|
||||
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
except Exception as e: # Catch other potential errors like ValueError
|
||||
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
||||
except Exception as e: # Catch other potential errors like ValueError
|
||||
print(
|
||||
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||
count = arr.size
|
||||
f.write(struct.pack('<Q', count))
|
||||
f.write(struct.pack("<Q", count))
|
||||
try:
|
||||
expected_dtype = np.dtype(struct_fmt_char)
|
||||
if arr.dtype != expected_dtype:
|
||||
@@ -89,23 +124,30 @@ def write_numpy_vector(f, arr, struct_fmt_char):
|
||||
else:
|
||||
data_to_write = arr.tobytes()
|
||||
f.write(data_to_write)
|
||||
del data_to_write # Hint GC
|
||||
del data_to_write # Hint GC
|
||||
except MemoryError as e:
|
||||
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
||||
raise e
|
||||
print(
|
||||
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def write_list_vector(f, lst, struct_fmt_char):
|
||||
"""Writes a Python list as a vector iteratively."""
|
||||
count = len(lst)
|
||||
f.write(struct.pack('<Q', count))
|
||||
fmt = '<' + struct_fmt_char
|
||||
f.write(struct.pack("<Q", count))
|
||||
fmt = "<" + struct_fmt_char
|
||||
chunk_size = 1024 * 1024
|
||||
element_size = struct.calcsize(fmt)
|
||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||
try:
|
||||
buffer = bytearray(chunk_size * element_size)
|
||||
except MemoryError:
|
||||
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
||||
print(
|
||||
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise
|
||||
buffer_count = 0
|
||||
|
||||
@@ -116,66 +158,80 @@ def write_list_vector(f, lst, struct_fmt_char):
|
||||
buffer_count += 1
|
||||
|
||||
if buffer_count == chunk_size or i == count - 1:
|
||||
f.write(buffer[:buffer_count * element_size])
|
||||
f.write(buffer[: buffer_count * element_size])
|
||||
buffer_count = 0
|
||||
|
||||
except struct.error as e:
|
||||
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
||||
print(
|
||||
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||
if level < 0: return 0
|
||||
if level < 0:
|
||||
return 0
|
||||
if level < len(cum_nneighbor_per_level_np):
|
||||
return cum_nneighbor_per_level_np[level]
|
||||
else:
|
||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||
|
||||
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data):
|
||||
|
||||
def write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
storage_fourcc,
|
||||
storage_data,
|
||||
):
|
||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||
# Write IndexHNSW Header
|
||||
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
||||
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||
|
||||
# Write HNSW struct parts (standard order)
|
||||
write_numpy_vector(f_out, assign_probas_np, 'd')
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
||||
write_numpy_vector(f_out, levels_np, 'i')
|
||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||
write_numpy_vector(f_out, levels_np, "i")
|
||||
|
||||
# Write compact format flag
|
||||
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
||||
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
|
||||
|
||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||
if isinstance(compact_level_ptr, np.ndarray):
|
||||
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
||||
write_numpy_vector(f_out, compact_level_ptr, "Q")
|
||||
else:
|
||||
write_list_vector(f_out, compact_level_ptr, 'Q')
|
||||
|
||||
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
||||
write_list_vector(f_out, compact_level_ptr, "Q")
|
||||
|
||||
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
|
||||
|
||||
# Write HNSW scalar parameters
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
||||
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||
|
||||
# Write storage fourcc (this determines how to read what follows)
|
||||
f_out.write(struct.pack('<I', storage_fourcc))
|
||||
|
||||
f_out.write(struct.pack("<I", storage_fourcc))
|
||||
|
||||
# Write compact neighbors data AFTER storage fourcc
|
||||
write_list_vector(f_out, compact_neighbors_data, 'i')
|
||||
|
||||
write_list_vector(f_out, compact_neighbors_data, "i")
|
||||
|
||||
# Write storage data if not NULL (only after neighbors)
|
||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||
f_out.write(storage_data)
|
||||
@@ -183,185 +239,244 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
|
||||
|
||||
# --- Main Conversion Logic ---
|
||||
|
||||
|
||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||
"""
|
||||
Converts an HNSW graph file to the CSR format.
|
||||
Supports both original and already-compact formats (backward compatibility).
|
||||
|
||||
|
||||
Args:
|
||||
input_filename: Input HNSW index file
|
||||
output_filename: Output CSR index file
|
||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||
"""
|
||||
# Keep prints simple; rely on CI runner to flush output as needed
|
||||
|
||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||
start_time = time.time()
|
||||
original_hnsw_data = {}
|
||||
neighbors_np = None # Initialize to allow check in finally block
|
||||
neighbors_np = None # Initialize to allow check in finally block
|
||||
try:
|
||||
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
||||
|
||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||
# --- Read IndexHNSW FourCC and Header ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||
# ... (Keep the header reading logic as before) ...
|
||||
hnsw_index_fourcc = read_struct(f_in, '<I')
|
||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
||||
return False
|
||||
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
||||
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
||||
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
||||
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['metric_arg'] = 0.0
|
||||
if original_hnsw_data['metric_type'] > 1:
|
||||
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
||||
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
||||
|
||||
print(
|
||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["metric_arg"] = 0.0
|
||||
if original_hnsw_data["metric_type"] > 1:
|
||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
|
||||
)
|
||||
|
||||
# --- Read original HNSW struct data ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||
gc.collect()
|
||||
|
||||
ntotal = len(levels_np)
|
||||
if ntotal != original_hnsw_data['ntotal']:
|
||||
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
||||
original_hnsw_data['ntotal'] = ntotal
|
||||
if ntotal != original_hnsw_data["ntotal"]:
|
||||
print(
|
||||
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
original_hnsw_data["ntotal"] = ntotal
|
||||
|
||||
# --- Check for compact format flag ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||
pos_before_compact = f_in.tell()
|
||||
try:
|
||||
is_compact_flag = read_struct(f_in, '<?')
|
||||
is_compact_flag = read_struct(f_in, "<?")
|
||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||
|
||||
|
||||
if is_compact_flag:
|
||||
# Input is already in compact format - read compact data
|
||||
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
||||
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
||||
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
|
||||
)
|
||||
|
||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
|
||||
)
|
||||
|
||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
|
||||
)
|
||||
|
||||
# Read scalar parameters
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||
)
|
||||
|
||||
# Read storage fourcc
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
||||
|
||||
storage_fourcc = read_struct(f_in, "<I")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
|
||||
)
|
||||
|
||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||
# Read compact neighbors data
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
|
||||
)
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
|
||||
# Skip storage data and write with NULL marker
|
||||
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
|
||||
)
|
||||
storage_fourcc = NULL_INDEX_FOURCC
|
||||
elif not prune_embeddings:
|
||||
# Read and preserve compact neighbors and storage
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
|
||||
|
||||
# Read remaining storage data
|
||||
storage_data = f_in.read()
|
||||
else:
|
||||
# Already pruned (NULL storage)
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||
del compact_neighbors_data_np
|
||||
storage_data = b''
|
||||
|
||||
storage_data = b""
|
||||
|
||||
# Write the updated compact format
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
storage_fourcc,
|
||||
storage_data if not prune_embeddings else b"",
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||
return True
|
||||
|
||||
|
||||
else:
|
||||
# is_compact=False, rewind and read original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
# No compact flag found, assume original format
|
||||
f_in.seek(pos_before_compact)
|
||||
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
|
||||
)
|
||||
|
||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
|
||||
)
|
||||
pos_before_probe = f_in.tell()
|
||||
try:
|
||||
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
||||
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
|
||||
if suspected_flag == 0x00:
|
||||
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
|
||||
)
|
||||
elif suspected_flag == 0x01:
|
||||
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
|
||||
)
|
||||
raise ValueError("Inconsistent compact flag state")
|
||||
else:
|
||||
# Rewind - this byte is part of offsets data
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
|
||||
)
|
||||
except EOFError:
|
||||
f_in.seek(pos_before_probe)
|
||||
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
|
||||
)
|
||||
|
||||
# --- Read original format data ---
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||
if len(offsets_np) != ntotal + 1:
|
||||
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
||||
raise ValueError(
|
||||
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||
if neighbors_np.size != expected_neighbors_size:
|
||||
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
||||
print(
|
||||
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||
)
|
||||
|
||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||
storage_fourcc = None
|
||||
try:
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
||||
storage_fourcc = read_struct(f_in, "<I")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
|
||||
)
|
||||
except EOFError:
|
||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||
except Exception as e:
|
||||
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
|
||||
)
|
||||
|
||||
# --- Perform Conversion ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||
@@ -373,17 +488,21 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
|
||||
current_level_ptr_idx = 0
|
||||
current_data_idx = 0
|
||||
total_valid_neighbors_counted = 0 # For validation
|
||||
total_valid_neighbors_counted = 0 # For validation
|
||||
|
||||
# Optimize calculation by getting slices once per node if possible
|
||||
for i in range(ntotal):
|
||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||
progress = (i / ntotal) * 100
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
||||
print(
|
||||
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
|
||||
end="",
|
||||
)
|
||||
|
||||
node_max_level = levels_np[i] - 1
|
||||
if node_max_level < -1: node_max_level = -1
|
||||
if node_max_level < -1:
|
||||
node_max_level = -1
|
||||
|
||||
node_ptr_start_index = current_level_ptr_idx
|
||||
compact_node_offsets_np[i] = node_ptr_start_index
|
||||
@@ -394,13 +513,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
for level in range(node_max_level + 1):
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
|
||||
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
||||
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
||||
begin_orig_np = original_offset_start + get_cum_neighbors(
|
||||
cum_nneighbor_per_level_np, level
|
||||
)
|
||||
end_orig_np = original_offset_start + get_cum_neighbors(
|
||||
cum_nneighbor_per_level_np, level + 1
|
||||
)
|
||||
|
||||
begin_orig = int(begin_orig_np)
|
||||
end_orig = int(end_orig_np)
|
||||
|
||||
neighbors_len = len(neighbors_np) # Cache length
|
||||
neighbors_len = len(neighbors_np) # Cache length
|
||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||
|
||||
@@ -413,83 +536,117 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
|
||||
if num_valid > 0:
|
||||
# Append valid neighbors
|
||||
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
||||
compact_neighbors_data.extend(
|
||||
level_neighbors_slice[valid_neighbors_mask]
|
||||
)
|
||||
current_data_idx += num_valid
|
||||
total_valid_neighbors_counted += num_valid
|
||||
|
||||
|
||||
compact_level_ptr.append(current_data_idx)
|
||||
current_level_ptr_idx += num_pointers_expected
|
||||
|
||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
||||
print(
|
||||
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
|
||||
) # Clear progress line
|
||||
|
||||
# --- Validation Checks ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||
valid_check_passed = True
|
||||
# Check 1: Total valid neighbors count
|
||||
print(f" Checking total valid neighbor count...")
|
||||
print(" Checking total valid neighbor count...")
|
||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
if expected_valid_count != len(compact_neighbors_data):
|
||||
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||
|
||||
# Check 2: Final pointer indices consistency
|
||||
print(f" Checking final pointer indices...")
|
||||
print(" Checking final pointer indices...")
|
||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
||||
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||
valid_check_passed = False
|
||||
print(
|
||||
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
if (
|
||||
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
|
||||
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||
print(
|
||||
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
valid_check_passed = False
|
||||
else:
|
||||
print(f" OK: Final pointers match data size.")
|
||||
print(" OK: Final pointers match data size.")
|
||||
|
||||
if not valid_check_passed:
|
||||
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
||||
print(
|
||||
"Error: Validation checks failed. Output file might be incorrect.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Optional: Exit here if validation fails
|
||||
# return False
|
||||
|
||||
# --- Explicitly delete large intermediate arrays ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
|
||||
)
|
||||
del neighbors_np
|
||||
del offsets_np
|
||||
gc.collect()
|
||||
|
||||
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
||||
print(
|
||||
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
|
||||
)
|
||||
|
||||
# --- Write CSR HNSW graph data using unified function ---
|
||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
||||
|
||||
print(
|
||||
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
|
||||
)
|
||||
|
||||
# Determine storage fourcc and data based on prune_embeddings
|
||||
if prune_embeddings:
|
||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
||||
print(" Pruning embeddings: Writing NULL storage marker.")
|
||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b''
|
||||
storage_data = b""
|
||||
else:
|
||||
# Keep embeddings - read and preserve original storage data
|
||||
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||
print(f" Preserving embeddings: Reading original storage data...")
|
||||
print(" Preserving embeddings: Reading original storage data...")
|
||||
storage_data = f_in.read() # Read remaining storage data
|
||||
output_storage_fourcc = storage_fourcc
|
||||
print(f" Read {len(storage_data)} bytes of storage data")
|
||||
else:
|
||||
print(f" No embeddings found in original file (NULL storage)")
|
||||
print(" No embeddings found in original file (NULL storage)")
|
||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||
storage_data = b''
|
||||
|
||||
storage_data = b""
|
||||
|
||||
# Use the unified write function
|
||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||
compact_neighbors_data, output_storage_fourcc, storage_data)
|
||||
|
||||
write_compact_format(
|
||||
f_out,
|
||||
original_hnsw_data,
|
||||
assign_probas_np,
|
||||
cum_nneighbor_per_level_np,
|
||||
levels_np,
|
||||
compact_level_ptr,
|
||||
compact_node_offsets_np,
|
||||
compact_neighbors_data,
|
||||
output_storage_fourcc,
|
||||
storage_data,
|
||||
)
|
||||
|
||||
# Clean up memory
|
||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
|
||||
@@ -503,40 +660,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||
return False
|
||||
except MemoryError as e:
|
||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||
# Clean up potentially partially written output file?
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
return False
|
||||
print(
|
||||
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Clean up potentially partially written output file?
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
except EOFError as e:
|
||||
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
||||
try: os.remove(output_filename)
|
||||
except OSError: pass
|
||||
print(
|
||||
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
except OSError: pass
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||
finally:
|
||||
if 'neighbors_np' in locals() and neighbors_np is not None:
|
||||
del neighbors_np
|
||||
gc.collect()
|
||||
try:
|
||||
if "neighbors_np" in locals() and neighbors_np is not None:
|
||||
del neighbors_np
|
||||
gc.collect()
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
# --- Script Execution ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
|
||||
)
|
||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
||||
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
||||
help="Prune embedding storage (write NULL storage marker)")
|
||||
parser.add_argument("--keep-embeddings", action="store_true",
|
||||
help="Keep embedding storage (overrides --prune-embeddings)")
|
||||
parser.add_argument(
|
||||
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prune-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Prune embedding storage (write NULL storage marker)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-embeddings",
|
||||
action="store_true",
|
||||
help="Keep embedding storage (overrides --prune-embeddings)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -545,10 +728,12 @@ if __name__ == "__main__":
|
||||
sys.exit(1)
|
||||
|
||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
||||
success = convert_hnsw_graph_to_csr(
|
||||
args.input_index_file, args.output_csr_graph_file, prune_embeddings
|
||||
)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import shutil
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
|
||||
from leann.registry import register_backend
|
||||
import numpy as np
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,6 +29,12 @@ def get_metric_map():
|
||||
}
|
||||
|
||||
|
||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return data / norms
|
||||
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -48,8 +55,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||
self.dimensions = self.build_params.get("dimensions")
|
||||
if not self.is_recompute and self.is_compact:
|
||||
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
||||
logger.warning(
|
||||
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
||||
)
|
||||
self.is_compact = False
|
||||
self.build_params["is_compact"] = False
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
from . import faiss # type: ignore
|
||||
|
||||
path = Path(index_path)
|
||||
@@ -70,7 +84,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index.hnsw.efConstruction = self.efConstruction
|
||||
|
||||
if self.distance_metric.lower() == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
data = normalize_l2(data)
|
||||
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
@@ -92,19 +106,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
if success:
|
||||
logger.info("✅ CSR conversion successful.")
|
||||
index_file_old = index_file.with_suffix(".old")
|
||||
shutil.move(str(index_file), str(index_file_old))
|
||||
# index_file_old = index_file.with_suffix(".old")
|
||||
# shutil.move(str(index_file), str(index_file_old))
|
||||
shutil.move(str(csr_temp_file), str(index_file))
|
||||
logger.info(
|
||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||
)
|
||||
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||
else:
|
||||
# Clean up and fail fast
|
||||
if csr_temp_file.exists():
|
||||
os.remove(csr_temp_file)
|
||||
raise RuntimeError(
|
||||
"CSR conversion failed - cannot proceed with compact format"
|
||||
)
|
||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||
|
||||
|
||||
class HNSWSearcher(BaseSearcher):
|
||||
@@ -116,7 +126,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
self.distance_metric = (
|
||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||
)
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -150,7 +162,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using HNSW index.
|
||||
|
||||
@@ -174,28 +186,36 @@ class HNSWSearcher(BaseSearcher):
|
||||
"""
|
||||
from . import faiss # type: ignore
|
||||
|
||||
if not recompute_embeddings:
|
||||
if self.is_pruned:
|
||||
raise RuntimeError("Recompute is required for pruned index.")
|
||||
if not recompute_embeddings and self.is_pruned:
|
||||
raise RuntimeError(
|
||||
"Recompute is required for pruned/compact HNSW index. "
|
||||
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
||||
)
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
query = normalize_l2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
params.zmq_port = (
|
||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
)
|
||||
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
params.check_relative_distance = True
|
||||
|
||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||
params.pq_pruning_ratio = prune_ratio
|
||||
|
||||
@@ -205,9 +225,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
elif pruning_strategy == "proportional":
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = (
|
||||
1.0 # Any value > 1e-6 triggers proportional mode
|
||||
)
|
||||
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
||||
else: # "global"
|
||||
params.local_prune = False
|
||||
params.send_neigh_times_ratio = 0.0
|
||||
@@ -219,6 +237,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||
|
||||
search_time = time.time()
|
||||
self._index.search(
|
||||
query.shape[0],
|
||||
faiss.swig_ptr(query),
|
||||
@@ -227,9 +246,8 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.swig_ptr(labels),
|
||||
params,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -3,17 +3,18 @@ HNSW-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -52,8 +53,8 @@ def create_hnsw_embedding_server(
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
@@ -78,207 +79,319 @@ def create_hnsw_embedding_server(
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
||||
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
# Dimension from metadata for shaping responses
|
||||
try:
|
||||
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||
except Exception:
|
||||
embedding_dim = 0
|
||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||
|
||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
"""ZMQ server thread that respects shutdown signal.
|
||||
|
||||
Creates its own REP socket bound to zmq_port and polls with timeouts
|
||||
to allow graceful shutdown.
|
||||
"""
|
||||
logger.info("ZMQ server thread started with shutdown support")
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
rep_socket = context.socket(zmq.REP)
|
||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
# Track last request type/length for shape-correct fallbacks
|
||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||
last_request_length = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message_bytes = socket.recv()
|
||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
try:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
e2e_start = time.time()
|
||||
logger.debug("🔍 Waiting for ZMQ message...")
|
||||
request_bytes = rep_socket.recv()
|
||||
|
||||
e2e_start = time.time()
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
# Rest of the processing logic (same as original)
|
||||
request = msgpack.unpackb(request_bytes)
|
||||
|
||||
# Handle direct text embedding request
|
||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings)
|
||||
if all(isinstance(item, str) for item in request_payload):
|
||||
logger.info(
|
||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||
)
|
||||
|
||||
# Use unified embedding computation (now with model caching)
|
||||
embeddings = compute_embeddings(
|
||||
request_payload, model_name, mode=embedding_mode
|
||||
)
|
||||
|
||||
response = embeddings.tolist()
|
||||
socket.send(msgpack.packb(response))
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||
response_bytes = msgpack.packb([model_name])
|
||||
rep_socket.send(response_bytes)
|
||||
continue
|
||||
|
||||
# Handle distance calculation requests
|
||||
if (
|
||||
isinstance(request_payload, list)
|
||||
and len(request_payload) == 2
|
||||
and isinstance(request_payload[0], list)
|
||||
and isinstance(request_payload[1], list)
|
||||
):
|
||||
node_ids = request_payload[0]
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
# Handle direct text embedding request
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
# Handle distance calculation request: [[ids], [query_vector]]
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
node_ids = request[0]
|
||||
# Handle nested [[ids]] shape defensively
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
node_ids = node_ids[0]
|
||||
query_vector = np.array(request[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
# Get embeddings for node IDs
|
||||
texts = []
|
||||
for nid in node_ids:
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
# Gather texts for found ids
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
|
||||
# Prepare full-length response with large sentinel values
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if distance_metric == "l2":
|
||||
partial = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else: # mips or cosine
|
||||
partial = -np.dot(embeddings, query_vector)
|
||||
|
||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as e:
|
||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||
|
||||
# Send response in expected shape [[distances]]
|
||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
continue
|
||||
|
||||
# Fallback: treat as embedding-by-id request
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
and isinstance(request[0], list)
|
||||
):
|
||||
node_ids = request[0]
|
||||
elif isinstance(request, list):
|
||||
node_ids = request
|
||||
else:
|
||||
node_ids = []
|
||||
last_request_type = "embedding"
|
||||
last_request_length = len(node_ids)
|
||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||
|
||||
# Preallocate zero-filled flat data for robustness
|
||||
if embedding_dim <= 0:
|
||||
dims = [0, 0]
|
||||
flat_data: list[float] = []
|
||||
else:
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
flat_data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
# Collect texts for found ids
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
texts.append(txt)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Passage with ID {nid} not found"
|
||||
)
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Calculate distances
|
||||
if distance_metric == "l2":
|
||||
distances = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else: # mips or cosine
|
||||
distances = -np.dot(embeddings, query_vector)
|
||||
|
||||
response_payload = distances.flatten().tolist()
|
||||
response_bytes = msgpack.packb(
|
||||
[response_payload], use_single_float=True
|
||||
)
|
||||
logger.debug(
|
||||
f"Sending distance response with {len(distances)} distances"
|
||||
)
|
||||
|
||||
socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
continue
|
||||
|
||||
# Standard embedding request (passage ID lookup)
|
||||
if (
|
||||
not isinstance(request_payload, list)
|
||||
or len(request_payload) != 1
|
||||
or not isinstance(request_payload[0], list)
|
||||
):
|
||||
logger.error(
|
||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||
)
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
continue
|
||||
|
||||
node_ids = request_payload[0]
|
||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||
|
||||
# Look up texts by node IDs
|
||||
texts = []
|
||||
for nid in node_ids:
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
dims = [0, embedding_dim]
|
||||
flat_data = []
|
||||
else:
|
||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
flat = emb_f32.flatten().tolist()
|
||||
for j, pos in enumerate(found_indices):
|
||||
start = pos * embedding_dim
|
||||
end = start + embedding_dim
|
||||
if end <= len(flat_data):
|
||||
flat_data[start:end] = flat[
|
||||
j * embedding_dim : (j + 1) * embedding_dim
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||
|
||||
# Serialization and response
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
assert False
|
||||
response_payload = [dims, flat_data]
|
||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||
|
||||
hidden_contiguous_f32 = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
response_payload = [
|
||||
list(hidden_contiguous_f32.shape),
|
||||
hidden_contiguous_f32.flatten().tolist(),
|
||||
]
|
||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||
rep_socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
|
||||
socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
except zmq.Again:
|
||||
# Timeout - check shutdown_event and continue
|
||||
continue
|
||||
except Exception as e:
|
||||
if not shutdown_event.is_set():
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
# Shape-correct fallback
|
||||
try:
|
||||
if last_request_type == "distance":
|
||||
large_distance = 1e9
|
||||
fallback_len = max(0, int(last_request_length))
|
||||
safe = [[large_distance] * fallback_len]
|
||||
elif last_request_type == "embedding":
|
||||
bsz = max(0, int(last_request_length))
|
||||
dim = max(0, int(embedding_dim))
|
||||
safe = (
|
||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||
)
|
||||
elif last_request_type == "text":
|
||||
safe = [] # direct text embeddings expectation is a flat list
|
||||
else:
|
||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
rep_socket.close(0)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except zmq.Again:
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
logger.info("ZMQ server thread exiting gracefully")
|
||||
|
||||
traceback.print_exc()
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
# Add shutdown coordination
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
def shutdown_zmq_server():
|
||||
"""Gracefully shutdown ZMQ server."""
|
||||
logger.info("Initiating graceful shutdown...")
|
||||
shutdown_event.set()
|
||||
|
||||
if zmq_thread.is_alive():
|
||||
logger.info("Waiting for ZMQ thread to finish...")
|
||||
zmq_thread.join(timeout=5)
|
||||
if zmq_thread.is_alive():
|
||||
logger.warning("ZMQ thread did not finish in time")
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
# Note: socket and context are cleaned up by thread exit
|
||||
logger.info("ZMQ resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||
|
||||
# Clean up other resources
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
logger.info("Additional resources cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning additional resources: {e}")
|
||||
|
||||
logger.info("Graceful shutdown completed")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers within this function scope
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
shutdown_zmq_server()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Pass shutdown_event to ZMQ thread
|
||||
zmq_thread = threading.Thread(
|
||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||
daemon=False, # Not daemon - we want to wait for it
|
||||
)
|
||||
zmq_thread.start()
|
||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
while not shutdown_event.is_set():
|
||||
time.sleep(0.1) # Check shutdown more frequently
|
||||
except KeyboardInterrupt:
|
||||
logger.info("HNSW Server shutting down...")
|
||||
shutdown_zmq_server()
|
||||
return
|
||||
|
||||
# If we reach here, shutdown was triggered by signal
|
||||
logger.info("Main loop exited, process should be shutting down")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import signal
|
||||
import sys
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
# Register signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
# Signal handlers are now registered within create_hnsw_embedding_server
|
||||
|
||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument(
|
||||
@@ -299,7 +412,7 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user