Compare commits

...

47 Commits

Author SHA1 Message Date
Andy Lee
fcbcde1ea8 feat: implement smart memory configuration for DiskANN
- Add intelligent memory calculation based on data size and system specs
- search_memory_maximum: 1/10 of embedding size (controls PQ compression)
- build_memory_maximum: 50% of available RAM (controls sharding)
- Provides optimal balance between performance and memory usage
- Automatic fallback to default values if parameters are explicitly provided
2025-08-03 22:54:08 -07:00
Andy Lee
54df6310c5 fix: diskann build and prevent termination from hanging
- Fix OpenMP library linking in DiskANN CMake configuration
- Add timeout protection for HuggingFace model loading to prevent hangs
- Improve embedding server process termination with better timeouts
- Make DiskANN backend default enabled alongside HNSW
- Update documentation to reflect both backends included by default
2025-08-03 21:16:52 -07:00
yichuan520030910320
19bcc07814 change readme discription 2025-07-28 20:52:45 -07:00
yichuan520030910320
8356e3c668 changr to openai main cli 2025-07-28 17:39:14 -07:00
GitHub Actions
08eac5c821 chore: release v0.1.16 2025-07-29 00:15:18 +00:00
Andy Lee
4671ed9b36 Fix macos ABI by using system default clang (#11)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format

* feat: add OpenAI embeddings support to google_history_reader_leann.py

- Add --embedding-model and --embedding-mode arguments
- Support automatic detection of normalized embeddings
- Works correctly with cosine distance for OpenAI embeddings

* feat: add --use-existing-index option to google_history_reader_leann.py

- Allow using existing index without rebuilding
- Useful for testing pre-built indices

* fix: Improve OpenAI embeddings handling in HNSW backend

* fix: improve macOS C++ compatibility and add CI tests

* refactor: improve test structure and fix main_cli example

- Move pytest configuration from pytest.ini to pyproject.toml
- Remove unnecessary run_tests.py script (use test extras instead)
- Fix main_cli_example.py to properly use command line arguments for LLM config
- Add test_readme_examples.py to test code examples from README
- Refactor tests to use pytest fixtures and parametrization
- Update test documentation to reflect new structure
- Set proper environment variables in CI for test execution

* fix: add --distance-metric support to DiskANN embedding server and remove obsolete macOS ABI test markers

- Add --distance-metric parameter to diskann_embedding_server.py for consistency with other backends
- Remove pytest.skip and pytest.xfail markers for macOS C++ ABI issues as they have been fixed
- Fix test assertions to handle SearchResult objects correctly
- All tests now pass on macOS with the C++ ABI compatibility fixes

* chore: update lock file with test dependencies

* docs: remove obsolete C++ ABI compatibility warnings

- Remove outdated macOS C++ compatibility warnings from README
- Simplify CI workflow by removing macOS-specific failure handling
- All tests now pass consistently on macOS after ABI fixes

* fix: update macOS deployment target for DiskANN to 13.3

- DiskANN uses sgesdd_ LAPACK function which is only available on macOS 13.3+
- Update MACOSX_DEPLOYMENT_TARGET from 11.0 to 13.3 for DiskANN builds
- This fixes the compilation error on GitHub Actions macOS runners

* fix: align Python version requirements to 3.9

- Update root project to support Python 3.9, matching subpackages
- Restore macOS Python 3.9 support in CI
- This fixes the CI failure for Python 3.9 environments

* fix: handle MPS memory issues in CI tests

- Use smaller MiniLM-L6-v2 model (384 dimensions) for README tests in CI
- Skip other memory-intensive tests in CI environment
- Add minimal CI tests that don't require model loading
- Set CI environment variable and disable MPS fallback
- Ensure README examples always run correctly in CI

* fix: remove Python 3.10+ dependencies for compatibility

- Comment out llama-index-readers-docling and llama-index-node-parser-docling
- These packages require Python >= 3.10 and were causing CI failures on Python 3.9
- Regenerate uv.lock file to resolve dependency conflicts

* fix: use virtual environment in CI instead of system packages

- uv-managed Python environments don't allow --system installs
- Create and activate virtual environment before installing packages
- Update all CI steps to use the virtual environment

* add some env in ci

* fix: use --find-links to install platform-specific wheels

- Let uv automatically select the correct wheel for the current platform
- Fixes error when trying to install macOS wheels on Linux
- Simplifies the installation logic

* fix: disable OpenMP parallelism in CI to avoid libomp crashes

- Set OMP_NUM_THREADS=1 to avoid OpenMP thread synchronization issues
- Set MKL_NUM_THREADS=1 for single-threaded MKL operations
- This prevents segfaults in LayerNorm on macOS CI runners
- Addresses the libomp compatibility issues with PyTorch on Apple Silicon

* skip several macos test because strange issue on ci

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-07-28 17:14:42 -07:00
yichuan520030910320
055c086398 add ablation of embedding model compare 2025-07-28 14:43:42 -07:00
Andy Lee
d505dcc5e3 Fix/OpenAI embeddings cosine distance (#10)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format

* feat: add OpenAI embeddings support to google_history_reader_leann.py

- Add --embedding-model and --embedding-mode arguments
- Support automatic detection of normalized embeddings
- Works correctly with cosine distance for OpenAI embeddings

* feat: add --use-existing-index option to google_history_reader_leann.py

- Allow using existing index without rebuilding
- Useful for testing pre-built indices

* fix: Improve OpenAI embeddings handling in HNSW backend
2025-07-28 14:35:49 -07:00
Andy Lee
261006c36a docs: revert 2025-07-27 22:07:36 -07:00
GitHub Actions
b2eba23e21 chore: release v0.1.15 2025-07-28 05:05:30 +00:00
yichuan520030910320
e9ee687472 nit: fix readme 2025-07-27 21:56:05 -07:00
yichuan520030910320
6f5d5e4a77 fix some readme 2025-07-27 21:50:09 -07:00
Andy Lee
5c8921673a fix: auto-detect normalized embeddings and use cosine distance (#8)
* fix: auto-detect normalized embeddings and use cosine distance

- Add automatic detection for normalized embedding models (OpenAI, Voyage AI, Cohere)
- Automatically set distance_metric='cosine' for normalized embeddings
- Add warnings when using non-optimal distance metrics
- Implement manual L2 normalization in HNSW backend (custom Faiss build lacks normalize_L2)
- Fix DiskANN zmq_port compatibility with lazy loading strategy
- Add documentation for normalized embeddings feature

This fixes the low accuracy issue when using OpenAI text-embedding-3-small model with default MIPS metric.

* style: format
2025-07-27 21:19:29 -07:00
yichuan520030910320
e9d2d420bd fix some readme 2025-07-27 20:48:23 -07:00
yichuan520030910320
ebabfad066 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-27 20:44:36 -07:00
yichuan520030910320
e6f612b5e8 fix install and readme 2025-07-27 20:44:28 -07:00
Andy Lee
51c41acd82 docs: add comprehensive CONTRIBUTING.md guide with pre-commit setup 2025-07-27 20:40:42 -07:00
yichuan520030910320
455f93fb7c fix emaple and add pypi example 2025-07-27 18:20:13 -07:00
yichuan520030910320
48207c3b69 add pypi example 2025-07-27 17:08:49 -07:00
yichuan520030910320
4de1caa40f fix redame install method 2025-07-27 17:00:28 -07:00
yichuan520030910320
60eaa8165c fix precommit and fix redame install method 2025-07-27 16:36:30 -07:00
yichuan520030910320
c1a5d0c624 fix readme 2025-07-27 02:24:28 -07:00
yichuan520030910320
af1790395a fix ruff errors and formatting 2025-07-27 02:22:54 -07:00
yichuan520030910320
383c6d8d7e add clear instructions 2025-07-27 02:19:27 -07:00
yichuan520030910320
bc0d839693 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-27 02:07:41 -07:00
yichuan520030910320
8596562de5 add pip install option to README 2025-07-27 02:06:40 -07:00
GitHub Actions
5d09586853 chore: release v0.1.14 2025-07-27 08:50:56 +00:00
Andy Lee
a7cba078dd chore: consolidate essential fixes and add pre-commit hooks
- Add pre-commit configuration with ruff and black
- Fix lint CI job to use uv tool install instead of sync
- Add essential LlamaIndex dependencies to leann-core

Co-Authored-By: Yichuan Wang <73766326+yichuan-w@users.noreply.github.com>
2025-07-27 01:24:24 -07:00
Andy Lee
b3e9ee96fa fix: resolve all ruff linting errors and add lint CI check
- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments
- Replace Chinese comments with English equivalents
- Fix unused imports with proper noqa annotations for intentional imports
- Fix bare except clauses with specific exception types
- Fix redefined variables and undefined names
- Add ruff noqa annotations for generated protobuf files
- Add lint and format check to GitHub Actions CI pipeline
2025-07-26 22:38:13 -07:00
yichuan520030910320
8537a6b17e default args change 2025-07-26 21:51:14 -07:00
yichuan520030910320
7c8d7dc5c2 tones down 2025-07-26 21:47:55 -07:00
yichuan520030910320
8e23d663e6 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-26 21:46:02 -07:00
yichuan520030910320
8a3994bf80 update colab now it works perfect 2025-07-26 21:45:56 -07:00
GitHub Actions
8375f601ba chore: release v0.1.13 2025-07-27 01:08:17 +00:00
yichuan520030910320
c87c0fe662 update colab install & fix colab path 2025-07-26 18:07:31 -07:00
yichuan520030910320
73927b68ef Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-26 17:09:55 -07:00
yichuan520030910320
cc1a62e5aa update pytoml version again 2025-07-26 17:09:45 -07:00
GitHub Actions
802020cb41 chore: release v0.1.12 2025-07-26 23:35:28 +00:00
yichuan520030910320
cdb92f7cf4 update pytoml version && fix colab env && fix pdf extract in pip 2025-07-26 16:33:13 -07:00
yichuan520030910320
dc69bdec00 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 17:54:43 -07:00
yichuan520030910320
98073e9868 update missing pkg 2025-07-25 17:54:21 -07:00
GitHub Actions
cf2ef48967 chore: release v0.1.11 2025-07-26 00:12:37 +00:00
yichuan520030910320
0692bbf7a2 change workflow 2025-07-25 17:11:56 -07:00
GitHub Actions
52584a171f chore: release v0.1.10 2025-07-25 23:12:16 +00:00
Andy Lee
efd6b5324b fix: add protobuf as a dependency for DiskANN backend
- Fixes 'No module named google' error when starting DiskANN embedding server
- Prevents users from having to manually install protobuf
2025-07-25 16:10:25 -07:00
Andy Lee
2baaa4549b fix: handle relative paths in HNSW embedding server metadata
- Convert relative paths to absolute paths based on metadata file location
- Fixes FileNotFoundError when starting embedding server
- Resolves issue with passages file not found in different working directories
2025-07-25 16:09:53 -07:00
Andy Lee
35310ddd52 fix: pure Python packages not building due to ubuntu-latest check
The build workflow was checking for matrix.os == 'ubuntu-latest',
but we changed the matrix to use 'ubuntu-22.04', causing the
pure Python packages (leann-core and leann) to never be built.

Changed to use pattern matching [[ == ubuntu-* ]] to match any
Ubuntu version.

This explains why v0.1.9 only published the C++ backend packages
but not the pure Python packages.
2025-07-25 15:14:21 -07:00
85 changed files with 6325 additions and 4059 deletions

View File

@@ -10,7 +10,36 @@ on:
default: '' default: ''
jobs: jobs:
lint:
name: Lint and Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build: build:
needs: lint
name: Build ${{ matrix.os }} Python ${{ matrix.python }} name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy: strategy:
matrix: matrix:
@@ -68,7 +97,8 @@ jobs:
- name: Install system dependencies (macOS) - name: Install system dependencies (macOS)
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
brew install llvm libomp boost protobuf zeromq # Don't install LLVM, use system clang for better compatibility
brew install libomp boost protobuf zeromq
- name: Install build dependencies - name: Install build dependencies
run: | run: |
@@ -82,7 +112,7 @@ jobs:
- name: Build packages - name: Build packages
run: | run: |
# Build core (platform independent) # Build core (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann-core cd packages/leann-core
uv build uv build
cd ../.. cd ../..
@@ -91,7 +121,11 @@ jobs:
# Build HNSW backend # Build HNSW backend
cd packages/leann-backend-hnsw cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python # Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
export MACOSX_DEPLOYMENT_TARGET=11.0
uv build --wheel --python python
else else
uv build --wheel --python python uv build --wheel --python python
fi fi
@@ -100,14 +134,19 @@ jobs:
# Build DiskANN backend # Build DiskANN backend
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python # Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
export MACOSX_DEPLOYMENT_TARGET=13.3
uv build --wheel --python python
else else
uv build --wheel --python python uv build --wheel --python python
fi fi
cd ../.. cd ../..
# Build meta package (platform independent) # Build meta package (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann cd packages/leann
uv build uv build
cd ../.. cd ../..
@@ -160,6 +199,51 @@ jobs:
echo "📦 Built packages:" echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing
run: |
# Create a virtual environment
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
# Install the built wheels
# Use --find-links to let uv choose the correct wheel for the platform
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
uv pip install leann-core --find-links packages/leann-core/dist
uv pip install leann --find-links packages/leann/dist
fi
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
# Install test dependencies using extras
uv pip install -e ".[test]"
- name: Run tests with pytest
env:
CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Run all tests
pytest tests/
- name: Run sanity checks (optional)
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Run distance function tests if available
if [ -f test/sanity_checks/test_distance_functions.py ]; then
echo "Running distance function sanity checks..."
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
fi
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:

View File

@@ -22,11 +22,14 @@ jobs:
- name: Validate version - name: Validate version
run: | run: |
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then # Remove 'v' prefix if present for validation
echo "❌ Invalid version format" VERSION_CLEAN="${{ inputs.version }}"
VERSION_CLEAN="${VERSION_CLEAN#v}"
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
exit 1 exit 1
fi fi
echo "✅ Version format valid" echo "✅ Version format valid: ${{ inputs.version }}"
- name: Update versions and push - name: Update versions and push
id: push id: push
@@ -57,7 +60,7 @@ jobs:
needs: update-version needs: update-version
uses: ./.github/workflows/build-reusable.yml uses: ./.github/workflows/build-reusable.yml
with: with:
ref: ${{ needs.update-version.outputs.commit-sha }} ref: 'main'
publish: publish:
name: Publish and Release name: Publish and Release
@@ -70,7 +73,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ needs.update-version.outputs.commit-sha }} ref: 'main'
- name: Download all artifacts - name: Download all artifacts
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4

2
.gitignore vendored
View File

@@ -86,3 +86,5 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.passages.json *.passages.json
batchtest.py batchtest.py
tests/__pytest_cache__/
tests/__pycache__/

16
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
- id: ruff
- id: ruff-format

175
README.md
View File

@@ -12,7 +12,7 @@
The smallest vector index in the world. RAG Everything with LEANN! The smallest vector index in the world. RAG Everything with LEANN!
</h2> </h2>
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**. LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276) LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
@@ -33,12 +33,46 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage! 🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory! 📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage. **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation ## Installation
> `pip leann` coming soon!
<details>
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
Install uv first if you don't have it:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
</details>
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
### 🚀 Quick Install (Recommended for most users)
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
```bash
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
uv venv
source .venv/bin/activate
uv pip install leann
```
### 🔧 Build from Source (Recommended for development)
```bash ```bash
git clone git@github.com:yichuan-w/LEANN.git leann git clone git@github.com:yichuan-w/LEANN.git leann
cd leann cd leann
@@ -48,27 +82,65 @@ git submodule update --init --recursive
**macOS:** **macOS:**
```bash ```bash
brew install llvm libomp boost protobuf zeromq pkgconf brew install llvm libomp boost protobuf zeromq pkgconf
# Install with HNSW backend (default, recommended for most users)
# Install uv first if you don't have it:
# curl -LsSf https://astral.sh/uv/install.sh | sh
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
``` ```
**Linux:** **Linux:**
```bash ```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
uv sync uv sync
``` ```
**Ollama Setup (Recommended for full privacy):**
> *You can skip this installation if you only want to use OpenAI API for generation.*
## Quick Start
Our declarative API makes RAG as easy as writing a config file.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
```
## RAG on Everything!
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
> **Generation Model Setup**
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
Set your OpenAI API key as an environment variable:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
</details>
<details>
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
**macOS:** **macOS:**
@@ -80,6 +152,7 @@ ollama pull llama3.2:1b
``` ```
**Linux:** **Linux:**
```bash ```bash
# Install Ollama # Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh curl -fsSL https://ollama.ai/install.sh | sh
@@ -91,43 +164,7 @@ ollama serve &
ollama pull llama3.2:1b ollama pull llama3.2:1b
``` ```
## Quick Start in 30s </details>
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
# 1. Build the index (no embeddings stored!)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language and it is very popular")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
# 3. Chat with LEANN using retrieved results
llm_config = {
"type": "ollama",
"model": "llama3.2:1b"
}
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
response = chat.ask(
"Compare the two retrieved programming languages and say which one is more popular today.",
top_k=2,
)
```
## RAG on Everything!
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)! ### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
@@ -137,35 +174,46 @@ Ask questions directly about your personal PDFs, documents, and any directory co
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600"> <img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p> </p>
The example below asks a question about summarizing two papers (uses default data in `examples/data`): The example below asks a question about summarizing two papers (uses default data in `examples/data`) and this is the easiest example to run here:
```bash ```bash
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
```
```
# Or use python directly
source .venv/bin/activate source .venv/bin/activate
python ./examples/main_cli_example.py python ./examples/main_cli_example.py
``` ```
<details>
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
```bash
# Use custom index directory
python examples/main_cli_example.py --index-dir "./my_custom_index"
# Use custom data directory
python examples/main_cli_example.py --data-dir "./my_documents"
# Ask a specific question
python examples/main_cli_example.py --query "What are the main findings in these papers?"
```
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail! ### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon.
<p align="center"> <p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600"> <img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p> </p>
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access. **Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash ```bash
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?" python examples/mail_reader_leann.py --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
``` ```
**780K email chunks → 78MB storage** Finally, search your email like you search Google. **780K email chunks → 78MB storage.** Finally, search your email like you search Google.
<details> <details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary> <summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
```bash ```bash
# Use default mail path (works for most macOS setups) # Use default mail path (works for most macOS setups)
@@ -195,7 +243,7 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses" - "Show me emails about travel expenses"
</details> </details>
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History! ### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
<p align="center"> <p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600"> <img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
@@ -207,7 +255,7 @@ python examples/google_history_reader_leann.py --query "Tell me my browser histo
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine. **38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
<details> <details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary> <summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
```bash ```bash
# Use default Chrome profile (auto-finds all profiles) # Use default Chrome profile (auto-finds all profiles)
@@ -284,7 +332,7 @@ Failed to find or export WeChat data. Exiting.
</details> </details>
<details> <details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary> <summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
```bash ```bash
# Use default settings (recommended for first run) # Use default settings (recommended for first run)
@@ -441,10 +489,10 @@ If you find Leann useful, please cite:
## ✨ [Detailed Features →](docs/features.md) ## ✨ [Detailed Features →](docs/features.md)
## 🤝 [Contributing →](docs/contributing.md) ## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
## [FAQ →](docs/faq.md) ## [FAQ →](docs/faq.md)
## 📈 [Roadmap →](docs/roadmap.md) ## 📈 [Roadmap →](docs/roadmap.md)
@@ -465,4 +513,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.e
<p align="center"> <p align="center">
Made with ❤️ by the Leann team Made with ❤️ by the Leann team
</p> </p>

View File

@@ -4,7 +4,11 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Quick Start in 30s" "# Quick Start \n",
"\n",
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
"\n",
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
] ]
}, },
{ {
@@ -13,8 +17,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# install this if you areusing colab\n", "# install this if you are using colab\n",
"! pip install leann" "! uv pip install leann-core leann-backend-hnsw --no-deps\n",
"! uv pip install leann --no-deps\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"\n",
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"INDEX_DIR = Path(\"./\").resolve()\n",
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
] ]
}, },
{ {
@@ -26,91 +47,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"M: 64 for level: 0\n",
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
"[0.00s] Reading Index HNSW header...\n",
"[0.00s] Header read: d=768, ntotal=5\n",
"[0.00s] Reading HNSW struct vectors...\n",
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
"[0.00s] Read assign_probas (6)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
"[0.11s] Read cum_nneighbor_per_level (7)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
"[0.21s] Read levels (5)\n",
"[0.30s] Probing for compact storage flag...\n",
"[0.30s] Found compact flag: False\n",
"[0.30s] Compact flag is False, reading original format...\n",
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
"[0.30s] Read offsets (6)\n",
"[0.40s] Attempting to read neighbors vector...\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
"[0.40s] Read neighbors (320)\n",
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
"[0.50s] Checking for storage data...\n",
"[0.50s] Found storage fourcc: 49467849.\n",
"[0.50s] Converting to CSR format...\n",
"[0.50s] Conversion loop finished. \n",
"[0.50s] Running validation checks...\n",
" Checking total valid neighbor count...\n",
" OK: Total valid neighbors = 20\n",
" Checking final pointer indices...\n",
" OK: Final pointers match data size.\n",
"[0.50s] Deleting original neighbors and offsets arrays...\n",
" CSR Stats: |data|=20, |level_ptr|=10\n",
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
" Pruning embeddings: Writing NULL storage marker.\n",
"[0.69s] Conversion complete.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
]
}
],
"source": [ "source": [
"from leann.api import LeannBuilder\n", "from leann.api import LeannBuilder\n",
"\n", "\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n", "builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n", "builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n", "builder.add_text(\n",
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
")\n",
"builder.add_text(\"Machine learning transforms industries\")\n", "builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n", "builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n", "builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(\"knowledge.leann\")" "builder.build_index(INDEX_PATH)"
] ]
}, },
{ {
@@ -122,97 +73,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'programming languages'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n",
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n"
]
},
{
"data": {
"text/plain": [
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"from leann.api import LeannSearcher\n", "from leann.api import LeannSearcher\n",
"\n", "\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n", "searcher = LeannSearcher(INDEX_PATH)\n",
"results = searcher.search(\"programming languages\", top_k=2)\n", "results = searcher.search(\"programming languages\", top_k=2)\n",
"results" "results"
] ]
@@ -228,79 +95,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n",
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"data": {
"text/plain": [
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"from leann.api import LeannChat\n", "from leann.api import LeannChat\n",
"\n", "\n",
@@ -309,11 +104,11 @@
" \"model\": \"Qwen/Qwen3-0.6B\",\n", " \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n", "}\n",
"\n", "\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n", "chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
"response = chat.ask(\n", "response = chat.ask(\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n", " \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" top_k=2,\n", " top_k=2,\n",
" llm_kwargs={\"max_tokens\": 128}\n", " llm_kwargs={\"max_tokens\": 128},\n",
")\n", ")\n",
"response" "response"
] ]

220
docs/CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,220 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🚀 Development Setup
### Prerequisites
1. **Install uv** (fast Python package installer):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Clone the repository**:
```bash
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
cd LEANN-RAG
```
3. **Install system dependencies**:
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq pkgconf
```
**Ubuntu/Debian:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
```
4. **Build from source**:
```bash
# macOS
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
# Ubuntu/Debian
uv sync
```
## 🔨 Pre-commit Hooks
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv pip install pre-commit
```
2. **Install the git hooks**:
```bash
pre-commit install
```
3. **Run pre-commit manually** (optional):
```bash
pre-commit run --all-files
```
### Pre-commit Checks
Our pre-commit configuration includes:
- **Trailing whitespace removal**
- **End-of-file fixing**
- **YAML validation**
- **Large file prevention**
- **Merge conflict detection**
- **Debug statement detection**
- **Code formatting with ruff**
- **Code linting with ruff**
## 🧪 Testing
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest test/test_filename.py
# Run with coverage
uv run pytest --cov=leann
```
### Writing Tests
- Place tests in the `test/` directory
- Follow the naming convention `test_*.py`
- Use descriptive test names that explain what's being tested
- Include both positive and negative test cases
## 📝 Code Style
We use `ruff` for both linting and formatting to ensure consistent code style.
### Format Your Code
```bash
# Format all files
ruff format
# Check formatting without changing files
ruff format --check
```
### Lint Your Code
```bash
# Run linter with auto-fix
ruff check --fix
# Just check without fixing
ruff check
```
### Style Guidelines
- Follow PEP 8 conventions
- Use descriptive variable names
- Add type hints where appropriate
- Write docstrings for all public functions and classes
- Keep functions focused and single-purpose
## 🚦 CI/CD
Our CI pipeline runs automatically on all pull requests. It includes:
1. **Linting and Formatting**: Ensures code follows our style guidelines
2. **Multi-platform builds**: Tests on Ubuntu and macOS
3. **Python version matrix**: Tests on Python 3.9-3.13
4. **Wheel building**: Ensures packages can be built and distributed
### CI Commands
The CI uses the same commands as pre-commit to ensure consistency:
```bash
# Linting
ruff check .
# Format checking
ruff format --check .
```
Make sure your code passes these checks locally before pushing!
## 🔄 Pull Request Process
1. **Fork the repository** and create your branch from `main`:
```bash
git checkout -b feature/your-feature-name
```
2. **Make your changes**:
- Write clean, documented code
- Add tests for new functionality
- Update documentation as needed
3. **Run pre-commit checks**:
```bash
pre-commit run --all-files
```
4. **Test your changes**:
```bash
uv run pytest
```
5. **Commit with descriptive messages**:
```bash
git commit -m "feat: add new search algorithm"
```
Follow [Conventional Commits](https://www.conventionalcommits.org/):
- `feat:` for new features
- `fix:` for bug fixes
- `docs:` for documentation changes
- `test:` for test additions/changes
- `refactor:` for code refactoring
- `perf:` for performance improvements
6. **Push and create a pull request**:
- Provide a clear description of your changes
- Reference any related issues
- Include examples or screenshots if applicable
## 📚 Documentation
When adding new features or making significant changes:
1. Update relevant documentation in `/docs`
2. Add docstrings to new functions/classes
3. Update README.md if needed
4. Include usage examples
## 🤔 Getting Help
- **Discord**: Join our community for discussions
- **Issues**: Check existing issues or create a new one
- **Discussions**: For general questions and ideas
## 📄 License
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
---
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟

View File

@@ -0,0 +1,98 @@
"""
Comparison between Sentence Transformers and OpenAI embeddings
This example shows how different embedding models handle complex queries
and demonstrates the differences between local and API-based embeddings.
"""
import numpy as np
from leann.embedding_compute import compute_embeddings
# OpenAI API key should be set as environment variable
# export OPENAI_API_KEY="your-api-key-here"
# Test data
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
# Two queries with same intent but different wording
query1 = "Tell me my browser history about some conference i often visit"
query2 = "browser history about conference I often visit"
texts = [query1, query2, conference_text, browser_text]
def cosine_similarity(a, b):
return np.dot(a, b) # Already normalized
def analyze_embeddings(embeddings, model_name):
print(f"\n=== {model_name} Results ===")
# Results for Query 1
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
print(f"Query 1: '{query1}'")
print(f" → Conference similarity: {sim1_conf:.4f} {'' if sim1_conf > sim1_browser else ''}")
print(
f" → Browser similarity: {sim1_browser:.4f} {'' if sim1_browser > sim1_conf else ''}"
)
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
# Results for Query 2
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
print(f"\nQuery 2: '{query2}'")
print(f" → Conference similarity: {sim2_conf:.4f} {'' if sim2_conf > sim2_browser else ''}")
print(
f" → Browser similarity: {sim2_browser:.4f} {'' if sim2_browser > sim2_conf else ''}"
)
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
# Show the impact
print("\n=== Impact Analysis ===")
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
print("✅ STABLE: Conference remains winner in both queries")
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
print("✅ STABLE: Browser remains winner in both queries")
else:
print("🔄 MIXED: Results vary between queries")
return {
"query1_conf": sim1_conf,
"query1_browser": sim1_browser,
"query2_conf": sim2_conf,
"query2_browser": sim2_browser,
}
# Test Sentence Transformers
print("Testing Sentence Transformers (facebook/contriever)...")
try:
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
except Exception as e:
print(f"❌ Sentence Transformers failed: {e}")
st_results = None
# Test OpenAI
print("\n" + "=" * 60)
print("Testing OpenAI (text-embedding-3-small)...")
try:
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
except Exception as e:
print(f"❌ OpenAI failed: {e}")
openai_results = None
# Compare results
if st_results and openai_results:
print("\n" + "=" * 60)
print("=== COMPARISON SUMMARY ===")

View File

@@ -1,11 +0,0 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results

View File

@@ -0,0 +1,75 @@
# Normalized Embeddings Support in LEANN
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
## What are Normalized Embeddings?
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
## Automatic Detection
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
1. **Automatically set `distance_metric="cosine"`** if not specified
2. **Show a warning** if you manually specify a different distance metric
3. **Provide optimal search performance** with the correct metric
## Supported Normalized Embedding Models
### OpenAI
All OpenAI text embedding models are normalized:
- `text-embedding-ada-002`
- `text-embedding-3-small`
- `text-embedding-3-large`
### Voyage AI
All Voyage AI embedding models are normalized:
- `voyage-2`
- `voyage-3`
- `voyage-large-2`
- `voyage-multilingual-2`
- `voyage-code-2`
### Cohere
All Cohere embedding models are normalized:
- `embed-english-v3.0`
- `embed-multilingual-v3.0`
- `embed-english-light-v3.0`
- `embed-multilingual-light-v3.0`
## Example Usage
```python
from leann.api import LeannBuilder
# Automatic detection - will use cosine distance
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai"
)
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
# Automatically setting distance_metric='cosine'
# Manual override (not recommended)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
distance_metric="mips" # Will show warning
)
# Warning: Using 'mips' distance metric with normalized embeddings...
```
## Non-Normalized Embeddings
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
## Why This Matters
Using the wrong distance metric with normalized embeddings can lead to:
- **Poor search quality** due to HNSW's early termination with narrow score ranges
- **Incorrect ranking** of search results
- **Suboptimal performance** compared to using the correct metric
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).

View File

@@ -3,14 +3,15 @@
Memory comparison between Faiss HNSW and LEANN HNSW backend Memory comparison between Faiss HNSW and LEANN HNSW backend
""" """
import gc
import logging import logging
import os import os
import subprocess
import sys import sys
import time import time
import psutil
import gc
import subprocess
from pathlib import Path from pathlib import Path
import psutil
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
# Setup logging # Setup logging
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
for line in lines: for line in lines:
if "Peak Memory:" in line: if "Peak Memory:" in line:
peak_memory = float( peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory} return {"peak_memory": peak_memory}
@@ -111,9 +110,8 @@ def test_leann_hnsw():
tracker.checkpoint("After imports") tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents # Load and parse documents
documents = SimpleDirectoryReader( documents = SimpleDirectoryReader(
@@ -202,11 +200,9 @@ def test_leann_hnsw():
searcher = LeannSearcher(index_path) searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading") tracker.checkpoint("After searcher loading")
print("Running search queries...") print("Running search queries...")
queries = [ queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发", "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?", "What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容", "华为诺亚方舟实验室的主要研究内容",
] ]
@@ -304,21 +300,15 @@ def main():
print("\nLEANN vs Faiss Performance:") print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"] memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print( print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison # Storage comparison
if leann_storage_size > faiss_storage_size: if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size storage_ratio = leann_storage_size / faiss_storage_size
print( print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size: elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size storage_ratio = faiss_storage_size / leann_storage_size
print( print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else: else:
print(" Storage Size: similar") print(" Storage Size: similar")
else: else:

View File

@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks. subscribe to our email newsletter to hear about new eBooks.

View File

@@ -3,32 +3,42 @@
Document search demo with recompute mode Document search demo with recompute mode
""" """
import os
from pathlib import Path
import shutil import shutil
import time import time
from pathlib import Path
# Import backend packages to trigger plugin registration # Import backend packages to trigger plugin registration
try: try:
import leann_backend_diskann import leann_backend_diskann # noqa: F401
import leann_backend_hnsw import leann_backend_hnsw # noqa: F401
print("INFO: Backend packages imported successfully.") print("INFO: Backend packages imported successfully.")
except ImportError as e: except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}") print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core # Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannChat, LeannSearcher
def load_sample_documents(): def load_sample_documents():
"""Create sample documents for demonstration""" """Create sample documents for demonstration"""
docs = [ docs = [
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."}, {
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."}, "title": "Intro to Python",
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."}, "content": "Python is a high-level, interpreted language known for simplicity.",
},
{
"title": "ML Basics",
"content": "Machine learning builds systems that learn from data.",
},
{
"title": "Data Structures",
"content": "Data structures like arrays, lists, and graphs organize data.",
},
] ]
return docs return docs
def main(): def main():
print("==========================================================") print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===") print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
@@ -45,11 +55,7 @@ def main():
# --- 1. Build index --- # --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...") print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder( builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
backend_name=BACKEND_TO_TEST,
graph_degree=32,
complexity=64
)
documents = load_sample_documents() documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.") print(f"Loaded {len(documents)} sample documents.")
@@ -57,7 +63,7 @@ def main():
builder.add_text(doc["content"], metadata={"title": doc["title"]}) builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH) builder.build_index(INDEX_PATH)
print(f"\nIndex built!") print("\nIndex built!")
# --- 2. Basic search demo --- # --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...") print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
@@ -74,31 +80,33 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<") print(">>> Basic search results <<<")
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}") print(
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
)
# --- 3. Recompute search demo --- # --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...") print("\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---") print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters # Configure recompute parameters
recompute_params = { recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation "recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch "USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering "skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication "dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10% "prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation "batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning "global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port "zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2" "embedding_model": "sentence-transformers/all-mpnet-base-v2",
} }
print("Recompute parameter configuration:") print("Recompute parameter configuration:")
for key, value in recompute_params.items(): for key, value in recompute_params.items():
print(f" {key}: {value}") print(f" {key}: {value}")
print(f"\n🔄 Executing Recompute search...") print("\n🔄 Executing Recompute search...")
try: try:
start_time = time.time() start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params) recompute_results = searcher.search(query, top_k=2, **recompute_params)
@@ -107,10 +115,12 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<") print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1): for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}") print(
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
)
# Compare results # Compare results
print(f"\n--- Result comparison ---") print("\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds") print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds") print(f"Recompute time: {recompute_time:.3f} seconds")
@@ -119,19 +129,21 @@ def main():
basic_score = results[i].score basic_score = results[i].score
recompute_score = recompute_results[i].score recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score) score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") print(
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
)
if recompute_time > basic_time: if recompute_time > basic_time:
print(f"✅ Recompute mode working correctly (more accurate but slower)") print("✅ Recompute mode working correctly (more accurate but slower)")
else: else:
print(f" Recompute time is unusually fast, network recomputation may not be enabled") print("i Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e: except Exception as e:
print(f"❌ Recompute search failed: {e}") print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue") print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo --- # --- 4. Chat demo ---
print(f"\n[PHASE 4] Starting chat session...") print("\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH) chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query) chat_response = chat.ask(query)
print(f"You: {query}") print(f"You: {query}")

View File

@@ -1,11 +1,13 @@
import os
import email import email
import os
from pathlib import Path from pathlib import Path
from typing import List, Any from typing import Any
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
def find_all_messages_directories(root: str | None = None) -> list[Path]:
""" """
Recursively find all 'Messages' directories under the given root. Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects. Returns a list of Path objects.
@@ -16,11 +18,12 @@ def find_all_messages_directories(root: str = None) -> List[Path]:
root = os.path.join(home_dir, "Library", "Mail") root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = [] messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root): for dirpath, _dirnames, _filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages": if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath)) messages_dirs.append(Path(dirpath))
return messages_dirs return messages_dirs
class EmlxReader(BaseReader): class EmlxReader(BaseReader):
""" """
Apple Mail .emlx file reader with embedded metadata. Apple Mail .emlx file reader with embedded metadata.
@@ -37,7 +40,7 @@ class EmlxReader(BaseReader):
""" """
self.include_html = include_html self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
""" """
Load data from the input directory containing .emlx files. Load data from the input directory containing .emlx files.
@@ -46,8 +49,8 @@ class EmlxReader(BaseReader):
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of messages to read. max_count (int): Maximum amount of messages to read.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
count = 0 count = 0
# Walk through the directory recursively # Walk through the directory recursively
@@ -63,12 +66,12 @@ class EmlxReader(BaseReader):
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
try: try:
# Read the .emlx file # Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx files have a length prefix followed by the email content # .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email # The first line contains the length, followed by the email
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] email_content = lines[1]
@@ -77,22 +80,32 @@ class EmlxReader(BaseReader):
msg = email.message_from_string(email_content) msg = email.message_from_string(email_content)
# Extract email metadata # Extract email metadata
subject = msg.get('Subject', 'No Subject') subject = msg.get("Subject", "No Subject")
from_addr = msg.get('From', 'Unknown') from_addr = msg.get("From", "Unknown")
to_addr = msg.get('To', 'Unknown') to_addr = msg.get("To", "Unknown")
date = msg.get('Date', 'Unknown') date = msg.get("Date", "Unknown")
# Extract email body # Extract email body
body = "" body = ""
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html": if (
if part.get_content_type() == "text/html" and not self.include_html: part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
if (
part.get_content_type() == "text/html"
and not self.include_html
):
continue continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore') body += part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# break # break
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content with metadata embedded in text # Create document content with metadata embedded in text
doc_content = f""" doc_content = f"""

View File

@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
from fsspec import AbstractFileSystem
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document from llama_index.core.schema import Document
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
""" """
DEFAULT_MESSAGE_FORMAT: str = ( DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\n" "Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
) )
def __init__( def __init__(
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
try: try:
from bs4 import BeautifulSoup # noqa from bs4 import BeautifulSoup # noqa
except ImportError: except ImportError:
raise ImportError( raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.max_count = max_count self.max_count = max_count
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
def load_data( def load_data(
self, self,
file: Path, file: Path,
extra_info: Optional[Dict] = None, extra_info: dict | None = None,
fs: Optional[AbstractFileSystem] = None, fs: AbstractFileSystem | None = None,
) -> List[Document]: ) -> list[Document]:
"""Parse file into string.""" """Parse file into string."""
# Import required libraries # Import required libraries
import mailbox import mailbox
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
) )
i = 0 i = 0
results: List[str] = [] results: list[str] = []
# Load file using mailbox # Load file using mailbox
bytes_parser = BytesParser(policy=default).parse bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
@@ -134,12 +128,12 @@ class EmlxMboxReader(MboxReader):
def load_data( def load_data(
self, self,
directory: Path, directory: Path,
extra_info: Optional[Dict] = None, extra_info: dict | None = None,
fs: Optional[AbstractFileSystem] = None, fs: AbstractFileSystem | None = None,
) -> List[Document]: ) -> list[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic.""" """Parse .emlx files from directory into strings using MboxReader logic."""
import tempfile
import os import os
import tempfile
if fs: if fs:
logger.warning( logger.warning(
@@ -156,18 +150,18 @@ class EmlxMboxReader(MboxReader):
return [] return []
# Create a temporary mbox file # Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox: with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format # Convert .emlx files to mbox format
for emlx_file in emlx_files: for emlx_file in emlx_files:
try: try:
# Read the .emlx file # Read the .emlx file
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f: with open(emlx_file, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx format: first line is length, rest is email content # .emlx format: first line is length, rest is email content
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] # Skip the length line email_content = lines[1] # Skip the length line
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
# Clean up temporary file # Clean up temporary file
try: try:
os.unlink(temp_mbox_path) os.unlink(temp_mbox_path)
except: except OSError:
pass pass

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Test only Faiss HNSW""" """Test only Faiss HNSW"""
import os
import sys import sys
import time import time
import psutil import psutil
import gc
import os
def get_memory_usage(): def get_memory_usage():
@@ -37,20 +37,20 @@ def main():
import faiss import faiss
except ImportError: except ImportError:
print("Faiss is not installed.") print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again") print(
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
sys.exit(1) sys.exit(1)
from llama_index.core import ( from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings, Settings,
node_parser, SimpleDirectoryReader,
Document, StorageContext,
VectorStoreIndex,
) )
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
tracker = MemoryTracker("Faiss HNSW") tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial") tracker.checkpoint("Initial")
@@ -90,8 +90,9 @@ def main():
vector_store=vector_store, persist_dir="./storage_faiss" vector_store=vector_store, persist_dir="./storage_faiss"
) )
from llama_index.core import load_index_from_storage from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context) index = load_index_from_storage(storage_context=storage_context)
print(f"Index loaded from ./storage_faiss") print("Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index") tracker.checkpoint("After loading existing index")
index_loaded = True index_loaded = True
except Exception as e: except Exception as e:
@@ -99,6 +100,7 @@ def main():
print("Cleaning up corrupted index and building new one...") print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index # Clean up corrupted index
import shutil import shutil
if os.path.exists("./storage_faiss"): if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss") shutil.rmtree("./storage_faiss")
@@ -109,9 +111,7 @@ def main():
vector_store = FaissVectorStore(faiss_index=faiss_index) vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store) storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents, storage_context=storage_context, transformations=[node_parser]
storage_context=storage_context,
transformations=[node_parser]
) )
tracker.checkpoint("After index building") tracker.checkpoint("After index building")
@@ -127,7 +127,7 @@ def main():
query_engine = index.as_query_engine(similarity_top_k=20) query_engine = index.as_query_engine(similarity_top_k=20)
queries = [ queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发", "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?", "What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容", "华为诺亚方舟实验室的主要研究内容",
] ]

View File

@@ -1,15 +1,17 @@
import os
import asyncio
import argparse import argparse
import asyncio
import os
try: try:
import dotenv import dotenv
dotenv.load_dotenv() dotenv.load_dotenv()
except ModuleNotFoundError: except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables # python-dotenv is not installed; skip loading environment variables
dotenv = None dotenv = None
from pathlib import Path from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannChat
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
# dotenv.load_dotenv() # handled above if python-dotenv is available # dotenv.load_dotenv() # handled above if python-dotenv is available
@@ -17,7 +19,14 @@ from llama_index.core.node_parser import SentenceSplitter
# Default Chrome profile path # Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default") DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
def create_leann_index_from_multiple_chrome_profiles(
profile_dirs: list[Path],
index_path: str = "chrome_history_index.leann",
max_count: int = -1,
embedding_model: str = "facebook/contriever",
embedding_mode: str = "sentence-transformers",
):
""" """
Create LEANN index from multiple Chrome profile data sources. Create LEANN index from multiple Chrome profile data sources.
@@ -25,28 +34,30 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
profile_dirs: List of Path objects pointing to Chrome profile directories profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile max_count: Maximum number of history entries to process per profile
embedding_model: The embedding model to use
embedding_mode: The embedding backend mode
""" """
print("Creating LEANN index from multiple Chrome profile data sources...") print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from history_data # Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader() reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
all_documents = [] all_documents = []
total_processed = 0 total_processed = 0
# Process each Chrome profile directory # Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs): for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}") print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
try: try:
documents = reader.load_data( documents = reader.load_data(
chrome_profile_path=str(profile_dir), chrome_profile_path=str(profile_dir), max_count=max_count
max_count=max_count
) )
if documents: if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}") print(f"Loaded {len(documents)} history documents from {profile_dir}")
@@ -66,10 +77,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
if not all_documents: if not all_documents:
print("No documents loaded from any source. Exiting.") print("No documents loaded from any source. Exiting.")
# highlight info that you need to close all chrome browser before running this script and high light the instruction!! # highlight info that you need to close all chrome browser before running this script and high light the instruction!!
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m") print(
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
)
return None return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles") print(
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
)
# Create text splitter with 256 chunk size # Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128) text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
@@ -87,22 +102,24 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents") print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="facebook/contriever", embedding_model=embedding_model,
embedding_mode=embedding_mode,
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} history chunks to index...") print(f"Adding {len(all_texts)} history chunks to index...")
@@ -116,7 +133,14 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
return index_path return index_path
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
def create_leann_index(
profile_path: str | None = None,
index_path: str = "chrome_history_index.leann",
max_count: int = 1000,
embedding_model: str = "facebook/contriever",
embedding_mode: str = "sentence-transformers",
):
""" """
Create LEANN index from Chrome history data. Create LEANN index from Chrome history data.
@@ -124,26 +148,26 @@ def create_leann_index(profile_path: str = None, index_path: str = "chrome_histo
profile_path: Path to the Chrome profile directory (optional, uses default if None) profile_path: Path to the Chrome profile directory (optional, uses default if None)
index_path: Path to save the LEANN index index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process max_count: Maximum number of history entries to process
embedding_model: The embedding model to use
embedding_mode: The embedding backend mode
""" """
print("Creating LEANN index from Chrome history data...") print("Creating LEANN index from Chrome history data...")
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Load documents using ChromeHistoryReader from history_data # Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader() reader = ChromeHistoryReader()
documents = reader.load_data( documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
chrome_profile_path=profile_path,
max_count=max_count
)
if not documents: if not documents:
print("No documents loaded. Exiting.") print("No documents loaded. Exiting.")
@@ -165,22 +189,24 @@ def create_leann_index(profile_path: str = None, index_path: str = "chrome_histo
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents") print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="facebook/contriever", embedding_model=embedding_model,
embedding_mode=embedding_mode,
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} history chunks to index...") print(f"Adding {len(all_texts)} history chunks to index...")
@@ -194,6 +220,7 @@ def create_leann_index(profile_path: str = None, index_path: str = "chrome_histo
return index_path return index_path
async def query_leann_index(index_path: str, query: str): async def query_leann_index(index_path: str, query: str):
""" """
Query the LEANN index. Query the LEANN index.
@@ -202,7 +229,7 @@ async def query_leann_index(index_path: str, query: str):
index_path: Path to the LEANN index index_path: Path to the LEANN index
query: The query string query: The query string
""" """
print(f"\n[PHASE 2] Starting Leann chat session...") print("\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path) chat = LeannChat(index_path=index_path)
print(f"You: {query}") print(f"You: {query}")
@@ -217,27 +244,65 @@ async def query_leann_index(index_path: str, query: str):
"model": "gpt-4o", "model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"), "api_key": os.getenv("OPENAI_API_KEY"),
}, },
llm_kwargs={ llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
"temperature": 0.0,
"max_tokens": 1000
}
) )
print(f"Leann chat response: \033[36m{chat_response}\033[0m") print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main(): async def main():
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index') parser = argparse.ArgumentParser(
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE, description="LEANN Chrome History Reader - Create and query browser history index"
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this') )
parser.add_argument('--index-dir', type=str, default="./google_history_index", parser.add_argument(
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)') "--chrome-profile",
parser.add_argument('--max-entries', type=int, default=1000, type=str,
help='Maximum number of history entries to process (default: 1000)') default=DEFAULT_CHROME_PROFILE,
parser.add_argument('--query', type=str, default=None, help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
help='Single query to run (default: runs example queries)') )
parser.add_argument('--auto-find-profiles', action='store_true', default=True, parser.add_argument(
help='Automatically find all Chrome profiles (default: True)') "--index-dir",
type=str,
default="./google_history_index",
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=1000,
help="Maximum number of history entries to process (default: 1000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--auto-find-profiles",
action="store_true",
default=True,
help="Automatically find all Chrome profiles (default: True)",
)
parser.add_argument(
"--embedding-model",
type=str,
default="facebook/contriever",
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small')",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
help="The embedding backend mode",
)
parser.add_argument(
"--use-existing-index",
action="store_true",
help="Use existing index without rebuilding",
)
args = parser.parse_args() args = parser.parse_args()
@@ -248,24 +313,34 @@ async def main():
print(f"Index directory: {INDEX_DIR}") print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}") print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories if args.use_existing_index:
from history_data.history import ChromeHistoryReader # Use existing index without rebuilding
if not Path(INDEX_PATH).exists():
if args.auto_find_profiles: print(f"Error: Index file not found at {INDEX_PATH}")
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return return
print(f"Using existing index at {INDEX_PATH}")
index_path = INDEX_PATH
else: else:
# Use single specified profile # Find Chrome profile directories
profile_path = Path(args.chrome_profile) from history_data.history import ChromeHistoryReader
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources if args.auto_find_profiles:
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries) profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(
profile_dirs, INDEX_PATH, args.max_entries, args.embedding_model, args.embedding_mode
)
if index_path: if index_path:
if args.query: if args.query:
@@ -275,12 +350,13 @@ async def main():
# Example queries # Example queries
queries = [ queries = [
"What websites did I visit about machine learning?", "What websites did I visit about machine learning?",
"Find my search history about programming" "Find my search history about programming",
] ]
for query in queries: for query in queries:
print("\n" + "="*60) print("\n" + "=" * 60)
await query_leann_index(index_path, query) await query_leann_index(index_path, query)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,3 +1,3 @@
from .history import ChromeHistoryReader from .history import ChromeHistoryReader
__all__ = ['ChromeHistoryReader'] __all__ = ["ChromeHistoryReader"]

View File

@@ -1,10 +1,12 @@
import sqlite3
import os import os
import sqlite3
from pathlib import Path from pathlib import Path
from typing import List, Any from typing import Any
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader): class ChromeHistoryReader(BaseReader):
""" """
Chrome browser history reader that extracts browsing data from SQLite database. Chrome browser history reader that extracts browsing data from SQLite database.
@@ -17,7 +19,7 @@ class ChromeHistoryReader(BaseReader):
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
""" """
Load Chrome history data from the default Chrome profile location. Load Chrome history data from the default Chrome profile location.
@@ -27,13 +29,15 @@ class ChromeHistoryReader(BaseReader):
max_count (int): Maximum amount of history entries to read. max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory. chrome_profile_path (str): Custom path to Chrome profile directory.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
chrome_profile_path = load_kwargs.get('chrome_profile_path', None) chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
# Default Chrome profile path on macOS # Default Chrome profile path on macOS
if chrome_profile_path is None: if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default") chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History") history_db_path = os.path.join(chrome_profile_path, "History")
@@ -82,7 +86,7 @@ class ChromeHistoryReader(BaseReader):
""" """
# Create document with embedded metadata # Create document with embedded metadata
doc = Document(text=doc_content, metadata={ "title": title[0:150]}) doc = Document(text=doc_content, metadata={"title": title[0:150]})
# if len(title) > 150: # if len(title) > 150:
# print(f"Title is too long: {title}") # print(f"Title is too long: {title}")
docs.append(doc) docs.append(doc)
@@ -98,7 +102,7 @@ class ChromeHistoryReader(BaseReader):
return docs return docs
@staticmethod @staticmethod
def find_chrome_profiles() -> List[Path]: def find_chrome_profiles() -> list[Path]:
""" """
Find all Chrome profile directories. Find all Chrome profile directories.
@@ -124,7 +128,9 @@ class ChromeHistoryReader(BaseReader):
return profile_dirs return profile_dirs
@staticmethod @staticmethod
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000): def export_history_to_file(
output_file: str = "chrome_history_export.txt", max_count: int = 1000
):
""" """
Export Chrome history to a text file using the same SQL query format. Export Chrome history to a text file using the same SQL query format.
@@ -132,7 +138,9 @@ class ChromeHistoryReader(BaseReader):
output_file: Path to the output file output_file: Path to the output file
max_count: Maximum number of entries to export max_count: Maximum number of entries to export
""" """
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default") chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History") history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path): if not os.path.exists(history_db_path):
@@ -159,10 +167,12 @@ class ChromeHistoryReader(BaseReader):
cursor.execute(query, (max_count,)) cursor.execute(query, (max_count,))
rows = cursor.fetchall() rows = cursor.fetchall()
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
for row in rows: for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row last_visit, url, title, visit_count, typed_count, hidden = row
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n") f.write(
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
)
conn.close() conn.close()
print(f"Exported {len(rows)} history entries to {output_file}") print(f"Exported {len(rows)} history entries to {output_file}")

View File

@@ -2,13 +2,14 @@ import json
import os import os
import re import re
import subprocess import subprocess
import sys
import time import time
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Any, Dict, Optional from typing import Any
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader): class WeChatHistoryReader(BaseReader):
""" """
@@ -43,10 +44,16 @@ class WeChatHistoryReader(BaseReader):
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli" wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists(): if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...") print("Downloading WeChatTweak CLI...")
subprocess.run([ subprocess.run(
"curl", "-L", "-o", str(wechattweak_path), [
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli" "curl",
], check=True) "-L",
"-o",
str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
],
check=True,
)
# Make executable # Make executable
wechattweak_path.chmod(0o755) wechattweak_path.chmod(0o755)
@@ -73,16 +80,16 @@ class WeChatHistoryReader(BaseReader):
def check_api_available(self) -> bool: def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available.""" """Check if WeChatTweak API is available."""
try: try:
result = subprocess.run([ result = subprocess.run(
"curl", "-s", "http://localhost:48065/wechat/allcontacts" ["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
], capture_output=True, text=True, timeout=5) capture_output=True,
text=True,
timeout=5,
)
return result.returncode == 0 and result.stdout.strip() return result.returncode == 0 and result.stdout.strip()
except Exception: except Exception:
return False return False
def _extract_readable_text(self, content: str) -> str: def _extract_readable_text(self, content: str) -> str:
""" """
Extract readable text from message content, removing XML and system messages. Extract readable text from message content, removing XML and system messages.
@@ -100,14 +107,14 @@ class WeChatHistoryReader(BaseReader):
if isinstance(content, dict): if isinstance(content, dict):
# Extract text from dictionary structure # Extract text from dictionary structure
text_parts = [] text_parts = []
if 'title' in content: if "title" in content:
text_parts.append(str(content['title'])) text_parts.append(str(content["title"]))
if 'quoted' in content: if "quoted" in content:
text_parts.append(str(content['quoted'])) text_parts.append(str(content["quoted"]))
if 'content' in content: if "content" in content:
text_parts.append(str(content['content'])) text_parts.append(str(content["content"]))
if 'text' in content: if "text" in content:
text_parts.append(str(content['text'])) text_parts.append(str(content["text"]))
if text_parts: if text_parts:
return " | ".join(text_parts) return " | ".join(text_parts)
@@ -120,11 +127,11 @@ class WeChatHistoryReader(BaseReader):
return "" return ""
# Remove common prefixes like "wxid_xxx:\n" # Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content) clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content) clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
# If it's just XML or system message, return empty # If it's just XML or system message, return empty
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content: if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
return "" return ""
return clean_content.strip() return clean_content.strip()
@@ -145,9 +152,9 @@ class WeChatHistoryReader(BaseReader):
# Handle dictionary content # Handle dictionary content
if isinstance(content, dict): if isinstance(content, dict):
# Check if dict has any readable text fields # Check if dict has any readable text fields
text_fields = ['title', 'quoted', 'content', 'text'] text_fields = ["title", "quoted", "content", "text"]
for field in text_fields: for field in text_fields:
if field in content and content[field]: if content.get(field):
return True return True
return False return False
@@ -156,42 +163,47 @@ class WeChatHistoryReader(BaseReader):
return False return False
# Skip image messages (contain XML with img tags) # Skip image messages (contain XML with img tags)
if '<img' in content and 'cdnurl' in content: if "<img" in content and "cdnurl" in content:
return False return False
# Skip emoji messages (contain emoji XML tags) # Skip emoji messages (contain emoji XML tags)
if '<emoji' in content and 'productid' in content: if "<emoji" in content and "productid" in content:
return False return False
# Skip voice messages # Skip voice messages
if '<voice' in content: if "<voice" in content:
return False return False
# Skip video messages # Skip video messages
if '<video' in content: if "<video" in content:
return False return False
# Skip file messages # Skip file messages
if '<appmsg' in content and 'appid' in content: if "<appmsg" in content and "appid" in content:
return False return False
# Skip system messages (like "recalled a message") # Skip system messages (like "recalled a message")
if 'recalled a message' in content: if "recalled a message" in content:
return False return False
# Check if there's actual readable text (not just XML or system messages) # Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content # Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content) clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content) clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
# If after cleaning we have meaningful text, consider it readable # If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'): if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
return True return True
return False return False
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128, def _concatenate_messages(
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]: self,
messages: list[dict],
max_length: int = 128,
time_window_minutes: int = 30,
overlap_messages: int = 0,
) -> list[dict]:
""" """
Concatenate messages based on length and time rules. Concatenate messages based on length and time rules.
@@ -214,12 +226,12 @@ class WeChatHistoryReader(BaseReader):
for message in messages: for message in messages:
# Extract message info # Extract message info
content = message.get('content', '') content = message.get("content", "")
message_text = message.get('message', '') message_text = message.get("message", "")
create_time = message.get('createTime', 0) create_time = message.get("createTime", 0)
from_user = message.get('fromUser', '') message.get("fromUser", "")
to_user = message.get('toUser', '') message.get("toUser", "")
is_sent_from_self = message.get('isSentFromSelf', False) message.get("isSentFromSelf", False)
# Extract readable text # Extract readable text
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
@@ -236,16 +248,24 @@ class WeChatHistoryReader(BaseReader):
if time_diff_minutes > time_window_minutes: if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group # Time gap too large, start new group
if current_group: if current_group:
concatenated_groups.append({ concatenated_groups.append(
'messages': current_group, {
'total_length': current_length, "messages": current_group,
'start_time': current_group[0].get('createTime', 0), "total_length": current_length,
'end_time': current_group[-1].get('createTime', 0) "start_time": current_group[0].get("createTime", 0),
}) "end_time": current_group[-1].get("createTime", 0),
}
)
# Keep last few messages for overlap # Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages: if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:] current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group) current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else: else:
current_group = [] current_group = []
current_length = 0 current_length = 0
@@ -254,16 +274,24 @@ class WeChatHistoryReader(BaseReader):
message_length = len(readable_text) message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group: if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new # Current group would exceed max length, save it and start new
concatenated_groups.append({ concatenated_groups.append(
'messages': current_group, {
'total_length': current_length, "messages": current_group,
'start_time': current_group[0].get('createTime', 0), "total_length": current_length,
'end_time': current_group[-1].get('createTime', 0) "start_time": current_group[0].get("createTime", 0),
}) "end_time": current_group[-1].get("createTime", 0),
}
)
# Keep last few messages for overlap # Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages: if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:] current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group) current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else: else:
current_group = [] current_group = []
current_length = 0 current_length = 0
@@ -275,16 +303,18 @@ class WeChatHistoryReader(BaseReader):
# Add the last group if it exists # Add the last group if it exists
if current_group: if current_group:
concatenated_groups.append({ concatenated_groups.append(
'messages': current_group, {
'total_length': current_length, "messages": current_group,
'start_time': current_group[0].get('createTime', 0), "total_length": current_length,
'end_time': current_group[-1].get('createTime', 0) "start_time": current_group[0].get("createTime", 0),
}) "end_time": current_group[-1].get("createTime", 0),
}
)
return concatenated_groups return concatenated_groups
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str: def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
""" """
Create concatenated content from a group of messages. Create concatenated content from a group of messages.
@@ -295,16 +325,16 @@ class WeChatHistoryReader(BaseReader):
Returns: Returns:
Formatted concatenated content Formatted concatenated content
""" """
messages = message_group['messages'] messages = message_group["messages"]
start_time = message_group['start_time'] start_time = message_group["start_time"]
end_time = message_group['end_time'] end_time = message_group["end_time"]
# Format timestamps # Format timestamps
if start_time: if start_time:
try: try:
start_timestamp = datetime.fromtimestamp(start_time) start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S') start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except: except (ValueError, OSError):
start_time_str = str(start_time) start_time_str = str(start_time)
else: else:
start_time_str = "Unknown" start_time_str = "Unknown"
@@ -312,8 +342,8 @@ class WeChatHistoryReader(BaseReader):
if end_time: if end_time:
try: try:
end_timestamp = datetime.fromtimestamp(end_time) end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S') end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except: except (ValueError, OSError):
end_time_str = str(end_time) end_time_str = str(end_time)
else: else:
end_time_str = "Unknown" end_time_str = "Unknown"
@@ -321,10 +351,10 @@ class WeChatHistoryReader(BaseReader):
# Build concatenated message content # Build concatenated message content
message_parts = [] message_parts = []
for message in messages: for message in messages:
content = message.get('content', '') content = message.get("content", "")
message_text = message.get('message', '') message_text = message.get("message", "")
create_time = message.get('createTime', 0) create_time = message.get("createTime", 0)
is_sent_from_self = message.get('isSentFromSelf', False) is_sent_from_self = message.get("isSentFromSelf", False)
# Extract readable text # Extract readable text
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
@@ -336,8 +366,8 @@ class WeChatHistoryReader(BaseReader):
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS # change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S') time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except: except (ValueError, OSError):
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
@@ -351,7 +381,7 @@ class WeChatHistoryReader(BaseReader):
doc_content = f""" doc_content = f"""
Contact: {contact_name} Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str} Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group['total_length']} chars): Messages ({len(messages)} messages, {message_group["total_length"]} chars):
{concatenated_text} {concatenated_text}
""" """
@@ -361,7 +391,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
""" """
return doc_content, contact_name return doc_content, contact_name
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
""" """
Load WeChat chat history data from exported JSON files. Load WeChat chat history data from exported JSON files.
@@ -376,13 +406,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
time_window_minutes (int): Time window in minutes to group messages together (default: 30). time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2). overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
wechat_export_dir = load_kwargs.get('wechat_export_dir', None) wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
include_non_text = load_kwargs.get('include_non_text', False) include_non_text = load_kwargs.get("include_non_text", False)
concatenate_messages = load_kwargs.get('concatenate_messages', False) concatenate_messages = load_kwargs.get("concatenate_messages", False)
max_length = load_kwargs.get('max_length', 1000) load_kwargs.get("max_length", 1000)
time_window_minutes = load_kwargs.get('time_window_minutes', 30) load_kwargs.get("time_window_minutes", 30)
# Default WeChat export path # Default WeChat export path
if wechat_export_dir is None: if wechat_export_dir is None:
@@ -403,7 +433,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
break break
try: try:
with open(json_file, 'r', encoding='utf-8') as f: with open(json_file, encoding="utf-8") as f:
chat_data = json.load(f) chat_data = json.load(f)
# Extract contact name from filename # Extract contact name from filename
@@ -414,7 +444,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
readable_messages = [] readable_messages = []
for message in chat_data: for message in chat_data:
try: try:
content = message.get('content', '') content = message.get("content", "")
if not include_non_text and not self._is_text_message(content): if not include_non_text and not self._is_text_message(content):
continue continue
@@ -432,7 +462,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
readable_messages, readable_messages,
max_length=-1, max_length=-1,
time_window_minutes=-1, time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups overlap_messages=0, # Keep 2 messages overlap between groups
) )
# Create documents from concatenated groups # Create documents from concatenated groups
@@ -440,12 +470,19 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name) doc_content, contact_name = self._create_concatenated_content(
doc = Document(text=doc_content, metadata={"contact_name": contact_name}) message_group, contact_name
)
doc = Document(
text=doc_content,
metadata={"contact_name": contact_name},
)
docs.append(doc) docs.append(doc)
count += 1 count += 1
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}") print(
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
)
else: else:
# Original single-message processing # Original single-message processing
@@ -454,12 +491,12 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
break break
# Extract message information # Extract message information
from_user = message.get('fromUser', '') message.get("fromUser", "")
to_user = message.get('toUser', '') message.get("toUser", "")
content = message.get('content', '') content = message.get("content", "")
message_text = message.get('message', '') message_text = message.get("message", "")
create_time = message.get('createTime', 0) create_time = message.get("createTime", 0)
is_sent_from_self = message.get('isSentFromSelf', False) is_sent_from_self = message.get("isSentFromSelf", False)
# Handle content that might be dict or string # Handle content that might be dict or string
try: try:
@@ -480,8 +517,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
if create_time: if create_time:
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S') time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except: except (ValueError, OSError):
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
@@ -512,7 +549,7 @@ Message: {readable_text if readable_text else message_text}
return docs return docs
@staticmethod @staticmethod
def find_wechat_export_dirs() -> List[Path]: def find_wechat_export_dirs() -> list[Path]:
""" """
Find all WeChat export directories. Find all WeChat export directories.
@@ -526,7 +563,7 @@ Message: {readable_text if readable_text else message_text}
Path("./wechat_export_test"), Path("./wechat_export_test"),
Path("./wechat_export"), Path("./wechat_export"),
Path("./wechat_chat_history"), Path("./wechat_chat_history"),
Path("./chat_export") Path("./chat_export"),
] ]
for export_dir in possible_dirs: for export_dir in possible_dirs:
@@ -534,13 +571,20 @@ Message: {readable_text if readable_text else message_text}
json_files = list(export_dir.glob("*.json")) json_files = list(export_dir.glob("*.json"))
if json_files: if json_files:
export_dirs.append(export_dir) export_dirs.append(export_dir)
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files") print(
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
)
print(f"Found {len(export_dirs)} WeChat export directories") print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs return export_dirs
@staticmethod @staticmethod
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False): def export_chat_to_file(
output_file: str = "wechat_chat_export.txt",
max_count: int = 1000,
export_dir: str | None = None,
include_non_text: bool = False,
):
""" """
Export WeChat chat history to a text file. Export WeChat chat history to a text file.
@@ -560,14 +604,14 @@ Message: {readable_text if readable_text else message_text}
try: try:
json_files = list(Path(export_dir).glob("*.json")) json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, 'w', encoding='utf-8') as f: with open(output_file, "w", encoding="utf-8") as f:
count = 0 count = 0
for json_file in json_files: for json_file in json_files:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
try: try:
with open(json_file, 'r', encoding='utf-8') as json_f: with open(json_file, encoding="utf-8") as json_f:
chat_data = json.load(json_f) chat_data = json.load(json_f)
contact_name = json_file.stem contact_name = json_file.stem
@@ -577,10 +621,10 @@ Message: {readable_text if readable_text else message_text}
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
from_user = message.get('fromUser', '') from_user = message.get("fromUser", "")
content = message.get('content', '') content = message.get("content", "")
message_text = message.get('message', '') message_text = message.get("message", "")
create_time = message.get('createTime', 0) create_time = message.get("createTime", 0)
# Skip non-text messages unless requested # Skip non-text messages unless requested
if not include_non_text: if not include_non_text:
@@ -595,8 +639,8 @@ Message: {readable_text if readable_text else message_text}
if create_time: if create_time:
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S') time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except: except (ValueError, OSError):
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
@@ -613,7 +657,7 @@ Message: {readable_text if readable_text else message_text}
except Exception as e: except Exception as e:
print(f"Error exporting WeChat chat history: {e}") print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]: def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
""" """
Export WeChat chat history using wechat-exporter tool. Export WeChat chat history using wechat-exporter tool.
@@ -642,16 +686,21 @@ Message: {readable_text if readable_text else message_text}
requirements_file = self.wechat_exporter_dir / "requirements.txt" requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists(): if requirements_file.exists():
print("Installing wechat-exporter requirements...") print("Installing wechat-exporter requirements...")
subprocess.run([ subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command # Run the export command
print("Running wechat-exporter...") print("Running wechat-exporter...")
result = subprocess.run([ result = subprocess.run(
sys.executable, str(self.wechat_exporter_dir / "main.py"), [
"export-all", str(export_path) sys.executable,
], capture_output=True, text=True, check=True) str(self.wechat_exporter_dir / "main.py"),
"export-all",
str(export_path),
],
capture_output=True,
text=True,
check=True,
)
print("Export command output:") print("Export command output:")
print(result.stdout) print(result.stdout)
@@ -662,7 +711,9 @@ Message: {readable_text if readable_text else message_text}
# Check if export was successful # Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")): if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json")) json_files = list(export_path.glob("*.json"))
print(f"Successfully exported {len(json_files)} chat history files to {export_path}") print(
f"Successfully exported {len(json_files)} chat history files to {export_path}"
)
return export_path return export_path
else: else:
print("Export completed but no JSON files found") print("Export completed but no JSON files found")
@@ -678,7 +729,7 @@ Message: {readable_text if readable_text else message_text}
print("Please ensure WeChat is running and WeChatTweak is installed.") print("Please ensure WeChat is running and WeChatTweak is installed.")
return None return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]: def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
""" """
Find existing WeChat exports or create new ones. Find existing WeChat exports or create new ones.
@@ -697,7 +748,7 @@ Message: {readable_text if readable_text else message_text}
Path("./wechat_export"), Path("./wechat_export"),
Path("./wechat_export_direct"), Path("./wechat_export_direct"),
Path("./wechat_chat_history"), Path("./wechat_chat_history"),
Path("./chat_export") Path("./chat_export"),
] ]
for export_dir_path in possible_export_dirs: for export_dir_path in possible_export_dirs:
@@ -714,6 +765,8 @@ Message: {readable_text if readable_text else message_text}
if exported_path: if exported_path:
export_dirs = [exported_path] export_dirs = [exported_path]
else: else:
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.") print(
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
)
return export_dirs return export_dirs

View File

@@ -1,30 +1,39 @@
import argparse
import asyncio
import os import os
import sys import sys
import asyncio
import dotenv
import argparse
from pathlib import Path from pathlib import Path
from typing import List, Any
import dotenv
# Add the project root to Python path so we can import from examples # Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannChat
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv() dotenv.load_dotenv()
# Auto-detect user's mail path # Auto-detect user's mail path
def get_mail_path(): def get_mail_path():
"""Get the mail path for the current user""" """Get the mail path for the current user"""
home_dir = os.path.expanduser("~") home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail") return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS # Default mail path for macOS
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data" DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
def create_leann_index_from_multiple_sources(
messages_dirs: list[Path],
index_path: str = "mail_index.leann",
max_count: int = -1,
include_html: bool = False,
embedding_model: str = "facebook/contriever",
):
""" """
Create LEANN index from multiple mail data sources. Create LEANN index from multiple mail data sources.
@@ -38,6 +47,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
# Load documents using EmlxReader from LEANN_email_reader # Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html) reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader # from email_data.email import EmlxMboxReader
# from pathlib import Path # from pathlib import Path
@@ -45,13 +55,13 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
all_documents = [] all_documents = []
total_processed = 0 total_processed = 0
# Process each Messages directory # Process each Messages directory
for i, messages_dir in enumerate(messages_dirs): for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}") print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
try: try:
documents = reader.load_data(messages_dir) documents = reader.load_data(messages_dir)
@@ -74,7 +84,9 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
print("No documents loaded from any source. Exiting.") print("No documents loaded from any source. Exiting.")
return None return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks") print(
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
)
# Create text splitter with 256 chunk size # Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25) text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
@@ -89,16 +101,18 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text # text = '[subject] ' + doc.metadata["subject"] + '\n' + text
all_texts.append(text) all_texts.append(text)
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks") print(
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
)
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
@@ -108,7 +122,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} email chunks to index...") print(f"Adding {len(all_texts)} email chunks to index...")
@@ -122,7 +136,14 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
return index_path return index_path
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
def create_leann_index(
mail_path: str,
index_path: str = "mail_index.leann",
max_count: int = 1000,
include_html: bool = False,
embedding_model: str = "facebook/contriever",
):
""" """
Create LEANN index from mail data. Create LEANN index from mail data.
@@ -136,15 +157,16 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Load documents using EmlxReader from LEANN_email_reader # Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html) reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader # from email_data.email import EmlxMboxReader
# from pathlib import Path # from pathlib import Path
@@ -172,12 +194,12 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
@@ -187,7 +209,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
complexity=64, complexity=64,
is_compact=True, is_compact=True,
is_recompute=True, is_recompute=True,
num_threads=1 # Force single-threaded mode num_threads=1, # Force single-threaded mode
) )
print(f"Adding {len(all_texts)} email chunks to index...") print(f"Adding {len(all_texts)} email chunks to index...")
@@ -201,6 +223,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
return index_path return index_path
async def query_leann_index(index_path: str, query: str): async def query_leann_index(index_path: str, query: str):
""" """
Query the LEANN index. Query the LEANN index.
@@ -209,13 +232,13 @@ async def query_leann_index(index_path: str, query: str):
index_path: Path to the LEANN index index_path: Path to the LEANN index
query: The query string query: The query string
""" """
print(f"\n[PHASE 2] Starting Leann chat session...") print("\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path, chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}") print(f"You: {query}")
import time import time
start_time = time.time()
time.time()
chat_response = chat.ask( chat_response = chat.ask(
query, query,
top_k=20, top_k=20,
@@ -223,26 +246,47 @@ async def query_leann_index(index_path: str, query: str):
complexity=32, complexity=32,
beam_width=1, beam_width=1,
) )
end_time = time.time() time.time()
# print(f"Time taken: {end_time - start_time} seconds") # print(f"Time taken: {end_time - start_time} seconds")
# highlight the answer # highlight the answer
print(f"Leann chat response: \033[36m{chat_response}\033[0m") print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main(): async def main():
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index') parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
# Remove --mail-path argument and auto-detect all Messages directories # Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH # Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index", parser.add_argument(
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)') "--index-dir",
parser.add_argument('--max-emails', type=int, default=1000, type=str,
help='Maximum number of emails to process (-1 means all)') default="./mail_index",
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies", help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
help='Single query to run (default: runs example queries)') )
parser.add_argument('--include-html', action='store_true', default=False, parser.add_argument(
help='Include HTML content in email processing (default: False)') "--max-emails",
parser.add_argument('--embedding-model', type=str, default="facebook/contriever", type=int,
help='Embedding model to use (default: facebook/contriever)') default=1000,
help="Maximum number of emails to process (-1 means all)",
)
parser.add_argument(
"--query",
type=str,
default="Give me some funny advertisement about apple or other companies",
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--include-html",
action="store_true",
default=False,
help="Include HTML content in email processing (default: False)",
)
parser.add_argument(
"--embedding-model",
type=str,
default="facebook/contriever",
help="Embedding model to use (default: facebook/contriever)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -250,6 +294,7 @@ async def main():
# Automatically find all Messages directories under the current user's Mail directory # Automatically find all Messages directories under the current user's Mail directory
from examples.email_data.LEANN_email_reader import find_all_messages_directories from examples.email_data.LEANN_email_reader import find_all_messages_directories
mail_path = get_mail_path() mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}") print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path) messages_dirs = find_all_messages_directories(mail_path)
@@ -257,8 +302,7 @@ async def main():
# messages_dirs = [DEFAULT_MAIL_PATH] # messages_dirs = [DEFAULT_MAIL_PATH]
# messages_dirs = messages_dirs[:1] # messages_dirs = messages_dirs[:1]
print('len(messages_dirs): ', len(messages_dirs)) print("len(messages_dirs): ", len(messages_dirs))
if not messages_dirs: if not messages_dirs:
print("No Messages directories found. Exiting.") print("No Messages directories found. Exiting.")
@@ -270,7 +314,13 @@ async def main():
print(f"Found {len(messages_dirs)} Messages directories.") print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources # Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model) index_path = create_leann_index_from_multiple_sources(
messages_dirs,
INDEX_PATH,
args.max_emails,
args.include_html,
args.embedding_model,
)
if index_path: if index_path:
if args.query: if args.query:
@@ -281,11 +331,12 @@ async def main():
queries = [ queries = [
"Hows Berkeley Graduate Student Instructor", "Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying", "how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students" "Whats the number of class recommend to take per semester for incoming EECS students",
] ]
for query in queries: for query in queries:
print("\n" + "="*60) print("\n" + "=" * 60)
await query_leann_index(index_path, query) await query_leann_index(index_path, query)
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,26 +1,30 @@
import argparse
import os import os
import sys import sys
import argparse
from pathlib import Path from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples # Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from llama_index.core import VectorStoreIndex, StorageContext import torch
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
# --- EMBEDDING MODEL --- # --- EMBEDDING MODEL ---
from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
# --- END EMBEDDING MODEL --- # --- END EMBEDDING MODEL ---
# Import EmlxReader from the new module # Import EmlxReader from the new module
from examples.email_data.LEANN_email_reader import EmlxReader from examples.email_data.LEANN_email_reader import EmlxReader
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
def create_and_save_index(
mail_path: str,
save_dir: str = "mail_index_embedded",
max_count: int = 1000,
include_html: bool = False,
):
print("Creating index from mail data with embedded metadata...") print("Creating index from mail data with embedded metadata...")
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count) documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
if not documents: if not documents:
@@ -30,7 +34,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded",
# Use facebook/contriever as the embedder # Use facebook/contriever as the embedder
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever") embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
# set on device # set on device
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
embed_model._model.to("cuda") embed_model._model.to("cuda")
# set mps # set mps
@@ -39,21 +43,19 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded",
else: else:
embed_model._model.to("cpu") embed_model._model.to("cpu")
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents, transformations=[text_splitter], embed_model=embed_model
transformations=[text_splitter],
embed_model=embed_model
) )
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir) index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}") print(f"Index saved to {save_dir}")
return index return index
def load_index(save_dir: str = "mail_index_embedded"): def load_index(save_dir: str = "mail_index_embedded"):
try: try:
storage_context = StorageContext.from_defaults(persist_dir=save_dir) storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = VectorStoreIndex.from_vector_store( index = VectorStoreIndex.from_vector_store(
storage_context.vector_store, storage_context.vector_store, storage_context=storage_context
storage_context=storage_context
) )
print(f"Index loaded from {save_dir}") print(f"Index loaded from {save_dir}")
return index return index
@@ -61,6 +63,7 @@ def load_index(save_dir: str = "mail_index_embedded"):
print(f"Error loading index: {e}") print(f"Error loading index: {e}")
return None return None
def query_index(index, query: str): def query_index(index, query: str):
if index is None: if index is None:
print("No index available for querying.") print("No index available for querying.")
@@ -70,18 +73,36 @@ def query_index(index, query: str):
print(f"Query: {query}") print(f"Query: {query}")
print(f"Response: {response}") print(f"Response: {response}")
def main(): def main():
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index') parser = argparse.ArgumentParser(
parser.add_argument('--mail-path', type=str, description="LlamaIndex Mail Reader - Create and query email index"
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages", )
help='Path to mail data directory') parser.add_argument(
parser.add_argument('--save-dir', type=str, default="mail_index_embedded", "--mail-path",
help='Directory to store the index (default: mail_index_embedded)') type=str,
parser.add_argument('--max-emails', type=int, default=10000, default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
help='Maximum number of emails to process') help="Path to mail data directory",
parser.add_argument('--include-html', action='store_true', default=False, )
help='Include HTML content in email processing (default: False)') parser.add_argument(
"--save-dir",
type=str,
default="mail_index_embedded",
help="Directory to store the index (default: mail_index_embedded)",
)
parser.add_argument(
"--max-emails",
type=int,
default=10000,
help="Maximum number of emails to process",
)
parser.add_argument(
"--include-html",
action="store_true",
default=False,
help="Include HTML content in email processing (default: False)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -93,16 +114,22 @@ def main():
index = load_index(save_dir) index = load_index(save_dir)
else: else:
print("Creating new index...") print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html) index = create_and_save_index(
mail_path,
save_dir,
max_count=args.max_emails,
include_html=args.include_html,
)
if index: if index:
queries = [ queries = [
"Hows Berkeley Graduate Student Instructor", "Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying", "how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students" "Whats the number of class recommend to take per semester for incoming EECS students",
] ]
for query in queries: for query in queries:
print("\n" + "="*50) print("\n" + "=" * 50)
query_index(index, query) query_index(index, query)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,10 +1,11 @@
import argparse import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio import asyncio
from pathlib import Path
import dotenv import dotenv
from leann.api import LeannBuilder, LeannChat from leann.api import LeannBuilder, LeannChat
from pathlib import Path from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -29,17 +30,22 @@ async def main(args):
all_texts = [] all_texts = []
for doc in documents: for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes: if nodes:
all_texts.append(node.get_content()) all_texts.extend(node.get_content() for node in nodes)
print("--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="hnsw",
embedding_model="facebook/contriever", embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
# distance_metric is automatically set based on embedding model
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
is_compact=True, is_compact=True,
@@ -56,15 +62,25 @@ async def main(args):
else: else:
print(f"--- Using existing index at {INDEX_DIR} ---") print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 2] Starting Leann chat session...") print("\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"} # Build llm_config based on command line arguments
llm_config = {"type": "ollama", "model": "qwen3:8b"} if args.llm == "simulated":
llm_config = {"type": "openai", "model": "gpt-4o"} llm_config = {"type": "simulated"}
elif args.llm == "ollama":
llm_config = {"type": "ollama", "model": args.model, "host": args.host}
elif args.llm == "hf":
llm_config = {"type": "hf", "model": args.model}
elif args.llm == "openai":
llm_config = {"type": "openai", "model": args.model}
else:
raise ValueError(f"Unknown LLM type: {args.llm}")
print(f"Using LLM: {args.llm} with model: {args.model if args.llm != 'simulated' else 'N/A'}")
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
# query = ( # query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发" # "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# ) # )
query = args.query query = args.query
@@ -74,22 +90,33 @@ async def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
description="Run Leann Chat with various LLM backends."
)
parser.add_argument( parser.add_argument(
"--llm", "--llm",
type=str, type=str,
default="hf", default="openai",
choices=["simulated", "ollama", "hf", "openai"], choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.", help="The LLM backend to use.",
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="Qwen/Qwen3-0.6B", default="gpt-4o",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).", help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
) )
parser.add_argument(
"--embedding-model",
type=str,
default="facebook/contriever",
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx"],
help="The embedding backend mode.",
)
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,

View File

@@ -14,45 +14,52 @@ Key features:
- Document-level result consolidation - Document-level result consolidation
""" """
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict from collections import defaultdict
import json from dataclasses import dataclass
from typing import Any
import numpy as np
@dataclass @dataclass
class PatchResult: class PatchResult:
"""Represents a single patch search result.""" """Represents a single patch search result."""
patch_id: int patch_id: int
image_name: str image_name: str
image_path: str image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2) coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float score: float
attention_score: float attention_score: float
scale: float scale: float
metadata: Dict[str, Any] metadata: dict[str, Any]
@dataclass @dataclass
class AggregatedResult: class AggregatedResult:
"""Represents an aggregated document-level result.""" """Represents an aggregated document-level result."""
image_name: str image_name: str
image_path: str image_path: str
doc_score: float doc_score: float
patch_count: int patch_count: int
best_patch: PatchResult best_patch: PatchResult
all_patches: List[PatchResult] all_patches: list[PatchResult]
aggregation_method: str aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None spatial_clusters: list[list[PatchResult]] | None = None
class MultiVectorAggregator: class MultiVectorAggregator:
""" """
Aggregates multiple patch-level results into document-level results. Aggregates multiple patch-level results into document-level results.
""" """
def __init__(self, def __init__(
aggregation_method: str = "maxsim", self,
spatial_clustering: bool = True, aggregation_method: str = "maxsim",
cluster_distance_threshold: float = 100.0): spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0,
):
""" """
Initialize the aggregator. Initialize the aggregator.
@@ -65,9 +72,9 @@ class MultiVectorAggregator:
self.spatial_clustering = spatial_clustering self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self, def aggregate_results(
search_results: List[Dict[str, Any]], self, search_results: list[dict[str, Any]], top_k: int = 10
top_k: int = 10) -> List[AggregatedResult]: ) -> list[AggregatedResult]:
""" """
Aggregate patch-level search results into document-level results. Aggregate patch-level search results into document-level results.
@@ -92,7 +99,7 @@ class MultiVectorAggregator:
score=result.score, score=result.score,
attention_score=metadata.get("attention_score", 0.0), attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0), scale=metadata.get("scale", 1.0),
metadata=metadata metadata=metadata,
) )
image_groups[metadata["image_name"]].append(patch_result) image_groups[metadata["image_name"]].append(patch_result)
@@ -109,7 +116,9 @@ class MultiVectorAggregator:
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True) aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k] return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult: def _aggregate_image_patches(
self, image_name: str, patches: list[PatchResult]
) -> AggregatedResult:
"""Aggregate patches for a single image.""" """Aggregate patches for a single image."""
if self.aggregation_method == "maxsim": if self.aggregation_method == "maxsim":
@@ -149,10 +158,10 @@ class MultiVectorAggregator:
best_patch=best_patch, best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True), all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method, aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters spatial_clusters=spatial_clusters,
) )
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]: def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
"""Cluster patches that are spatially close to each other.""" """Cluster patches that are spatially close to each other."""
if len(patches) <= 1: if len(patches) <= 1:
return [patches] return [patches]
@@ -180,50 +189,61 @@ class MultiVectorAggregator:
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True) return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool: def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster.""" """Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates) patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster: for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates) cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 + distance = np.sqrt(
(patch_center[1] - cluster_center[1])**2) (patch_center[0] - cluster_center[0]) ** 2
+ (patch_center[1] - cluster_center[1]) ** 2
)
if distance <= self.cluster_distance_threshold: if distance <= self.cluster_distance_threshold:
return True return True
return False return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]: def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
"""Get center point of a patch.""" """Get center point of a patch."""
x1, y1, x2, y2 = coordinates x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2) return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3): def print_aggregated_results(
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
):
"""Pretty print aggregated results.""" """Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})") print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80) print("=" * 80)
for i, result in enumerate(results): for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}") print(f"\n{i + 1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}") print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}") print(f" Path: {result.image_path}")
# Show best patch # Show best patch
best = result.best_patch best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})") print(
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
)
# Show top patches # Show top patches
print(f" 📍 Top Patches:") print(" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]): for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}") print(
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
)
# Show spatial clusters if available # Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1: if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}") print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster) cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})") print(
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
)
def demo_aggregation(): def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality.""" """Demonstrate the multi-vector aggregation functionality."""
@@ -240,80 +260,101 @@ def demo_aggregation():
# Simulate results for 2 images with multiple patches each # Simulate results for 2 images with multiple patches each
mock_results = [ mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches # Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, { MockResult(
"image_name": "cats_and_kitchen.jpg", 0.85,
"image_path": "/path/to/cats_and_kitchen.jpg", {
"patch_id": 3, "image_name": "cats_and_kitchen.jpg",
"coordinates": [100, 50, 224, 174], # Kitchen area "image_path": "/path/to/cats_and_kitchen.jpg",
"attention_score": 0.92, "patch_id": 3,
"scale": 1.0 "coordinates": [100, 50, 224, 174], # Kitchen area
}), "attention_score": 0.92,
MockResult(0.78, { "scale": 1.0,
"image_name": "cats_and_kitchen.jpg", },
"image_path": "/path/to/cats_and_kitchen.jpg", ),
"patch_id": 7, MockResult(
"coordinates": [200, 300, 324, 424], # Cat area 0.78,
"attention_score": 0.88, {
"scale": 1.0 "image_name": "cats_and_kitchen.jpg",
}), "image_path": "/path/to/cats_and_kitchen.jpg",
MockResult(0.72, { "patch_id": 7,
"image_name": "cats_and_kitchen.jpg", "coordinates": [200, 300, 324, 424], # Cat area
"image_path": "/path/to/cats_and_kitchen.jpg", "attention_score": 0.88,
"patch_id": 12, "scale": 1.0,
"coordinates": [150, 100, 274, 224], # Appliances },
"attention_score": 0.75, ),
"scale": 1.0 MockResult(
}), 0.72,
MockResult(0.65, { {
"image_name": "cats_and_kitchen.jpg", "image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg", "image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15, "patch_id": 12,
"coordinates": [50, 250, 174, 374], # Furniture "coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.70, "attention_score": 0.75,
"scale": 1.0 "scale": 1.0,
}), },
),
MockResult(
0.65,
{
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0,
},
),
# Image 2: city_street.jpg - 3 patches # Image 2: city_street.jpg - 3 patches
MockResult(0.68, { MockResult(
"image_name": "city_street.jpg", 0.68,
"image_path": "/path/to/city_street.jpg", {
"patch_id": 2, "image_name": "city_street.jpg",
"coordinates": [300, 100, 424, 224], # Buildings "image_path": "/path/to/city_street.jpg",
"attention_score": 0.80, "patch_id": 2,
"scale": 1.0 "coordinates": [300, 100, 424, 224], # Buildings
}), "attention_score": 0.80,
MockResult(0.62, { "scale": 1.0,
"image_name": "city_street.jpg", },
"image_path": "/path/to/city_street.jpg", ),
"patch_id": 8, MockResult(
"coordinates": [100, 350, 224, 474], # Street level 0.62,
"attention_score": 0.75, {
"scale": 1.0 "image_name": "city_street.jpg",
}), "image_path": "/path/to/city_street.jpg",
MockResult(0.55, { "patch_id": 8,
"image_name": "city_street.jpg", "coordinates": [100, 350, 224, 474], # Street level
"image_path": "/path/to/city_street.jpg", "attention_score": 0.75,
"patch_id": 11, "scale": 1.0,
"coordinates": [400, 200, 524, 324], # Sky area },
"attention_score": 0.60, ),
"scale": 1.0 MockResult(
}), 0.55,
{
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0,
},
),
] ]
# Test different aggregation methods # Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"] methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods: for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}") print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
aggregator = MultiVectorAggregator( aggregator = MultiVectorAggregator(
aggregation_method=method, aggregation_method=method,
spatial_clustering=True, spatial_clustering=True,
cluster_distance_threshold=100.0 cluster_distance_threshold=100.0,
) )
aggregated = aggregator.aggregate_results(mock_results, top_k=5) aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated) aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__": if __name__ == "__main__":
demo_aggregation() demo_aggregation()

View File

@@ -6,13 +6,15 @@ Complete example showing how to build and search with OpenAI embeddings using HN
""" """
import os import os
import dotenv
from pathlib import Path from pathlib import Path
import dotenv
from leann.api import LeannBuilder, LeannSearcher from leann.api import LeannBuilder, LeannSearcher
# Load environment variables # Load environment variables
dotenv.load_dotenv() dotenv.load_dotenv()
def main(): def main():
# Check if OpenAI API key is available # Check if OpenAI API key is available
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
@@ -33,13 +35,13 @@ def main():
"Artificial intelligence aims to create machines that can perform human-like tasks.", "Artificial intelligence aims to create machines that can perform human-like tasks.",
"Python is a popular programming language used extensively in data science and AI.", "Python is a popular programming language used extensively in data science and AI.",
"Neural networks are inspired by the structure and function of the human brain.", "Neural networks are inspired by the structure and function of the human brain.",
"Big data refers to extremely large datasets that require special tools to process." "Big data refers to extremely large datasets that require special tools to process.",
] ]
INDEX_DIR = Path("./simple_openai_test_index") INDEX_DIR = Path("./simple_openai_test_index")
INDEX_PATH = str(INDEX_DIR / "simple_test.leann") INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
print(f"\n=== Building Index with OpenAI Embeddings ===") print("\n=== Building Index with OpenAI Embeddings ===")
print(f"Index path: {INDEX_PATH}") print(f"Index path: {INDEX_PATH}")
try: try:
@@ -49,10 +51,10 @@ def main():
embedding_model="text-embedding-3-small", embedding_model="text-embedding-3-small",
embedding_mode="openai", embedding_mode="openai",
# HNSW settings for OpenAI embeddings # HNSW settings for OpenAI embeddings
M=16, # Smaller graph degree M=16, # Smaller graph degree
efConstruction=64, # Smaller construction complexity efConstruction=64, # Smaller construction complexity
is_compact=True, # Enable compact storage for recompute is_compact=True, # Enable compact storage for recompute
is_recompute=True, # MUST enable for OpenAI embeddings is_recompute=True, # MUST enable for OpenAI embeddings
num_threads=1, num_threads=1,
) )
@@ -63,15 +65,16 @@ def main():
print("Building index...") print("Building index...")
builder.build_index(INDEX_PATH) builder.build_index(INDEX_PATH)
print(f"✅ Index built successfully!") print("✅ Index built successfully!")
except Exception as e: except Exception as e:
print(f"❌ Error building index: {e}") print(f"❌ Error building index: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
print(f"\n=== Testing Search ===") print("\n=== Testing Search ===")
try: try:
searcher = LeannSearcher(INDEX_PATH) searcher = LeannSearcher(INDEX_PATH)
@@ -79,7 +82,7 @@ def main():
test_queries = [ test_queries = [
"What is machine learning?", "What is machine learning?",
"How do neural networks work?", "How do neural networks work?",
"Programming languages for data science" "Programming languages for data science",
] ]
for query in test_queries: for query in test_queries:
@@ -88,21 +91,23 @@ def main():
print(f" Found {len(results)} results:") print(f" Found {len(results)} results:")
for i, result in enumerate(results): for i, result in enumerate(results):
print(f" {i+1}. Score: {result.score:.4f}") print(f" {i + 1}. Score: {result.score:.4f}")
print(f" Text: {result.text[:80]}...") print(f" Text: {result.text[:80]}...")
print(f"\n✅ Search test completed successfully!") print("\n✅ Search test completed successfully!")
return True return True
except Exception as e: except Exception as e:
print(f"❌ Error during search: {e}") print(f"❌ Error during search: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
if __name__ == "__main__": if __name__ == "__main__":
success = main() success = main()
if success: if success:
print(f"\n🎉 Simple OpenAI index test completed successfully!") print("\n🎉 Simple OpenAI index test completed successfully!")
else: else:
print(f"\n💥 Simple OpenAI index test failed!") print("\n💥 Simple OpenAI index test failed!")

View File

@@ -1,18 +1,23 @@
import asyncio import asyncio
from leann.api import LeannChat
from pathlib import Path from pathlib import Path
from leann.api import LeannChat
INDEX_DIR = Path("./test_pdf_index_huawei") INDEX_DIR = Path("./test_pdf_index_huawei")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
async def main(): async def main():
print(f"\n[PHASE 2] Starting Leann chat session...") print("\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH) chat = LeannChat(index_path=INDEX_PATH)
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发" # query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1) response = chat.ask(
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
)
print(f"\n[PHASE 2] Response: {response}") print(f"\n[PHASE 2] Response: {response}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
results and the golden standard results, making the comparison robust to ID changes. results and the golden standard results, making the comparison robust to ID changes.
""" """
import json
import argparse import argparse
import json
import sys
import time import time
from pathlib import Path from pathlib import Path
import sys
import numpy as np
from typing import List
from leann.api import LeannSearcher, LeannBuilder import numpy as np
from leann.api import LeannBuilder, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False): def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists(): if not data_root.exists():
print(f"Data directory '{data_root}' not found.") print(f"Data directory '{data_root}' not found.")
print( print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@@ -63,7 +60,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
sys.exit(1) sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None): def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
"""Download embeddings files specifically.""" """Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings" embeddings_dir = data_root / "embeddings"
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
# --- Helper Function to get Golden Passages --- # --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set: def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
""" """
Retrieves the text for golden passage IDs directly from the LeannSearcher's Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager. passage manager.
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
passage_data = searcher.passage_manager.get_passage(str(gid)) passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"]) golden_texts.add(passage_data["text"])
except KeyError: except KeyError:
print( print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts return golden_texts
def load_queries(file_path: Path) -> List[str]: def load_queries(file_path: Path) -> list[str]:
queries = [] queries = []
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
for line in f: for line in f:
data = json.loads(line) data = json.loads(line)
queries.append(data["query"]) queries.append(data["query"])
return queries return queries
def build_index_from_embeddings( def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
""" """
Build a LEANN index from pre-computed embeddings. Build a LEANN index from pre-computed embeddings.
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
description="Run recall evaluation on a LEANN index."
)
parser.add_argument( parser.add_argument(
"index_path", "index_path",
type=str, type=str,
@@ -202,9 +193,7 @@ def main():
parser.add_argument( parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate." "--num-queries", type=int, default=10, help="Number of queries to evaluate."
) )
parser.add_argument( parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument( parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW." "--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
) )
@@ -219,9 +208,7 @@ def main():
# Download data based on mode # Download data based on mode
if args.mode == "build": if args.mode == "build":
# For building mode, we need embeddings # For building mode, we need embeddings
download_data_if_needed( download_data_if_needed(data_root, download_embeddings=False) # Basic data first
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings # Auto-detect dataset type and download embeddings
if args.embeddings_file: if args.embeddings_file:
@@ -262,9 +249,7 @@ def main():
print(f"Index built successfully: {built_index_path}") print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation # Ask if user wants to run evaluation
eval_response = ( eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y": if eval_response != "y":
print("Index building complete. Exiting.") print("Index building complete. Exiting.")
return return
@@ -293,12 +278,8 @@ def main():
break break
if not args.index_path: if not args.index_path:
print( print("No indices found. The data download should have included pre-built indices.")
"No indices found. The data download should have included pre-built indices." print("Please check the data/indices/ directory or provide --index-path manually.")
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1) sys.exit(1)
# Detect dataset type from index path to select the correct ground truth # Detect dataset type from index path to select the correct ground truth
@@ -310,14 +291,10 @@ def main():
else: else:
# Fallback: try to infer from the index directory name # Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name dataset_type = Path(args.index_path).name
print( print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl" queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = ( golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}") print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}") print(f"INFO: Using queries file: {queries_file}")
@@ -327,7 +304,7 @@ def main():
searcher = LeannSearcher(args.index_path) searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file) queries = load_queries(queries_file)
with open(golden_results_file, "r") as f: with open(golden_results_file) as f:
golden_results_data = json.load(f) golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries)) num_eval_queries = min(args.num_queries, len(queries))
@@ -339,9 +316,7 @@ def main():
for i in range(num_eval_queries): for i in range(num_eval_queries):
start_time = time.time() start_time = time.time()
new_results = searcher.search( new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time) search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content # Correct Recall Calculation: Based on TEXT content

View File

@@ -4,13 +4,20 @@ Run: uv run python examples/simple_demo.py
""" """
import argparse import argparse
from leann import LeannBuilder, LeannSearcher, LeannChat
from leann import LeannBuilder, LeannChat, LeannSearcher
def main(): def main():
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.") parser = argparse.ArgumentParser(
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2", description="Simple demo of Leann with selectable embedding models."
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.") )
parser.add_argument(
"--embedding_model",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
)
args = parser.parse_args() args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===") print(f"=== Leann Simple Demo with {args.embedding_model} ===")

View File

@@ -1,13 +1,11 @@
import os
import asyncio
import dotenv
import argparse import argparse
import asyncio
import os
from pathlib import Path from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat import dotenv
from leann.api import LeannBuilder, LeannChat
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -16,7 +14,7 @@ DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports( def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path], export_dirs: list[Path],
index_path: str = "wechat_history_index.leann", index_path: str = "wechat_history_index.leann",
max_count: int = -1, max_count: int = -1,
): ):
@@ -38,15 +36,13 @@ def create_leann_index_from_multiple_wechat_exports(
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
all_documents = [] all_documents = []
total_processed = 0 total_processed = 0
# Process each WeChat export directory # Process each WeChat export directory
for i, export_dir in enumerate(export_dirs): for i, export_dir in enumerate(export_dirs):
print( print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try: try:
documents = reader.load_data( documents = reader.load_data(
@@ -86,7 +82,12 @@ def create_leann_index_from_multiple_wechat_exports(
# Split the document into chunks # Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc]) nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content() text = (
"[Contact] means the message is from: "
+ doc.metadata["contact_name"]
+ "\n"
+ node.get_content()
)
all_texts.append(text) all_texts.append(text)
print( print(
@@ -94,12 +95,12 @@ def create_leann_index_from_multiple_wechat_exports(
) )
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
@@ -125,7 +126,7 @@ def create_leann_index_from_multiple_wechat_exports(
def create_leann_index( def create_leann_index(
export_dir: str = None, export_dir: str | None = None,
index_path: str = "wechat_history_index.leann", index_path: str = "wechat_history_index.leann",
max_count: int = 1000, max_count: int = 1000,
): ):
@@ -141,12 +142,12 @@ def create_leann_index(
INDEX_DIR = Path(index_path).parent INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists(): if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Load documents using WeChatHistoryReader from history_data # Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader from history_data.wechat_history import WeChatHistoryReader
@@ -179,12 +180,12 @@ def create_leann_index(
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents") print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory # Create LEANN index directory
print(f"--- Index directory not found, building new index ---") print("--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True) INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---") print("--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...") print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility # Use HNSW backend for better macOS compatibility
builder = LeannBuilder( builder = LeannBuilder(
@@ -217,7 +218,7 @@ async def query_leann_index(index_path: str, query: str):
index_path: Path to the LEANN index index_path: Path to the LEANN index
query: The query string query: The query string
""" """
print(f"\n[PHASE 2] Starting Leann chat session...") print("\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path) chat = LeannChat(index_path=index_path)
print(f"You: {query}") print(f"You: {query}")
@@ -307,7 +308,7 @@ async def main():
else: else:
# Example queries # Example queries
queries = [ queries = [
"我想买魔术师约翰逊的球衣给我一些对应聊天记录?", "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
] ]
for query in queries: for query in queries:

View File

@@ -1 +0,0 @@

View File

@@ -1 +1 @@
from . import diskann_backend from . import diskann_backend as diskann_backend

View File

@@ -1,20 +1,20 @@
import numpy as np import contextlib
import logging
import os import os
import struct import struct
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Literal, Optional from typing import Any, Literal
import contextlib
import logging import numpy as np
import psutil
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendSearcherInterface, LeannBackendSearcherInterface,
) )
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -85,6 +85,43 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
f.write(data.tobytes()) f.write(data.tobytes())
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
"""
Calculate smart memory configuration for DiskANN based on data size and system specs.
Args:
data: The embedding data array
Returns:
tuple: (search_memory_maximum, build_memory_maximum) in GB
"""
num_vectors, dim = data.shape
# Calculate embedding storage size
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
embedding_size_gb = embedding_size_bytes / (1024**3)
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
# This controls Product Quantization size - smaller means more compression
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
# build_memory_maximum: Based on available system RAM for sharding control
# This controls how much memory DiskANN uses during index construction
available_memory_gb = psutil.virtual_memory().available / (1024**3)
total_memory_gb = psutil.virtual_memory().total / (1024**3)
# Use 50% of available memory, but at least 2GB and at most 75% of total
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
logger.info(
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
)
return search_memory_gb, build_memory_gb
@register_backend("diskann") @register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface): class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -100,7 +137,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -122,6 +159,16 @@ class DiskannBuilder(LeannBackendBuilderInterface):
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'." f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
) )
# Calculate smart memory configuration if not explicitly provided
if (
"search_memory_maximum" not in build_kwargs
or "build_memory_maximum" not in build_kwargs
):
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
else:
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
try: try:
from . import _diskannpy as diskannpy # type: ignore from . import _diskannpy as diskannpy # type: ignore
@@ -132,8 +179,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_prefix, index_prefix,
build_kwargs.get("complexity", 64), build_kwargs.get("complexity", 64),
build_kwargs.get("graph_degree", 32), build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("search_memory_maximum", smart_search_mem),
build_kwargs.get("build_memory_maximum", 8.0), build_kwargs.get("build_memory_maximum", smart_build_mem),
build_kwargs.get("num_threads", 8), build_kwargs.get("num_threads", 8),
build_kwargs.get("pq_disk_bytes", 0), build_kwargs.get("pq_disk_bytes", 0),
"", "",
@@ -164,18 +211,44 @@ class DiskannSearcher(BaseSearcher):
self.num_threads = kwargs.get("num_threads", 8) self.num_threads = kwargs.get("num_threads", 8)
fake_zmq_port = 6666 # For DiskANN, we need to reinitialize the index when zmq_port changes
# Store the initialization parameters for later use
full_index_prefix = str(self.index_dir / self.index_path.stem) full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex( self._init_params = {
metric_enum, "metric_enum": metric_enum,
full_index_prefix, "full_index_prefix": full_index_prefix,
self.num_threads, "num_threads": self.num_threads,
kwargs.get("num_nodes_to_cache", 0), "num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
1, "cache_mechanism": 1,
fake_zmq_port, # Initial port, can be updated at runtime "pq_prefix": "",
"", "partition_prefix": "",
"", }
) self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
def _ensure_index_loaded(self, zmq_port: int):
"""Ensure the index is loaded with the correct zmq_port."""
if self._index is None or self._current_zmq_port != zmq_port:
# Need to (re)load the index with the correct zmq_port
with suppress_cpp_output_if_needed():
if self._index is not None:
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
else:
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
self._index = self._diskannpy.StaticDiskFloatIndex(
self._init_params["metric_enum"],
self._init_params["full_index_prefix"],
self._init_params["num_threads"],
self._init_params["num_nodes_to_cache"],
self._init_params["cache_mechanism"],
zmq_port,
self._init_params["pq_prefix"],
self._init_params["partition_prefix"],
)
self._current_zmq_port = zmq_port
def search( def search(
self, self,
@@ -186,11 +259,11 @@ class DiskannSearcher(BaseSearcher):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None, zmq_port: int | None = None,
batch_recompute: bool = False, batch_recompute: bool = False,
dedup_node_dis: bool = False, dedup_node_dis: bool = False,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Search for nearest neighbors using DiskANN index. Search for nearest neighbors using DiskANN index.
@@ -213,18 +286,15 @@ class DiskannSearcher(BaseSearcher):
Returns: Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray) Dict with 'labels' (list of lists) and 'distances' (ndarray)
""" """
# Handle zmq_port compatibility: DiskANN can now update port at runtime # Handle zmq_port compatibility: Ensure index is loaded with correct port
if recompute_embeddings: if recompute_embeddings:
if zmq_port is None: if zmq_port is None:
raise ValueError( raise ValueError("zmq_port must be provided if recompute_embeddings is True")
"zmq_port must be provided if recompute_embeddings is True" self._ensure_index_loaded(zmq_port)
) else:
current_port = self._index.get_zmq_port() # If not recomputing, we still need an index, use a default port
if zmq_port != current_port: if self._index is None:
logger.debug( self._ensure_index_loaded(6666) # Default port when not recomputing
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy # DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional": if pruning_strategy == "proportional":
@@ -259,8 +329,6 @@ class DiskannSearcher(BaseSearcher):
use_global_pruning, use_global_pruning,
) )
string_labels = [ string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -3,16 +3,16 @@ DiskANN-specific embedding server
""" """
import argparse import argparse
import json
import logging
import os
import sys
import threading import threading
import time import time
import os
import zmq
import numpy as np
import json
from pathlib import Path from pathlib import Path
from typing import Optional
import sys import numpy as np
import logging import zmq
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -32,10 +32,11 @@ if not logger.handlers:
def create_diskann_embedding_server( def create_diskann_embedding_server(
passages_file: Optional[str] = None, passages_file: str | None = None,
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
): ):
""" """
Create and start a ZMQ-based embedding server for DiskANN backend. Create and start a ZMQ-based embedding server for DiskANN backend.
@@ -50,8 +51,8 @@ def create_diskann_embedding_server(
sys.path.insert(0, str(leann_core_path)) sys.path.insert(0, str(leann_core_path))
try: try:
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
logger.info("Successfully imported unified embedding computation module") logger.info("Successfully imported unified embedding computation module")
except ImportError as e: except ImportError as e:
@@ -76,7 +77,7 @@ def create_diskann_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported") raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources # Load metadata to get passage sources
with open(passages_file, "r") as f: with open(passages_file) as f:
meta = json.load(f) meta = json.load(f)
passages = PassageManager(meta["passage_sources"]) passages = PassageManager(meta["passage_sources"])
@@ -150,9 +151,7 @@ def create_diskann_embedding_server(
): ):
texts = request texts = request
is_text_request = True is_text_request = True
logger.info( logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
)
else: else:
raise ValueError("Not a valid msgpack text request") raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error: except Exception as msgpack_error:
@@ -167,9 +166,7 @@ def create_diskann_embedding_server(
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data["text"]
if not txt: if not txt:
raise RuntimeError( raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt) texts.append(txt)
except KeyError as e: except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}") logger.error(f"Passage ID {nid} not found: {e}")
@@ -180,9 +177,7 @@ def create_diskann_embedding_server(
# Debug logging # Debug logging
logger.debug(f"Processing {len(texts)} texts") logger.debug(f"Processing {len(texts)} texts")
logger.debug( logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
f"Text lengths: {[len(t) for t in texts[:5]]}"
) # Show first 5
# Process embeddings using unified computation # Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
@@ -199,9 +194,7 @@ def create_diskann_embedding_server(
else: else:
# For DiskANN C++ compatibility: return protobuf format # For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse() resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray( hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
embeddings, dtype=np.float32
)
# Serialize embeddings data # Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes() resp_proto.embeddings_data = hidden_contiguous.tobytes()
@@ -271,6 +264,13 @@ if __name__ == "__main__":
choices=["sentence-transformers", "openai", "mlx"], choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode", help="Embedding backend mode",
) )
parser.add_argument(
"--distance-metric",
type=str,
default="l2",
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
args = parser.parse_args() args = parser.parse_args()
@@ -280,4 +280,5 @@ if __name__ == "__main__":
zmq_port=args.zmq_port, zmq_port=args.zmq_port,
model_name=args.model_name, model_name=args.model_name,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
) )

View File

@@ -1,27 +1,28 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: embedding.proto # source: embedding.proto
# ruff: noqa
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3') )
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
if _descriptor._USE_C_DESCRIPTORS == False: if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._options = None
DESCRIPTOR._options = None _NODEEMBEDDINGREQUEST._serialized_start = 35
_NODEEMBEDDINGREQUEST._serialized_start=35 _NODEEMBEDDINGREQUEST._serialized_end = 75
_NODEEMBEDDINGREQUEST._serialized_end=75 _NODEEMBEDDINGRESPONSE._serialized_start = 77
_NODEEMBEDDINGRESPONSE._serialized_start=77 _NODEEMBEDDINGRESPONSE._serialized_end = 166
_NODEEMBEDDINGRESPONSE._serialized_end=166
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.1.9" version = "0.1.16"
dependencies = ["leann-core==0.1.9", "numpy"] dependencies = ["leann-core==0.1.16", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -10,6 +10,14 @@ if(APPLE)
set(OpenMP_C_LIB_NAMES "omp") set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp") set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib") set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
# Force use of system libc++ to avoid version mismatch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
# Set minimum macOS version for better compatibility
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif() endif()
# Use system ZeroMQ instead of building from source # Use system ZeroMQ instead of building from source

View File

@@ -1 +1 @@
from . import hnsw_backend from . import hnsw_backend as hnsw_backend

View File

@@ -1,87 +1,115 @@
import argparse
import gc # Import garbage collector interface
import os
import struct import struct
import sys import sys
import numpy as np
import os
import argparse
import gc # Import garbage collector interface
import time import time
import numpy as np
# --- FourCCs (add more if needed) --- # --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little') INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
# Add other HNSW fourccs if you expect different storage types inside HNSW # Add other HNSW fourccs if you expect different storage types inside HNSW
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little') # INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little') # INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example # INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little') NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
# --- Helper functions for reading/writing binary data --- # --- Helper functions for reading/writing binary data ---
def read_struct(f, fmt): def read_struct(f, fmt):
"""Reads data according to the struct format.""" """Reads data according to the struct format."""
size = struct.calcsize(fmt) size = struct.calcsize(fmt)
data = f.read(size) data = f.read(size)
if len(data) != size: if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.") raise EOFError(
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
)
return struct.unpack(fmt, data)[0] return struct.unpack(fmt, data)[0]
def read_vector_raw(f, element_fmt_char): def read_vector_raw(f, element_fmt_char):
"""Reads a vector (size followed by data), returns count and raw bytes.""" """Reads a vector (size followed by data), returns count and raw bytes."""
count = -1 # Initialize count count = -1 # Initialize count
total_bytes = -1 # Initialize total_bytes total_bytes = -1 # Initialize total_bytes
try: try:
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
element_size = struct.calcsize(element_fmt_char) element_size = struct.calcsize(element_fmt_char)
# --- FIX for MemoryError: Check for unreasonably large count --- # --- FIX for MemoryError: Check for unreasonably large count ---
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
if count > max_reasonable_count or count < 0: if count > max_reasonable_count or count < 0:
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.") raise MemoryError(
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
)
total_bytes = count * element_size total_bytes = count * element_size
# --- FIX for MemoryError: Check for huge byte size before allocation --- # --- FIX for MemoryError: Check for huge byte size before allocation ---
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.") raise MemoryError(
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
)
data_bytes = f.read(total_bytes) data_bytes = f.read(total_bytes)
if len(data_bytes) != total_bytes: if len(data_bytes) != total_bytes:
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.") raise EOFError(
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
)
return count, data_bytes return count, data_bytes
except (MemoryError, OverflowError) as e: except (MemoryError, OverflowError) as e:
# Add context to the error message # Add context to the error message
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr) print(
raise e # Re-raise the original error type f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
file=sys.stderr,
)
raise e # Re-raise the original error type
def read_numpy_vector(f, np_dtype, struct_fmt_char): def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array.""" """Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling count = -1 # Initialize count for robust error handling
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True) print(
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
try: try:
count, data_bytes = read_vector_raw(f, struct_fmt_char) count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}") print(f"Count={count}, Bytes={len(data_bytes)}")
if count > 0 and len(data_bytes) > 0: if count > 0 and len(data_bytes) > 0:
arr = np.frombuffer(data_bytes, dtype=np_dtype) arr = np.frombuffer(data_bytes, dtype=np_dtype)
if arr.size != count: if arr.size != count:
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}") raise ValueError(
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
)
return arr return arr
elif count == 0: elif count == 0:
return np.array([], dtype=np_dtype) return np.array([], dtype=np_dtype)
else: else:
raise ValueError("Read zero bytes but count > 0.") raise ValueError("Read zero bytes but count > 0.")
except MemoryError as e: except MemoryError as e:
# Now count should be defined (or -1 if error was in read_struct) # Now count should be defined (or -1 if error was in read_struct)
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr) print(
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
file=sys.stderr,
)
raise e raise e
except Exception as e: # Catch other potential errors like ValueError except Exception as e: # Catch other potential errors like ValueError
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr) print(
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
file=sys.stderr,
)
raise e raise e
def write_numpy_vector(f, arr, struct_fmt_char): def write_numpy_vector(f, arr, struct_fmt_char):
"""Writes a NumPy array as a vector (size followed by data).""" """Writes a NumPy array as a vector (size followed by data)."""
count = arr.size count = arr.size
f.write(struct.pack('<Q', count)) f.write(struct.pack("<Q", count))
try: try:
expected_dtype = np.dtype(struct_fmt_char) expected_dtype = np.dtype(struct_fmt_char)
if arr.dtype != expected_dtype: if arr.dtype != expected_dtype:
@@ -89,23 +117,30 @@ def write_numpy_vector(f, arr, struct_fmt_char):
else: else:
data_to_write = arr.tobytes() data_to_write = arr.tobytes()
f.write(data_to_write) f.write(data_to_write)
del data_to_write # Hint GC del data_to_write # Hint GC
except MemoryError as e: except MemoryError as e:
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr) print(
raise e f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
file=sys.stderr,
)
raise e
def write_list_vector(f, lst, struct_fmt_char): def write_list_vector(f, lst, struct_fmt_char):
"""Writes a Python list as a vector iteratively.""" """Writes a Python list as a vector iteratively."""
count = len(lst) count = len(lst)
f.write(struct.pack('<Q', count)) f.write(struct.pack("<Q", count))
fmt = '<' + struct_fmt_char fmt = "<" + struct_fmt_char
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
element_size = struct.calcsize(fmt) element_size = struct.calcsize(fmt)
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation # Allocate buffer outside the loop if possible, or handle MemoryError during allocation
try: try:
buffer = bytearray(chunk_size * element_size) buffer = bytearray(chunk_size * element_size)
except MemoryError: except MemoryError:
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr) print(
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
file=sys.stderr,
)
raise raise
buffer_count = 0 buffer_count = 0
@@ -116,65 +151,79 @@ def write_list_vector(f, lst, struct_fmt_char):
buffer_count += 1 buffer_count += 1
if buffer_count == chunk_size or i == count - 1: if buffer_count == chunk_size or i == count - 1:
f.write(buffer[:buffer_count * element_size]) f.write(buffer[: buffer_count * element_size])
buffer_count = 0 buffer_count = 0
except struct.error as e: except struct.error as e:
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr) print(
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
file=sys.stderr,
)
raise e raise e
def get_cum_neighbors(cum_nneighbor_per_level_np, level): def get_cum_neighbors(cum_nneighbor_per_level_np, level):
"""Helper to get cumulative neighbors count, matching C++ logic.""" """Helper to get cumulative neighbors count, matching C++ logic."""
if level < 0: return 0 if level < 0:
return 0
if level < len(cum_nneighbor_per_level_np): if level < len(cum_nneighbor_per_level_np):
return cum_nneighbor_per_level_np[level] return cum_nneighbor_per_level_np[level]
else: else:
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0 return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np, def write_compact_format(
compact_neighbors_data, storage_fourcc, storage_data): f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data,
):
"""Write HNSW data in compact format following C++ read order exactly.""" """Write HNSW data in compact format following C++ read order exactly."""
# Write IndexHNSW Header # Write IndexHNSW Header
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc'])) f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack('<i', original_hnsw_data['d'])) f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack('<q', original_hnsw_data['ntotal'])) f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack('<q', original_hnsw_data['dummy1'])) f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack('<q', original_hnsw_data['dummy2'])) f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack('<?', original_hnsw_data['is_trained'])) f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack('<i', original_hnsw_data['metric_type'])) f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data['metric_type'] > 1: if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg'])) f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
# Write HNSW struct parts (standard order) # Write HNSW struct parts (standard order)
write_numpy_vector(f_out, assign_probas_np, 'd') write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i') write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, 'i') write_numpy_vector(f_out, levels_np, "i")
# Write compact format flag # Write compact format flag
f_out.write(struct.pack('<?', True)) # storage_is_compact = True f_out.write(struct.pack("<?", True)) # storage_is_compact = True
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST # Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
if isinstance(compact_level_ptr, np.ndarray): if isinstance(compact_level_ptr, np.ndarray):
write_numpy_vector(f_out, compact_level_ptr, 'Q') write_numpy_vector(f_out, compact_level_ptr, "Q")
else: else:
write_list_vector(f_out, compact_level_ptr, 'Q') write_list_vector(f_out, compact_level_ptr, "Q")
write_numpy_vector(f_out, compact_node_offsets_np, 'Q') write_numpy_vector(f_out, compact_node_offsets_np, "Q")
# Write HNSW scalar parameters # Write HNSW scalar parameters
f_out.write(struct.pack('<i', original_hnsw_data['entry_point'])) f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack('<i', original_hnsw_data['max_level'])) f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction'])) f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack('<i', original_hnsw_data['efSearch'])) f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam'])) f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
# Write storage fourcc (this determines how to read what follows) # Write storage fourcc (this determines how to read what follows)
f_out.write(struct.pack('<I', storage_fourcc)) f_out.write(struct.pack("<I", storage_fourcc))
# Write compact neighbors data AFTER storage fourcc # Write compact neighbors data AFTER storage fourcc
write_list_vector(f_out, compact_neighbors_data, 'i') write_list_vector(f_out, compact_neighbors_data, "i")
# Write storage data if not NULL (only after neighbors) # Write storage data if not NULL (only after neighbors)
if storage_fourcc != NULL_INDEX_FOURCC and storage_data: if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
@@ -183,6 +232,7 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
# --- Main Conversion Logic --- # --- Main Conversion Logic ---
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True): def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
""" """
Converts an HNSW graph file to the CSR format. Converts an HNSW graph file to the CSR format.
@@ -196,91 +246,115 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Starting conversion: {input_filename} -> {output_filename}") print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time() start_time = time.time()
original_hnsw_data = {} original_hnsw_data = {}
neighbors_np = None # Initialize to allow check in finally block neighbors_np = None # Initialize to allow check in finally block
try: try:
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out: with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
# --- Read IndexHNSW FourCC and Header --- # --- Read IndexHNSW FourCC and Header ---
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...") print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
# ... (Keep the header reading logic as before) ... # ... (Keep the header reading logic as before) ...
hnsw_index_fourcc = read_struct(f_in, '<I') hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS: if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr) print(
return False f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc file=sys.stderr,
original_hnsw_data['d'] = read_struct(f_in, '<i') )
original_hnsw_data['ntotal'] = read_struct(f_in, '<q') return False
original_hnsw_data['dummy1'] = read_struct(f_in, '<q') original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data['dummy2'] = read_struct(f_in, '<q') original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data['is_trained'] = read_struct(f_in, '?') original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data['metric_type'] = read_struct(f_in, '<i') original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data['metric_arg'] = 0.0 original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
if original_hnsw_data['metric_type'] > 1: original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f') original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}") original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
print(
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
)
# --- Read original HNSW struct data --- # --- Read original HNSW struct data ---
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...") print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd') assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
)
gc.collect() gc.collect()
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i') cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
)
gc.collect() gc.collect()
levels_np = read_numpy_vector(f_in, np.int32, 'i') levels_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})") print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
gc.collect() gc.collect()
ntotal = len(levels_np) ntotal = len(levels_np)
if ntotal != original_hnsw_data['ntotal']: if ntotal != original_hnsw_data["ntotal"]:
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr) print(
original_hnsw_data['ntotal'] = ntotal f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
file=sys.stderr,
)
original_hnsw_data["ntotal"] = ntotal
# --- Check for compact format flag --- # --- Check for compact format flag ---
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...") print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
pos_before_compact = f_in.tell() pos_before_compact = f_in.tell()
try: try:
is_compact_flag = read_struct(f_in, '<?') is_compact_flag = read_struct(f_in, "<?")
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}") print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
if is_compact_flag: if is_compact_flag:
# Input is already in compact format - read compact data # Input is already in compact format - read compact data
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...") print(
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
)
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q') compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
)
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q') compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
)
# Read scalar parameters # Read scalar parameters
original_hnsw_data['entry_point'] = read_struct(f_in, '<i') original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data['max_level'] = read_struct(f_in, '<i') original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i') original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data['efSearch'] = read_struct(f_in, '<i') original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i') original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})") print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
# Read storage fourcc # Read storage fourcc
storage_fourcc = read_struct(f_in, '<I') storage_fourcc = read_struct(f_in, "<I")
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}") print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
)
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC: if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
# Read compact neighbors data # Read compact neighbors data
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})") print(
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
)
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
# Skip storage data and write with NULL marker # Skip storage data and write with NULL marker
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.") print(
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
)
storage_fourcc = NULL_INDEX_FOURCC storage_fourcc = NULL_INDEX_FOURCC
elif not prune_embeddings: elif not prune_embeddings:
# Read and preserve compact neighbors and storage # Read and preserve compact neighbors and storage
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
@@ -288,16 +362,25 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
storage_data = f_in.read() storage_data = f_in.read()
else: else:
# Already pruned (NULL storage) # Already pruned (NULL storage)
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i') compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
storage_data = b'' storage_data = b""
# Write the updated compact format # Write the updated compact format
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...") print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(
levels_np, compact_level_ptr, compact_node_offsets_np, f_out,
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'') original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data if not prune_embeddings else b"",
)
print(f"[{time.time() - start_time:.2f}s] Conversion complete.") print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
return True return True
@@ -305,63 +388,86 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
else: else:
# is_compact=False, rewind and read original format # is_compact=False, rewind and read original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...") print(
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
)
except EOFError: except EOFError:
# No compact flag found, assume original format # No compact flag found, assume original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...") print(
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
)
# --- Handle potential extra byte in original format (like C++ code) --- # --- Handle potential extra byte in original format (like C++ code) ---
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...") print(
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
)
pos_before_probe = f_in.tell() pos_before_probe = f_in.tell()
try: try:
suspected_flag = read_struct(f_in, '<B') # Read 1 byte suspected_flag = read_struct(f_in, "<B") # Read 1 byte
if suspected_flag == 0x00: if suspected_flag == 0x00:
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.") print(
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
)
elif suspected_flag == 0x01: elif suspected_flag == 0x01:
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False") print(
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
)
raise ValueError("Inconsistent compact flag state") raise ValueError("Inconsistent compact flag state")
else: else:
# Rewind - this byte is part of offsets data # Rewind - this byte is part of offsets data
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})") print(
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
)
except EOFError: except EOFError:
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read") print(
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
)
# --- Read original format data --- # --- Read original format data ---
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q') offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})") print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
if len(offsets_np) != ntotal + 1: if len(offsets_np) != ntotal + 1:
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}") raise ValueError(
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
)
gc.collect() gc.collect()
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...") print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
neighbors_np = read_numpy_vector(f_in, np.int32, 'i') neighbors_np = read_numpy_vector(f_in, np.int32, "i")
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})") print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0 expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
if neighbors_np.size != expected_neighbors_size: if neighbors_np.size != expected_neighbors_size:
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.") print(
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
)
gc.collect() gc.collect()
original_hnsw_data['entry_point'] = read_struct(f_in, '<i') original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data['max_level'] = read_struct(f_in, '<i') original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i') original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data['efSearch'] = read_struct(f_in, '<i') original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i') original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})") print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...") print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
storage_fourcc = None storage_fourcc = None
try: try:
storage_fourcc = read_struct(f_in, '<I') storage_fourcc = read_struct(f_in, "<I")
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.") print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
)
except EOFError: except EOFError:
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).") print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
except Exception as e: except Exception as e:
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}") print(
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
)
# --- Perform Conversion --- # --- Perform Conversion ---
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...") print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
@@ -373,17 +479,21 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
current_level_ptr_idx = 0 current_level_ptr_idx = 0
current_data_idx = 0 current_data_idx = 0
total_valid_neighbors_counted = 0 # For validation total_valid_neighbors_counted = 0 # For validation
# Optimize calculation by getting slices once per node if possible # Optimize calculation by getting slices once per node if possible
for i in range(ntotal): for i in range(ntotal):
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1% if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
progress = (i / ntotal) * 100 progress = (i / ntotal) * 100
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="") print(
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
end="",
)
node_max_level = levels_np[i] - 1 node_max_level = levels_np[i] - 1
if node_max_level < -1: node_max_level = -1 if node_max_level < -1:
node_max_level = -1
node_ptr_start_index = current_level_ptr_idx node_ptr_start_index = current_level_ptr_idx
compact_node_offsets_np[i] = node_ptr_start_index compact_node_offsets_np[i] = node_ptr_start_index
@@ -394,13 +504,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
for level in range(node_max_level + 1): for level in range(node_max_level + 1):
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level) begin_orig_np = original_offset_start + get_cum_neighbors(
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1) cum_nneighbor_per_level_np, level
)
end_orig_np = original_offset_start + get_cum_neighbors(
cum_nneighbor_per_level_np, level + 1
)
begin_orig = int(begin_orig_np) begin_orig = int(begin_orig_np)
end_orig = int(end_orig_np) end_orig = int(end_orig_np)
neighbors_len = len(neighbors_np) # Cache length neighbors_len = len(neighbors_np) # Cache length
begin_orig = min(max(0, begin_orig), neighbors_len) begin_orig = min(max(0, begin_orig), neighbors_len)
end_orig = min(max(begin_orig, end_orig), neighbors_len) end_orig = min(max(begin_orig, end_orig), neighbors_len)
@@ -413,82 +527,116 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
if num_valid > 0: if num_valid > 0:
# Append valid neighbors # Append valid neighbors
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask]) compact_neighbors_data.extend(
level_neighbors_slice[valid_neighbors_mask]
)
current_data_idx += num_valid current_data_idx += num_valid
total_valid_neighbors_counted += num_valid total_valid_neighbors_counted += num_valid
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
current_level_ptr_idx += num_pointers_expected current_level_ptr_idx += num_pointers_expected
compact_node_offsets_np[ntotal] = current_level_ptr_idx compact_node_offsets_np[ntotal] = current_level_ptr_idx
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line print(
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
) # Clear progress line
# --- Validation Checks --- # --- Validation Checks ---
print(f"[{time.time() - start_time:.2f}s] Running validation checks...") print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
valid_check_passed = True valid_check_passed = True
# Check 1: Total valid neighbors count # Check 1: Total valid neighbors count
print(f" Checking total valid neighbor count...") print(" Checking total valid neighbor count...")
expected_valid_count = np.sum(neighbors_np >= 0) expected_valid_count = np.sum(neighbors_np >= 0)
if total_valid_neighbors_counted != len(compact_neighbors_data): if total_valid_neighbors_counted != len(compact_neighbors_data):
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) print(
valid_check_passed = False f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
if expected_valid_count != len(compact_neighbors_data): if expected_valid_count != len(compact_neighbors_data):
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) print(
valid_check_passed = False f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
else: else:
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}") print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
# Check 2: Final pointer indices consistency # Check 2: Final pointer indices consistency
print(f" Checking final pointer indices...") print(" Checking final pointer indices...")
if compact_node_offsets_np[ntotal] != len(compact_level_ptr): if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr) print(
valid_check_passed = False f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \ file=sys.stderr,
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0): )
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1 valid_check_passed = False
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr) if (
valid_check_passed = False len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
print(
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
else: else:
print(f" OK: Final pointers match data size.") print(" OK: Final pointers match data size.")
if not valid_check_passed: if not valid_check_passed:
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr) print(
"Error: Validation checks failed. Output file might be incorrect.",
file=sys.stderr,
)
# Optional: Exit here if validation fails # Optional: Exit here if validation fails
# return False # return False
# --- Explicitly delete large intermediate arrays --- # --- Explicitly delete large intermediate arrays ---
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...") print(
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
)
del neighbors_np del neighbors_np
del offsets_np del offsets_np
gc.collect() gc.collect()
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}") print(
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
)
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...") print(
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
)
# Determine storage fourcc and data based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
if prune_embeddings: if prune_embeddings:
print(f" Pruning embeddings: Writing NULL storage marker.") print(" Pruning embeddings: Writing NULL storage marker.")
output_storage_fourcc = NULL_INDEX_FOURCC output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b'' storage_data = b""
else: else:
# Keep embeddings - read and preserve original storage data # Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC: if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(f" Preserving embeddings: Reading original storage data...") print(" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data") print(f" Read {len(storage_data)} bytes of storage data")
else: else:
print(f" No embeddings found in original file (NULL storage)") print(" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b'' storage_data = b""
# Use the unified write function # Use the unified write function
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(
levels_np, compact_level_ptr, compact_node_offsets_np, f_out,
compact_neighbors_data, output_storage_fourcc, storage_data) original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
output_storage_fourcc,
storage_data,
)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np
@@ -503,40 +651,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr) print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False return False
except MemoryError as e: except MemoryError as e:
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr) print(
# Clean up potentially partially written output file? f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
try: os.remove(output_filename) file=sys.stderr,
except OSError: pass )
return False # Clean up potentially partially written output file?
try:
os.remove(output_filename)
except OSError:
pass
return False
except EOFError as e: except EOFError as e:
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr) print(
try: os.remove(output_filename) f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
except OSError: pass file=sys.stderr,
)
try:
os.remove(output_filename)
except OSError:
pass
return False return False
except Exception as e: except Exception as e:
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr) print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
os.remove(output_filename) os.remove(output_filename)
except OSError: pass except OSError:
pass
return False return False
# Ensure neighbors_np is deleted even if an error occurs after its allocation # Ensure neighbors_np is deleted even if an error occurs after its allocation
finally: finally:
if 'neighbors_np' in locals() and neighbors_np is not None: try:
del neighbors_np if "neighbors_np" in locals() and neighbors_np is not None:
gc.collect() del neighbors_np
gc.collect()
except NameError:
pass
# --- Script Execution --- # --- Script Execution ---
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.") parser = argparse.ArgumentParser(
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
)
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file") parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file") parser.add_argument(
parser.add_argument("--prune-embeddings", action="store_true", default=True, "output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
help="Prune embedding storage (write NULL storage marker)") )
parser.add_argument("--keep-embeddings", action="store_true", parser.add_argument(
help="Keep embedding storage (overrides --prune-embeddings)") "--prune-embeddings",
action="store_true",
default=True,
help="Prune embedding storage (write NULL storage marker)",
)
parser.add_argument(
"--keep-embeddings",
action="store_true",
help="Keep embedding storage (overrides --prune-embeddings)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -545,10 +719,12 @@ if __name__ == "__main__":
sys.exit(1) sys.exit(1)
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file): if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr) print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
sys.exit(1) sys.exit(1)
prune_embeddings = args.prune_embeddings and not args.keep_embeddings prune_embeddings = args.prune_embeddings and not args.keep_embeddings
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings) success = convert_hnsw_graph_to_csr(
args.input_index_file, args.output_csr_graph_file, prune_embeddings
)
if not success: if not success:
sys.exit(1) sys.exit(1)

View File

@@ -1,19 +1,19 @@
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
import shutil
import logging import logging
import os
import shutil
from pathlib import Path
from typing import Any, Literal
from leann.searcher_base import BaseSearcher import numpy as np
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendSearcherInterface, LeannBackendSearcherInterface,
) )
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -28,6 +28,12 @@ def get_metric_map():
} }
def normalize_l2(data: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(data, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return data / norms
@register_backend("hnsw") @register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface): class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -48,8 +54,14 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute:
if self.is_compact:
# TODO: support this case @andy
raise ValueError(
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
)
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
from . import faiss # type: ignore from . import faiss # type: ignore
path = Path(index_path) path = Path(index_path)
@@ -70,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index.hnsw.efConstruction = self.efConstruction index.hnsw.efConstruction = self.efConstruction
if self.distance_metric.lower() == "cosine": if self.distance_metric.lower() == "cosine":
faiss.normalize_L2(data) data = normalize_l2(data)
index.add(data.shape[0], faiss.swig_ptr(data)) index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
@@ -95,16 +107,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
# index_file_old = index_file.with_suffix(".old") # index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old)) # shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
logger.info( logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else: else:
# Clean up and fail fast # Clean up and fail fast
if csr_temp_file.exists(): if csr_temp_file.exists():
os.remove(csr_temp_file) os.remove(csr_temp_file)
raise RuntimeError( raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
"CSR conversion failed - cannot proceed with compact format"
)
class HNSWSearcher(BaseSearcher): class HNSWSearcher(BaseSearcher):
@@ -116,7 +124,9 @@ class HNSWSearcher(BaseSearcher):
) )
from . import faiss # type: ignore from . import faiss # type: ignore
self.distance_metric = self.meta.get("distance_metric", "mips").lower() self.distance_metric = (
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
)
metric_enum = get_metric_map().get(self.distance_metric) metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -142,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
self, self,
query: np.ndarray, query: np.ndarray,
top_k: int, top_k: int,
zmq_port: Optional[int] = None, zmq_port: int | None = None,
complexity: int = 64, complexity: int = 64,
beam_width: int = 1, beam_width: int = 1,
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
@@ -150,7 +160,7 @@ class HNSWSearcher(BaseSearcher):
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0, batch_size: int = 0,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Search for nearest neighbors using HNSW index. Search for nearest neighbors using HNSW index.
@@ -179,23 +189,29 @@ class HNSWSearcher(BaseSearcher):
raise RuntimeError("Recompute is required for pruned index.") raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings: if recompute_embeddings:
if zmq_port is None: if zmq_port is None:
raise ValueError( raise ValueError("zmq_port must be provided if recompute_embeddings is True")
"zmq_port must be provided if recompute_embeddings is True"
)
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if self.distance_metric == "cosine": if self.distance_metric == "cosine":
faiss.normalize_L2(query) query = normalize_l2(query)
params = faiss.SearchParametersHNSW() params = faiss.SearchParametersHNSW()
if zmq_port is not None: if zmq_port is not None:
params.zmq_port = ( params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.efSearch = complexity params.efSearch = complexity
params.beam_size = beam_width params.beam_size = beam_width
# For OpenAI embeddings with cosine distance, disable relative distance check
# This prevents early termination when all scores are in a narrow range
embedding_model = self.meta.get("embedding_model", "").lower()
if self.distance_metric == "cosine" and any(
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
):
params.check_relative_distance = False
else:
params.check_relative_distance = True
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio # PQ pruning: direct mapping to HNSW's pq_pruning_ratio
params.pq_pruning_ratio = prune_ratio params.pq_pruning_ratio = prune_ratio
@@ -205,9 +221,7 @@ class HNSWSearcher(BaseSearcher):
params.send_neigh_times_ratio = 0.0 params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional": elif pruning_strategy == "proportional":
params.local_prune = False params.local_prune = False
params.send_neigh_times_ratio = ( params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
1.0 # Any value > 1e-6 triggers proportional mode
)
else: # "global" else: # "global"
params.local_prune = False params.local_prune = False
params.send_neigh_times_ratio = 0.0 params.send_neigh_times_ratio = 0.0
@@ -228,8 +242,6 @@ class HNSWSearcher(BaseSearcher):
params, params,
) )
string_labels = [ string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -3,17 +3,17 @@ HNSW-specific embedding server
""" """
import argparse import argparse
import json
import logging
import os
import sys
import threading import threading
import time import time
import os
import zmq
import numpy as np
import msgpack
import json
from pathlib import Path from pathlib import Path
from typing import Optional
import sys import msgpack
import logging import numpy as np
import zmq
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -33,7 +33,7 @@ if not logger.handlers:
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: str | None = None,
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips", distance_metric: str = "mips",
@@ -52,8 +52,8 @@ def create_hnsw_embedding_server(
sys.path.insert(0, str(leann_core_path)) sys.path.insert(0, str(leann_core_path))
try: try:
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
logger.info("Successfully imported unified embedding computation module") logger.info("Successfully imported unified embedding computation module")
except ImportError as e: except ImportError as e:
@@ -78,10 +78,22 @@ def create_hnsw_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported") raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources # Load metadata to get passage sources
with open(passages_file, "r") as f: with open(passages_file) as f:
meta = json.load(f) meta = json.load(f)
passages = PassageManager(meta["passage_sources"]) # Convert relative paths to absolute paths based on metadata file location
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
passage_sources = []
for source in meta["passage_sources"]:
source_copy = source.copy()
# Convert relative paths to absolute paths
if not Path(source_copy["path"]).is_absolute():
source_copy["path"] = str(metadata_dir / source_copy["path"])
if not Path(source_copy["index_path"]).is_absolute():
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
passage_sources.append(source_copy)
passages = PassageManager(passage_sources)
logger.info( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
@@ -120,9 +132,7 @@ def create_hnsw_embedding_server(
response = embeddings.tolist() response = embeddings.tolist()
socket.send(msgpack.packb(response)) socket.send(msgpack.packb(response))
e2e_end = time.time() e2e_end = time.time()
logger.info( logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
continue continue
# Handle distance calculation requests # Handle distance calculation requests
@@ -148,17 +158,13 @@ def create_hnsw_embedding_server(
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
logger.error(f"Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
raise RuntimeError( raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e: except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings( embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
texts, model_name, mode=embedding_mode
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -172,18 +178,12 @@ def create_hnsw_embedding_server(
distances = -np.dot(embeddings, query_vector) distances = -np.dot(embeddings, query_vector)
response_payload = distances.flatten().tolist() response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb( response_bytes = msgpack.packb([response_payload], use_single_float=True)
[response_payload], use_single_float=True logger.debug(f"Sending distance response with {len(distances)} distances")
)
logger.debug(
f"Sending distance response with {len(distances)} distances"
)
socket.send(response_bytes) socket.send(response_bytes)
e2e_end = time.time() e2e_end = time.time()
logger.info( logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue continue
# Standard embedding request (passage ID lookup) # Standard embedding request (passage ID lookup)
@@ -208,9 +208,7 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data["text"]
if not txt: if not txt:
raise RuntimeError( raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
@@ -229,11 +227,9 @@ def create_hnsw_embedding_server(
logger.error( logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
) )
assert False raise AssertionError()
hidden_contiguous_f32 = np.ascontiguousarray( hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
embeddings, dtype=np.float32
)
response_payload = [ response_payload = [
list(hidden_contiguous_f32.shape), list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(), hidden_contiguous_f32.flatten().tolist(),

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.1.9" version = "0.1.16"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.1.9", "leann-core==0.1.16",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.1.9" version = "0.1.16"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
@@ -20,7 +20,26 @@ dependencies = [
"torch>=2.0.0", "torch>=2.0.0",
"sentence-transformers>=2.2.0", "sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0", "llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
]
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
] ]
[project.scripts] [project.scripts]

View File

@@ -8,10 +8,14 @@ if platform.system() == "Darwin":
os.environ["MKL_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0" os.environ["KMP_BLOCKTIME"] = "0"
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
if os.environ.get("CI") == "true":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends() autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"] __all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]

View File

@@ -4,27 +4,36 @@ with the correct, original embedding logic from the user's reference code.
""" """
import json import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging import logging
import pickle
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import numpy as np
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .interface import LeannBackendFactoryInterface
from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_registered_backends() -> list[str]:
"""Get list of registered backend names."""
return list(BACKEND_REGISTRY.keys())
def compute_embeddings( def compute_embeddings(
chunks: List[str], chunks: list[str],
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
port: Optional[int] = None, port: int | None = None,
is_build=False, is_build=False,
) -> np.ndarray: ) -> np.ndarray:
""" """
@@ -61,9 +70,7 @@ def compute_embeddings(
) )
def compute_embeddings_via_server( def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers. """Computes embeddings using sentence-transformers.
Args: Args:
@@ -73,9 +80,9 @@ def compute_embeddings_via_server(
logger.info( logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
) )
import zmq
import msgpack import msgpack
import numpy as np import numpy as np
import zmq
# Connect to embedding server # Connect to embedding server
context = zmq.Context() context = zmq.Context()
@@ -104,11 +111,11 @@ class SearchResult:
id: str id: str
score: float score: float
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
class PassageManager: class PassageManager:
def __init__(self, passage_sources: List[Dict[str, Any]]): def __init__(self, passage_sources: list[dict[str, Any]]):
self.offset_maps = {} self.offset_maps = {}
self.passage_files = {} self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup self.global_offset_map = {} # Combined map for fast lookup
@@ -117,8 +124,15 @@ class PassageManager:
assert source["type"] == "jsonl", "only jsonl is supported" assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"] passage_file = source["path"]
index_file = source["index_path"] # .idx file index_file = source["index_path"] # .idx file
# Fix path resolution for Colab and other environments
if not Path(index_file).is_absolute():
# If relative path, try to resolve it properly
index_file = str(Path(index_file).resolve())
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}") raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f: with open(index_file, "rb") as f:
offset_map = pickle.load(f) offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map self.offset_maps[passage_file] = offset_map
@@ -128,11 +142,11 @@ class PassageManager:
for passage_id, offset in offset_map.items(): for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset) self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]: def get_passage(self, passage_id: str) -> dict[str, Any]:
if passage_id in self.global_offset_map: if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id] passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed # Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f: with open(passage_file, encoding="utf-8") as f:
f.seek(offset) f.seek(offset)
return json.loads(f.readline()) return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}") raise KeyError(f"Passage ID not found: {passage_id}")
@@ -143,24 +157,92 @@ class LeannBuilder:
self, self,
backend_name: str, backend_name: str,
embedding_model: str = "facebook/contriever", embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None, dimensions: int | None = None,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get( backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
backend_name
)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory self.backend_factory = backend_factory
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): # Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
# OpenAI models
("openai", "text-embedding-ada-002"),
("openai", "text-embedding-3-small"),
("openai", "text-embedding-3-large"),
# Voyage AI models
("voyage", "voyage-2"),
("voyage", "voyage-3"),
("voyage", "voyage-large-2"),
("voyage", "voyage-multilingual-2"),
("voyage", "voyage-code-2"),
# Cohere models
("cohere", "embed-english-v3.0"),
("cohere", "embed-multilingual-v3.0"),
("cohere", "embed-english-light-v3.0"),
("cohere", "embed-multilingual-light-v3.0"),
}
# Also check for patterns in model names
is_normalized = False
current_model_lower = embedding_model.lower()
current_mode_lower = embedding_mode.lower()
# Check exact matches
for mode, model in normalized_embeddings_models:
if (current_mode_lower == mode and current_model_lower == model) or (
mode in current_mode_lower and model in current_model_lower
):
is_normalized = True
break
# Check patterns
if not is_normalized:
# OpenAI patterns
if "openai" in current_mode_lower or "openai" in current_model_lower:
if any(
pattern in current_model_lower
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
):
is_normalized = True
# Voyage patterns
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
is_normalized = True
# Cohere patterns
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
if "embed" in current_model_lower:
is_normalized = True
# Handle distance metric
if is_normalized and "distance_metric" not in backend_kwargs:
backend_kwargs["distance_metric"] = "cosine"
warnings.warn(
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
f"Automatically setting distance_metric='cosine' for optimal performance. "
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
UserWarning,
stacklevel=2,
)
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
current_metric = backend_kwargs.get("distance_metric", "mips")
warnings.warn(
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
f"'{embedding_model}' may lead to suboptimal search results. "
f"Consider using 'cosine' distance metric for better performance.",
UserWarning,
stacklevel=2,
)
self.backend_kwargs = backend_kwargs
self.chunks: list[dict[str, Any]] = []
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
passage_id = metadata.get("id", str(len(self.chunks))) passage_id = metadata.get("id", str(len(self.chunks)))
@@ -190,9 +272,7 @@ class LeannBuilder:
try: try:
from tqdm import tqdm from tqdm import tqdm
chunk_iterator = tqdm( chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError: except ImportError:
chunk_iterator = self.chunks chunk_iterator = self.chunks
@@ -222,9 +302,7 @@ class LeannBuilder:
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build( builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
embeddings, string_ids, index_path, **current_backend_kwargs
)
leann_meta_path = index_dir / f"{index_name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "1.0", "version": "1.0",
@@ -273,9 +351,7 @@ class LeannBuilder:
ids, embeddings = data ids, embeddings = data
if not isinstance(embeddings, np.ndarray): if not isinstance(embeddings, np.ndarray):
raise ValueError( raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]: if len(ids) != embeddings.shape[0]:
raise ValueError( raise ValueError(
@@ -287,9 +363,7 @@ class LeannBuilder:
if self.dimensions is None: if self.dimensions is None:
self.dimensions = embedding_dim self.dimensions = embedding_dim
elif self.dimensions != embedding_dim: elif self.dimensions != embedding_dim:
raise ValueError( raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
logger.info( logger.info(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions" f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
@@ -374,26 +448,24 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
logger.info( logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
f"Index built successfully from precomputed embeddings: {index_path}"
)
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
self.meta_path_str = f"{index_path}.meta.json" self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists(): if not Path(self.meta_path_str).exists():
raise FileNotFoundError( raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
f"Leann metadata file not found at {self.meta_path_str}" with open(self.meta_path_str, encoding="utf-8") as f:
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f) self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"] backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get( self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
"embedding_mode", "sentence-transformers"
)
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", [])) self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
@@ -415,7 +487,7 @@ class LeannSearcher:
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
**kwargs, **kwargs,
) -> List[SearchResult]: ) -> list[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:") logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'") logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}") logger.info(f" Top_k: {top_k}")
@@ -442,7 +514,7 @@ class LeannSearcher:
zmq_port=zmq_port, zmq_port=zmq_port,
) )
# logger.info(f" Generated embedding shape: {query_embedding.shape}") # logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds") # logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time() start_time = time.time()
@@ -457,17 +529,15 @@ class LeannSearcher:
zmq_port=zmq_port, zmq_port=zmq_port,
**kwargs, **kwargs,
) )
search_time = time.time() - start_time time.time() - start_time
# logger.info(f" Search time: {search_time} seconds") # logger.info(f" Search time: {search_time} seconds")
logger.info( logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = [] enriched_results = []
if "labels" in results and "distances" in results: if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:") logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate( for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0]) zip(results["labels"][0], results["distances"][0], strict=False)
): ):
try: try:
passage_data = self.passage_manager.get_passage(string_id) passage_data = self.passage_manager.get_passage(string_id)
@@ -487,7 +557,7 @@ class LeannSearcher:
RESET = "\033[0m" RESET = "\033[0m"
# Truncate text for display (first 100 chars) # Truncate text for display (first 100 chars)
display_text = passage_data['text'] display_text = passage_data["text"]
logger.info( logger.info(
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}" f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
) )
@@ -505,7 +575,7 @@ class LeannChat:
def __init__( def __init__(
self, self,
index_path: str, index_path: str,
llm_config: Optional[Dict[str, Any]] = None, llm_config: dict[str, Any] | None = None,
enable_warmup: bool = False, enable_warmup: bool = False,
**kwargs, **kwargs,
): ):
@@ -521,7 +591,7 @@ class LeannChat:
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = True, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[Dict[str, Any]] = None, llm_kwargs: dict[str, Any] | None = None,
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
**search_kwargs, **search_kwargs,
): ):

View File

@@ -4,11 +4,12 @@ This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode. supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
""" """
from abc import ABC, abstractmethod import difflib
from typing import Dict, Any, Optional, List
import logging import logging
import os import os
import difflib from abc import ABC, abstractmethod
from typing import Any
import torch import torch
# Configure logging # Configure logging
@@ -16,10 +17,11 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_ollama_models() -> List[str]: def check_ollama_models() -> list[str]:
"""Check available Ollama models and return a list""" """Check available Ollama models and return a list"""
try: try:
import requests import requests
response = requests.get("http://localhost:11434/api/tags", timeout=5) response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
@@ -36,12 +38,13 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
(model_exists, available_tags): bool and list of matching tags (model_exists, available_tags): bool and list of matching tags
""" """
try: try:
import requests
import re import re
import requests
# Split model name and tag # Split model name and tag
if ':' in model_name: if ":" in model_name:
base_model, requested_tag = model_name.split(':', 1) base_model, requested_tag = model_name.split(":", 1)
else: else:
base_model, requested_tag = model_name, None base_model, requested_tag = model_name, None
@@ -62,7 +65,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
return True, [] # Base model exists but can't get tags return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts # Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+' tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
raw_tags = re.findall(tag_pattern, tags_response.text) raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates # Clean up tags - remove HTML artifacts and duplicates
@@ -70,7 +73,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
seen = set() seen = set()
for tag in raw_tags: for tag in raw_tags:
# Skip if it looks like HTML (contains < or >) # Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag: if "<" in tag or ">" in tag:
continue continue
if tag not in seen: if tag not in seen:
seen.add(tag) seen.add(tag)
@@ -91,7 +94,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
return True, [] return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]: def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models""" """Use intelligent fuzzy search for Ollama models"""
if not available_models: if not available_models:
return [] return []
@@ -104,7 +107,9 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
suggestions.extend(exact_matches) suggestions.extend(exact_matches)
# 2. Starts with query # 2. Starts with query
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions] starts_with = [
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
suggestions.extend(starts_with) suggestions.extend(starts_with)
# 3. Contains query # 3. Contains query
@@ -114,24 +119,25 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
# 4. Base model name matching (remove version numbers) # 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str: def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')""" """Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(':')[0].split('-')[0] return model_name.split(":")[0].split("-")[0]
query_base = get_base_name(query_lower) query_base = get_base_name(query_lower)
base_matches = [ base_matches = [
m for m in available_models m
for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions if get_base_name(m.lower()) == query_base and m not in suggestions
] ]
suggestions.extend(base_matches) suggestions.extend(base_matches)
# 5. Family/variant matching # 5. Family/variant matching
model_families = { model_families = {
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'], "llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
'qwen': ['qwen', 'qwen2', 'qwen3'], "qwen": ["qwen", "qwen2", "qwen3"],
'gemma': ['gemma', 'gemma2'], "gemma": ["gemma", "gemma2"],
'phi': ['phi', 'phi2', 'phi3'], "phi": ["phi", "phi2", "phi3"],
'mistral': ['mistral', 'mixtral', 'openhermes'], "mistral": ["mistral", "mixtral", "openhermes"],
'dolphin': ['dolphin', 'openchat'], "dolphin": ["dolphin", "openchat"],
'deepseek': ['deepseek', 'deepseek-coder'] "deepseek": ["deepseek", "deepseek-coder"],
} }
query_family = None query_family = None
@@ -143,7 +149,8 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
if query_family: if query_family:
family_variants = model_families[query_family] family_variants = model_families[query_family]
family_matches = [ family_matches = [
m for m in available_models m
for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions if any(variant in m.lower() for variant in family_variants) and m not in suggestions
] ]
suggestions.extend(family_matches) suggestions.extend(family_matches)
@@ -162,15 +169,13 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
# Remove this too - no need for fallback # Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]: def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
"""Use difflib to find similar model names""" """Use difflib to find similar model names"""
if not available_models: if not available_models:
return [] return []
# Get close matches using fuzzy matching # Get close matches using fuzzy matching
suggestions = difflib.get_close_matches( suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
invalid_model, available_models, n=3, cutoff=0.3
)
return suggestions return suggestions
@@ -178,13 +183,14 @@ def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading""" """Quick check if HuggingFace model exists without downloading"""
try: try:
from huggingface_hub import model_info from huggingface_hub import model_info
model_info(model_name) model_info(model_name)
return True return True
except Exception: except Exception:
return False return False
def get_popular_hf_models() -> List[str]: def get_popular_hf_models() -> list[str]:
"""Return a list of popular HuggingFace models for suggestions""" """Return a list of popular HuggingFace models for suggestions"""
try: try:
from huggingface_hub import list_models from huggingface_hub import list_models
@@ -194,15 +200,15 @@ def get_popular_hf_models() -> List[str]:
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=20 # Get top 20 most downloaded limit=20, # Get top 20 most downloaded
) )
# Extract model names and filter for chat/conversation models # Extract model names and filter for chat/conversation models
model_names = [] model_names = []
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant'] chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
for model in models: for model in models:
model_name = model.id if hasattr(model, 'id') else str(model) model_name = model.id if hasattr(model, "id") else str(model)
# Prioritize models with chat-related keywords # Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords): if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name) model_names.append(model_name)
@@ -216,7 +222,7 @@ def get_popular_hf_models() -> List[str]:
return _get_fallback_hf_models() return _get_fallback_hf_models()
def _get_fallback_hf_models() -> List[str]: def _get_fallback_hf_models() -> list[str]:
"""Fallback list of popular HuggingFace models""" """Fallback list of popular HuggingFace models"""
return [ return [
"microsoft/DialoGPT-medium", "microsoft/DialoGPT-medium",
@@ -228,11 +234,11 @@ def _get_fallback_hf_models() -> List[str]:
"facebook/blenderbot_small-90M", "facebook/blenderbot_small-90M",
"microsoft/phi-1_5", "microsoft/phi-1_5",
"facebook/opt-350m", "facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B" "EleutherAI/gpt-neo-1.3B",
] ]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]: def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions""" """Use HuggingFace Hub's native fuzzy search for model suggestions"""
try: try:
from huggingface_hub import list_models from huggingface_hub import list_models
@@ -243,10 +249,10 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=limit limit=limit,
) )
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models] model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
# If direct search doesn't return enough results, try some variations # If direct search doesn't return enough results, try some variations
if len(model_names) < 3: if len(model_names) < 3:
@@ -254,17 +260,17 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
variations = [] variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5") # Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '') base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
if base_query != query.lower(): if base_query != query.lower():
variations.append(base_query) variations.append(base_query)
# Try common model name patterns # Try common model name patterns
if 'gpt' in query.lower(): if "gpt" in query.lower():
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT']) variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
elif 'llama' in query.lower(): elif "llama" in query.lower():
variations.extend(['llama2', 'alpaca', 'vicuna']) variations.extend(["llama2", "alpaca", "vicuna"])
elif 'bert' in query.lower(): elif "bert" in query.lower():
variations.extend(['roberta', 'distilbert', 'albert']) variations.extend(["roberta", "distilbert", "albert"])
# Search with variations # Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
@@ -274,11 +280,13 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=3 limit=3,
) )
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models] var_names = [
model.id if hasattr(model, "id") else str(model) for model in var_models
]
model_names.extend(var_names) model_names.extend(var_names)
except: except Exception:
continue continue
# Remove duplicates while preserving order # Remove duplicates while preserving order
@@ -296,12 +304,12 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
return [] return []
def search_hf_models(query: str, limit: int = 10) -> List[str]: def search_hf_models(query: str, limit: int = 10) -> list[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)""" """Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit) return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]: def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
available_models = check_ollama_models() available_models = check_ollama_models()
@@ -313,7 +321,7 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
if model_exists_remotely and model_name in available_tags: if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it # Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n" error_msg += "\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n" error_msg += f" ollama pull {model_name}\n"
# Show local alternatives # Show local alternatives
@@ -325,10 +333,12 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
elif model_exists_remotely and available_tags: elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags # Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0] base_model = model_name.split(":")[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available." error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
error_msg += f"\n\nAvailable {base_model} models you can install:\n" error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1): for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n" error_msg += f" {i}. ollama pull {tag}\n"
@@ -364,7 +374,9 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
if model_name in available_tags: if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model" error_msg += f"\n ollama pull {model_name} # Install requested model"
else: else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant" error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += "\n https://ollama.com/library # Browse available models" error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg return error_msg
@@ -462,17 +474,16 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error( logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running." f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
import requests
import json import json
import requests
full_url = f"{self.host}/api/generate" full_url = f"{self.host}/api/generate"
payload = { payload = {
"model": self.model, "model": self.model,
@@ -482,7 +493,7 @@ class OllamaChat(LLMInterface):
} }
logger.debug(f"Sending request to Ollama: {payload}") logger.debug(f"Sending request to Ollama: {payload}")
try: try:
logger.info(f"Sending request to Ollama and waiting for response...") logger.info("Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload)) response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status() response.raise_for_status()
@@ -513,8 +524,8 @@ class HFChat(LLMInterface):
raise ValueError(model_error) raise ValueError(model_error)
try: try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'." "The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
@@ -531,14 +542,41 @@ class HFChat(LLMInterface):
self.device = "cpu" self.device = "cpu"
logger.info("No GPU detected. Using CPU.") logger.info("No GPU detected. Using CPU.")
# Load tokenizer and model # Load tokenizer and model with timeout protection
self.tokenizer = AutoTokenizer.from_pretrained(model_name) try:
self.model = AutoModelForCausalLM.from_pretrained( import signal
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32, def timeout_handler(signum, frame):
device_map="auto" if self.device != "cpu" else None, raise TimeoutError("Model download/loading timed out")
trust_remote_code=True
) # Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
logger.info(f"Successfully loaded {model_name}")
finally:
signal.alarm(0) # Cancel the alarm
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
except TimeoutError:
logger.error(f"Model loading timed out for {model_name}")
raise RuntimeError(
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
)
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise
# Move model to device if not using device_map # Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model): if self.device != "cpu" and "device_map" not in str(self.model):
@@ -549,7 +587,7 @@ class HFChat(LLMInterface):
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
print('kwargs in HF: ', kwargs) print("kwargs in HF: ", kwargs)
# Check if this is a Qwen model and add /no_think by default # Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower() is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
@@ -564,9 +602,7 @@ class HFChat(LLMInterface):
if hasattr(self.tokenizer, "apply_chat_template"): if hasattr(self.tokenizer, "apply_chat_template"):
try: try:
formatted_prompt = self.tokenizer.apply_chat_template( formatted_prompt = self.tokenizer.apply_chat_template(
messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False,
add_generation_prompt=True
) )
except Exception as e: except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}") logger.warning(f"Chat template failed, using raw prompt: {e}")
@@ -581,7 +617,7 @@ class HFChat(LLMInterface):
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
max_length=2048 max_length=2048,
) )
# Move inputs to device # Move inputs to device
@@ -607,13 +643,10 @@ class HFChat(LLMInterface):
# Generate # Generate
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(**inputs, **generation_config)
**inputs,
**generation_config
)
# Decode response # Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip() return response.strip()
@@ -622,7 +655,7 @@ class HFChat(LLMInterface):
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models.""" """LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
self.model = model self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.api_key = api_key or os.getenv("OPENAI_API_KEY")
@@ -649,11 +682,7 @@ class OpenAIChat(LLMInterface):
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000), "max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7), "temperature": kwargs.get("temperature", 0.7),
**{ **{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
} }
logger.info(f"Sending request to OpenAI with model {self.model}") logger.info(f"Sending request to OpenAI with model {self.model}")
@@ -675,7 +704,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context." return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface: def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
""" """
Factory function to get an LLM interface based on configuration. Factory function to get an LLM interface based on configuration.

View File

@@ -5,7 +5,38 @@ from pathlib import Path
from llama_index.core import SimpleDirectoryReader from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat from .api import LeannBuilder, LeannChat, LeannSearcher
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
"""Extract text from PDF using PyMuPDF for better quality."""
try:
import fitz # PyMuPDF
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text
except ImportError:
# Fallback to default reader
return None
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
"""Extract text from PDF using pdfplumber for better quality."""
try:
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
except ImportError:
# Fallback to default reader
return None
class LeannCLI: class LeannCLI:
@@ -45,18 +76,12 @@ Examples:
# Build command # Build command
build_parser = subparsers.add_parser("build", help="Build document index") build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name") build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument( build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
"--docs", type=str, required=True, help="Documents directory"
)
build_parser.add_argument( build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"] "--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
) )
build_parser.add_argument( build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
"--embedding-model", type=str, default="facebook/contriever" build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild"
)
build_parser.add_argument("--graph-degree", type=int, default=32) build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64) build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1) build_parser.add_argument("--num-threads", type=int, default=1)
@@ -102,7 +127,7 @@ Examples:
) )
# List command # List command
list_parser = subparsers.add_parser("list", help="List all indexes") subparsers.add_parser("list", help="List all indexes")
return parser return parser
@@ -110,17 +135,13 @@ Examples:
print("Stored LEANN indexes:") print("Stored LEANN indexes:")
if not self.indexes_dir.exists(): if not self.indexes_dir.exists():
print( print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return return
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()] index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
if not index_dirs: if not index_dirs:
print( print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
return return
print(f"Found {len(index_dirs)} indexes:") print(f"Found {len(index_dirs)} indexes:")
@@ -130,27 +151,58 @@ Examples:
print(f" {i}. {index_name} [{status}]") print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name): if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json" index_dir / "documents.leann.meta.json"
size_mb = sum( size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
f.stat().st_size for f in index_dir.iterdir() if f.is_file() 1024 * 1024
) / (1024 * 1024) )
print(f" Size: {size_mb:.1f} MB") print(f" Size: {size_mb:.1f} MB")
if index_dirs: if index_dirs:
example_name = index_dirs[0].name example_name = index_dirs[0].name
print(f"\nUsage:") print("\nUsage:")
print(f' leann search {example_name} "your query"') print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive") print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str): def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...") print(f"Loading documents from {docs_dir}...")
documents = SimpleDirectoryReader( # Try to use better PDF parsers first
documents = []
docs_path = Path(docs_dir)
for file_path in docs_path.rglob("*.pdf"):
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
# Load other file types with default reader
other_docs = SimpleDirectoryReader(
docs_dir, docs_dir,
recursive=True, recursive=True,
encoding="utf-8", encoding="utf-8",
required_exts=[".pdf", ".txt", ".md", ".docx"], required_exts=[".txt", ".md", ".docx"],
).load_data(show_progress=True) ).load_data(show_progress=True)
documents.extend(other_docs)
all_texts = [] all_texts = []
for doc in documents: for doc in documents:

View File

@@ -4,11 +4,12 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance Preserves all optimization parameters to ensure performance
""" """
import numpy as np
import torch
from typing import List, Dict, Any
import logging import logging
import os import os
from typing import Any
import numpy as np
import torch
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,11 +18,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level) logger.setLevel(log_level)
# Global model cache to avoid repeated loading # Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {} _model_cache: dict[str, Any] = {}
def compute_embeddings( def compute_embeddings(
texts: List[str], texts: list[str],
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
is_build: bool = False, is_build: bool = False,
@@ -59,7 +60,7 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers( def compute_embeddings_sentence_transformers(
texts: List[str], texts: list[str],
model_name: str, model_name: str,
use_fp16: bool = True, use_fp16: bool = True,
device: str = "auto", device: str = "auto",
@@ -114,9 +115,7 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
else: else:
logger.info( logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
@@ -134,9 +133,7 @@ def compute_embeddings_sentence_transformers(
if hasattr(torch.mps, "set_per_process_memory_fraction"): if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9) torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError: except AttributeError:
logger.warning( logger.warning("Some MPS optimizations not available in this PyTorch version")
"Some MPS optimizations not available in this PyTorch version"
)
elif device == "cpu": elif device == "cpu":
# TODO: Haven't tested this yet # TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4)) torch.set_num_threads(min(8, os.cpu_count() or 4))
@@ -226,25 +223,22 @@ def compute_embeddings_sentence_transformers(
device=device, device=device,
) )
logger.info( logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results # Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError( raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray: def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API""" """Compute embeddings using OpenAI API"""
try: try:
import openai
import os import os
import openai
except ImportError as e: except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}") raise ImportError(f"OpenAI package not installed: {e}")
@@ -264,9 +258,10 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
logger.info( logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
) )
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length # OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size max_batch_size = 1000 # Conservative batch size
all_embeddings = [] all_embeddings = []
try: try:
@@ -293,15 +288,12 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
raise raise
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info( logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}" print(f"len of embeddings: {len(embeddings)}")
)
return embeddings return embeddings
def compute_embeddings_mlx( def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model.""" """Computes embeddings using an MLX model."""
try: try:

View File

@@ -1,12 +1,12 @@
import time
import atexit import atexit
import logging
import os
import socket import socket
import subprocess import subprocess
import sys import sys
import os import time
import logging
from pathlib import Path from pathlib import Path
from typing import Optional
import psutil import psutil
# Set up logging based on environment variable # Set up logging based on environment variable
@@ -18,6 +18,24 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
def _check_port(port: int) -> bool: def _check_port(port: int) -> bool:
"""Check if a port is in use""" """Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -164,8 +182,8 @@ class EmbeddingServerManager:
e.g., "leann_backend_diskann.embedding_server" e.g., "leann_backend_diskann.embedding_server"
""" """
self.backend_module_name = backend_module_name self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None self.server_process: subprocess.Popen | None = None
self.server_port: Optional[int] = None self.server_port: int | None = None
self._atexit_registered = False self._atexit_registered = False
def start_server( def start_server(
@@ -175,68 +193,69 @@ class EmbeddingServerManager:
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
""" """Start the embedding server."""
Starts the embedding server process.
Args:
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server.
Returns:
tuple[bool, int]: (success, actual_port_used)
"""
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server # Check if we have a compatible server already running
if self._has_compatible_running_server(model_name, passages_file): if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, ( logger.info("Found compatible running server!")
"a compatible running server should set server_port" return True, port
)
return True, self.server_port
# Find available port (compatible or free) # For Colab environment, use a different strategy
try: if _is_colab_environment():
actual_port, is_compatible = _find_compatible_port_or_next_available( logger.info("Detected Colab environment, using alternative startup strategy")
port, model_name, passages_file return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
)
except RuntimeError as e: # Find a compatible port or next available
logger.error(str(e)) actual_port, is_compatible = _find_compatible_port_or_next_available(
return False, port port, model_name, passages_file
)
if is_compatible: if is_compatible:
logger.info(f"Using existing compatible server on port {actual_port}") logger.info(f"Found compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port return True, actual_port
if actual_port != port: # Start a new server
logger.info(f"Using port {actual_port} instead of {port}")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server( def _start_server_colab(
self, model_name: str, passages_file: str self,
) -> bool: port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
logger.info(f"Starting server on port {actual_port} for Colab environment")
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server.""" """Check if we have a compatible running server."""
if not ( if not (self.server_process and self.server_process.poll() is None and self.server_port):
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False return False
if _check_process_matches_config(self.server_port, model_name, passages_file): if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info( logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
return True return True
logger.info( logger.info("Existing server process is incompatible. Should start a new server.")
"Existing server process is incompatible. Should start a new server."
)
return False return False
def _start_new_server( def _start_new_server(
@@ -274,6 +293,8 @@ class EmbeddingServerManager:
command.extend(["--passages-file", str(passages_file)]) command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers": if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode]) command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
return command return command
@@ -333,13 +354,21 @@ class EmbeddingServerManager:
self.server_process.terminate() self.server_process.terminate()
try: try:
self.server_process.wait(timeout=5) self.server_process.wait(timeout=3)
logger.info(f"Server process {self.server_process.pid} terminated.") logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.warning( logger.warning(
f"Server process {self.server_process.pid} did not terminate gracefully, killing it." f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
) )
self.server_process.kill() self.server_process.kill()
try:
self.server_process.wait(timeout=2)
logger.info(f"Server process {self.server_process.pid} killed successfully.")
except subprocess.TimeoutExpired:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
# Don't hang indefinitely
# Clean up process resources to prevent resource tracker warnings # Clean up process resources to prevent resource tracker warnings
try: try:
@@ -348,3 +377,45 @@ class EmbeddingServerManager:
pass pass
self.server_process = None self.server_process = None
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback
if not self._atexit_registered:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error("Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port

View File

@@ -1,15 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Literal
import numpy as np import numpy as np
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build( def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index """Build index
Args: Args:
@@ -35,9 +34,7 @@ class LeannBackendSearcherInterface(ABC):
pass pass
@abstractmethod @abstractmethod
def _ensure_server_running( def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
self, passages_source_file: str, port: Optional[int], **kwargs
) -> int:
"""Ensure server is running""" """Ensure server is running"""
pass pass
@@ -51,9 +48,9 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None, zmq_port: int | None = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Search for nearest neighbors """Search for nearest neighbors
Args: Args:
@@ -77,7 +74,7 @@ class LeannBackendSearcherInterface(ABC):
self, self,
query: str, query: str,
use_server_if_available: bool = True, use_server_if_available: bool = True,
zmq_port: Optional[int] = None, zmq_port: int | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Compute embedding for a query string """Compute embedding for a query string

View File

@@ -1,13 +1,13 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib import importlib
import importlib.metadata import importlib.metadata
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {} BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str): def register_backend(name: str):
@@ -31,13 +31,11 @@ def autodiscover_backends():
backend_module_name = dist_name.replace("-", "_") backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name) discovered_backends.append(backend_module_name)
for backend_module_name in sorted( for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
discovered_backends
): # sort for deterministic loading
try: try:
importlib.import_module(backend_module_name) importlib.import_module(backend_module_name)
# Registration message is printed by the decorator # Registration message is printed by the decorator
except ImportError as e: except ImportError:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}") # print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass pass
# print("INFO: Backend auto-discovery finished.") # print("INFO: Backend auto-discovery finished.")

View File

@@ -1,7 +1,7 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Literal, Optional from typing import Any, Literal
import numpy as np import numpy as np
@@ -38,9 +38,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.embedding_model = self.meta.get("embedding_model") self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model: if not self.embedding_model:
print( print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
@@ -48,39 +46,40 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
backend_module_name=backend_module_name, backend_module_name=backend_module_name,
) )
def _load_meta(self) -> Dict[str, Any]: def _load_meta(self) -> dict[str, Any]:
"""Loads the metadata file associated with the index.""" """Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file. # This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, "r", encoding="utf-8") as f: with open(meta_path, encoding="utf-8") as f:
return json.load(f) return json.load(f)
def _ensure_server_running( def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
self, passages_source_file: str, port: int, **kwargs
) -> int:
""" """
Ensures the embedding server is running if recompute is needed. Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses. This is a helper for subclasses.
""" """
if not self.embedding_model: if not self.embedding_model:
raise ValueError( raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
"Cannot use recompute mode without 'embedding_model' in meta.json."
) # Get distance_metric from meta if not provided in kwargs
distance_metric = (
kwargs.get("distance_metric")
or self.meta.get("backend_kwargs", {}).get("distance_metric")
or "mips"
)
server_started, actual_port = self.embedding_server_manager.start_server( server_started, actual_port = self.embedding_server_manager.start_server(
port=port, port=port,
model_name=self.embedding_model, model_name=self.embedding_model,
embedding_mode=self.embedding_mode, embedding_mode=self.embedding_mode,
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"), distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
) )
if not server_started: if not server_started:
raise RuntimeError( raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
f"Failed to start embedding server on port {actual_port}"
)
return actual_port return actual_port
@@ -109,9 +108,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
# on that port? # on that port?
# Ensure we have a server with passages_file for compatibility # Ensure we have a server with passages_file for compatibility
passages_source_file = ( passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
self.index_dir / f"{self.index_path.name}.meta.json"
)
# Convert to absolute path to ensure server can find it # Convert to absolute path to ensure server can find it
zmq_port = self._ensure_server_running( zmq_port = self._ensure_server_running(
str(passages_source_file.resolve()), zmq_port str(passages_source_file.resolve()), zmq_port
@@ -132,8 +129,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server.""" """Compute embeddings using the ZMQ embedding server."""
import zmq
import msgpack import msgpack
import zmq
try: try:
context = zmq.Context() context = zmq.Context()
@@ -172,9 +169,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = False, recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None, zmq_port: int | None = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Search for the top_k nearest neighbors of the query vector. Search for the top_k nearest neighbors of the query vector.

View File

@@ -5,36 +5,32 @@ LEANN is a revolutionary vector database that democratizes personal AI. Transfor
## Installation ## Installation
```bash ```bash
# Default installation (HNSW backend, recommended) # Default installation (includes both HNSW and DiskANN backends)
uv pip install leann uv pip install leann
# With DiskANN backend (for large-scale deployments)
uv pip install leann[diskann]
``` ```
## Quick Start ## Quick Start
```python ```python
from leann import LeannBuilder, LeannSearcher, LeannChat from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index # Build an index (choose backend: "hnsw" or "diskann")
builder = LeannBuilder(backend_name="hnsw") builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.") builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.build_index("my_index.leann") builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search # Search
searcher = LeannSearcher("my_index.leann") searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("storage savings", top_k=3) results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data # Chat with your data
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"}) chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?") response = chat.ask("How much storage does LEANN save?", top_k=1)
``` ```
## Documentation
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
## License ## License
MIT License MIT License

View File

@@ -7,6 +7,6 @@ A revolutionary vector database that democratizes personal AI.
__version__ = "0.1.0" __version__ = "0.1.0"
# Re-export main API from leann-core # Re-export main API from leann-core
from leann_core import LeannBuilder, LeannSearcher, LeannChat from leann_core import LeannBuilder, LeannChat, LeannSearcher
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"] __all__ = ["LeannBuilder", "LeannChat", "LeannSearcher"]

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.1.9" version = "0.1.16"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
@@ -24,19 +24,16 @@ classifiers = [
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
] ]
# Default installation: core + hnsw # Default installation: core + hnsw + diskann
dependencies = [ dependencies = [
"leann-core>=0.1.0", "leann-core>=0.1.0",
"leann-backend-hnsw>=0.1.0", "leann-backend-hnsw>=0.1.0",
]
[project.optional-dependencies]
diskann = [
"leann-backend-diskann>=0.1.0", "leann-backend-diskann>=0.1.0",
] ]
[project.optional-dependencies]
# All backends now included by default
[project.urls] [project.urls]
Homepage = "https://github.com/yourusername/leann" Repository = "https://github.com/yichuan-w/LEANN"
Documentation = "https://leann.readthedocs.io" Issues = "https://github.com/yichuan-w/LEANN/issues"
Repository = "https://github.com/yourusername/leann"
Issues = "https://github.com/yourusername/leann/issues"

View File

@@ -1,22 +1,23 @@
import json import json
import typer
from pathlib import Path
import requests
from tqdm import tqdm
import xml.etree.ElementTree as ET
from typing_extensions import Annotated
import sqlite3 import sqlite3
import xml.etree.ElementTree as ElementTree
from pathlib import Path
from typing import Annotated
import requests
import typer
from tqdm import tqdm
app = typer.Typer() app = typer.Typer()
def get_safe_path(s: str) -> str: def get_safe_path(s: str) -> str:
""" """
Remove invalid characters to sanitize a path. Remove invalid characters to sanitize a path.
:param s: str to sanitize :param s: str to sanitize
:returns: sanitized str :returns: sanitized str
""" """
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace( ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(" ", "")
' ', '')
for i in ban_chars: for i in ban_chars:
s = s.replace(i, "") s = s.replace(i, "")
return s return s
@@ -25,36 +26,40 @@ def get_safe_path(s: str) -> str:
def process_history(history: str): def process_history(history: str):
if history.startswith("<?xml") or history.startswith("<msg>"): if history.startswith("<?xml") or history.startswith("<msg>"):
try: try:
root = ET.fromstring(history) root = ElementTree.fromstring(history)
title = root.find('.//title').text if root.find('.//title') is not None else None title = root.find(".//title").text if root.find(".//title") is not None else None
quoted = root.find('.//refermsg/content').text if root.find('.//refermsg/content') is not None else None quoted = (
root.find(".//refermsg/content").text
if root.find(".//refermsg/content") is not None
else None
)
if title and quoted: if title and quoted:
return { return {"title": title, "quoted": process_history(quoted)}
"title": title,
"quoted": process_history(quoted)
}
if title: if title:
return title return title
except Exception: except Exception:
return history return history
return history return history
def get_message(history: dict | str): def get_message(history: dict | str):
if isinstance(history, dict): if isinstance(history, dict):
if 'title' in history: if "title" in history:
return history['title'] return history["title"]
else: else:
return history return history
def export_chathistory(user_id: str): def export_chathistory(user_id: str):
res = requests.get("http://localhost:48065/wechat/chatlog", params={ res = requests.get(
"userId": user_id, "http://localhost:48065/wechat/chatlog",
"count": 100000 params={"userId": user_id, "count": 100000},
}).json() ).json()
for i in range(len(res['chatLogs'])): for i in range(len(res["chatLogs"])):
res['chatLogs'][i]['content'] = process_history(res['chatLogs'][i]['content']) res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
res['chatLogs'][i]['message'] = get_message(res['chatLogs'][i]['content']) res["chatLogs"][i]["message"] = get_message(res["chatLogs"][i]["content"])
return res['chatLogs'] return res["chatLogs"]
@app.command() @app.command()
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]): def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
@@ -64,7 +69,7 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
if not dest.is_dir(): if not dest.is_dir():
if not dest.exists(): if not dest.exists():
inp = typer.prompt("Destination path does not exist, create it? (y/n)") inp = typer.prompt("Destination path does not exist, create it? (y/n)")
if inp.lower() == 'y': if inp.lower() == "y":
dest.mkdir(parents=True) dest.mkdir(parents=True)
else: else:
typer.echo("Aborted.", err=True) typer.echo("Aborted.", err=True)
@@ -77,12 +82,12 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
exported_count = 0 exported_count = 0
for user in tqdm(all_users): for user in tqdm(all_users):
try: try:
usr_chatlog = export_chathistory(user['arg']) usr_chatlog = export_chathistory(user["arg"])
# Only write file if there are messages # Only write file if there are messages
if len(usr_chatlog) > 0: if len(usr_chatlog) > 0:
out_path = dest/get_safe_path((user['title'] or "")+"-"+user['arg']+'.json') out_path = dest / get_safe_path((user["title"] or "") + "-" + user["arg"] + ".json")
with open(out_path, 'w', encoding='utf-8') as f: with open(out_path, "w", encoding="utf-8") as f:
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2) json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
exported_count += 1 exported_count += 1
except Exception as e: except Exception as e:
@@ -91,23 +96,43 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
print(f"Exported {exported_count} users' chat history to {dest} in json.") print(f"Exported {exported_count} users' chat history to {dest} in json.")
@app.command() @app.command()
def export_sqlite(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path("chatlog.db")): def export_sqlite(
dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path(
"chatlog.db"
),
):
""" """
Export all users' chat history to a sqlite database. Export all users' chat history to a sqlite database.
""" """
connection = sqlite3.connect(dest) connection = sqlite3.connect(dest)
cursor = connection.cursor() cursor = connection.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)") cursor.execute(
"CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)") cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)") cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json() all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
for user in tqdm(all_users): for user in tqdm(all_users):
cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user['arg'], user['title'])) cursor.execute(
usr_chatlog = export_chathistory(user['arg']) "INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
(user["arg"], user["title"]),
)
usr_chatlog = export_chathistory(user["arg"])
for msg in usr_chatlog: for msg in usr_chatlog:
cursor.execute("INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)", (user['arg'], msg['fromUser'], msg['toUser'], msg['message'], msg['createTime'], str(msg['content']))) cursor.execute(
"INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)",
(
user["arg"],
msg["fromUser"],
msg["toUser"],
msg["message"],
msg["createTime"],
str(msg["content"]),
),
)
connection.commit() connection.commit()

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-workspace" name = "leann-workspace"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.10" requires-python = ">=3.9"
dependencies = [ dependencies = [
"leann-core", "leann-core",
@@ -25,14 +25,21 @@ dependencies = [
"requests>=2.25.0", "requests>=2.25.0",
"sentence-transformers>=2.2.0", "sentence-transformers>=2.2.0",
"openai>=1.0.0", "openai>=1.0.0",
# PDF parsing dependencies - essential for document processing
"PyPDF2>=3.0.0", "PyPDF2>=3.0.0",
"pdfplumber>=0.11.0",
"pymupdf>=1.26.0",
"pypdfium2>=4.30.0",
# LlamaIndex core and readers - updated versions
"llama-index>=0.12.44", "llama-index>=0.12.44",
"llama-index-readers-docling", "llama-index-readers-file>=0.4.0", # Essential for PDF parsing
"llama-index-node-parser-docling", # "llama-index-readers-docling", # Requires Python >= 3.10
"ipykernel==6.29.5", # "llama-index-node-parser-docling", # Requires Python >= 3.10
"msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0", "llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5", "llama-index-embeddings-huggingface>=0.5.5",
# Other dependencies
"ipykernel==6.29.5",
"msgpack>=1.1.1",
"mlx>=0.26.3; sys_platform == 'darwin'", "mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'", "mlx-lm>=0.26.0; sys_platform == 'darwin'",
"psutil>=5.8.0", "psutil>=5.8.0",
@@ -42,16 +49,35 @@ dependencies = [
dev = [ dev = [
"pytest>=7.0", "pytest>=7.0",
"pytest-cov>=4.0", "pytest-cov>=4.0",
"pytest-xdist>=3.0", # For parallel test execution
"black>=23.0", "black>=23.0",
"ruff>=0.1.0", "ruff>=0.1.0",
"matplotlib", "matplotlib",
"huggingface-hub>=0.20.0", "huggingface-hub>=0.20.0",
"pre-commit>=3.5.0",
]
test = [
"pytest>=7.0",
"pytest-timeout>=2.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0",
"python-dotenv>=1.0.0",
"sentence-transformers>=2.2.0",
] ]
diskann = [ diskann = [
"leann-backend-diskann", "leann-backend-diskann",
] ]
# Add a new optional dependency group for document processing
documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents
"openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing
]
[tool.setuptools] [tool.setuptools]
py-modules = [] py-modules = []
@@ -60,3 +86,71 @@ py-modules = []
leann-core = { path = "packages/leann-core", editable = true } leann-core = { path = "packages/leann-core", editable = true }
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true } leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true } leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
[tool.ruff]
target-version = "py310"
line-length = 100
extend-exclude = [
"third_party",
"*.egg-info",
"__pycache__",
".git",
".venv",
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"N", # pep8-naming
"RUF", # ruff-specific rules
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise without from
"N812", # lowercase imported as non-lowercase
"N806", # variable in function should be lowercase
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
]
[tool.ruff.lint.per-file-ignores]
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[dependency-groups]
dev = [
"ruff>=0.12.4",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key",
]
timeout = 600
addopts = [
"-v",
"--tb=short",
"--strict-markers",
"--disable-warnings",
]
env = [
"HF_HUB_DISABLE_SYMLINKS=1",
"TOKENIZERS_PARALLELISM=false",
]

View File

@@ -1,5 +1,6 @@
import os import os
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from leann.api import LeannBuilder, LeannChat
# Define the path for our new MLX-based index # Define the path for our new MLX-based index
INDEX_PATH = "./mlx_diskann_index/leann" INDEX_PATH = "./mlx_diskann_index/leann"
@@ -38,7 +39,5 @@ chat = LeannChat(index_path=INDEX_PATH)
# add query # add query
query = "MLX is an array framework for machine learning on Apple silicon." query = "MLX is an array framework for machine learning on Apple silicon."
print(f"Query: {query}") print(f"Query: {query}")
response = chat.ask( response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
)
print(f"Response: {response}") print(f"Response: {response}")

View File

@@ -1,10 +1,11 @@
import os
import email import email
from pathlib import Path import os
from typing import List, Any from typing import Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader): class EmlxReader(BaseReader):
""" """
Apple Mail .emlx file reader. Apple Mail .emlx file reader.
@@ -16,7 +17,7 @@ class EmlxReader(BaseReader):
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
""" """
Load data from the input directory containing .emlx files. Load data from the input directory containing .emlx files.
@@ -25,8 +26,8 @@ class EmlxReader(BaseReader):
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of messages to read. max_count (int): Maximum amount of messages to read.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
count = 0 count = 0
# Walk through the directory recursively # Walk through the directory recursively
@@ -42,12 +43,12 @@ class EmlxReader(BaseReader):
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
try: try:
# Read the .emlx file # Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx files have a length prefix followed by the email content # .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email # The first line contains the length, followed by the email
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] email_content = lines[1]
@@ -56,20 +57,27 @@ class EmlxReader(BaseReader):
msg = email.message_from_string(email_content) msg = email.message_from_string(email_content)
# Extract email metadata # Extract email metadata
subject = msg.get('Subject', 'No Subject') subject = msg.get("Subject", "No Subject")
from_addr = msg.get('From', 'Unknown') from_addr = msg.get("From", "Unknown")
to_addr = msg.get('To', 'Unknown') to_addr = msg.get("To", "Unknown")
date = msg.get('Date', 'Unknown') date = msg.get("Date", "Unknown")
# Extract email body # Extract email body
body = "" body = ""
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html": if (
body += part.get_payload(decode=True).decode('utf-8', errors='ignore') part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
body += part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# break # break
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content # Create document content
doc_content = f""" doc_content = f"""
@@ -83,29 +91,35 @@ Date: {date}
# Create metadata # Create metadata
metadata = { metadata = {
'file_path': filepath, "file_path": filepath,
'subject': subject, "subject": subject,
'from': from_addr, "from": from_addr,
'to': to_addr, "to": to_addr,
'date': date, "date": date,
'filename': filename "filename": filename,
} }
if count == 0: if count == 0:
print("--------------------------------") print("--------------------------------")
print('dir path', dirpath) print("dir path", dirpath)
print(metadata) print(metadata)
print(doc_content) print(doc_content)
print("--------------------------------") print("--------------------------------")
body=[] body = []
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
print("-------------------------------- get content type -------------------------------") print(
"-------------------------------- get content type -------------------------------"
)
print(part.get_content_type()) print(part.get_content_type())
print(part) print(part)
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore')) # body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
print("-------------------------------- get content type -------------------------------") print(
"-------------------------------- get content type -------------------------------"
)
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
print(body) print(body)
print(body) print(body)
@@ -125,10 +139,11 @@ Date: {date}
print(f"Loaded {len(docs)} email documents") print(f"Loaded {len(docs)} email documents")
return docs return docs
# Use the custom EmlxReader instead of MboxReader # Use the custom EmlxReader instead of MboxReader
documents = EmlxReader().load_data( documents = EmlxReader().load_data(
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages", "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
max_count=1000 max_count=1000,
) # Returns list of documents ) # Returns list of documents
# Configure the index with larger chunk size to handle long metadata # Configure the index with larger chunk size to handle long metadata
@@ -138,8 +153,7 @@ from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200) text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, documents, transformations=[text_splitter]
transformations=[text_splitter]
) # Initialize index with documents ) # Initialize index with documents
query_engine = index.as_query_engine() query_engine = index.as_query_engine()

View File

@@ -1,10 +1,11 @@
import os
import email import email
from pathlib import Path import os
from typing import List, Any from typing import Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader from llama_index.core import Document, StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader): class EmlxReader(BaseReader):
""" """
@@ -17,7 +18,7 @@ class EmlxReader(BaseReader):
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
""" """
Load data from the input directory containing .emlx files. Load data from the input directory containing .emlx files.
@@ -26,8 +27,8 @@ class EmlxReader(BaseReader):
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of messages to read. max_count (int): Maximum amount of messages to read.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
count = 0 count = 0
# Walk through the directory recursively # Walk through the directory recursively
@@ -43,12 +44,12 @@ class EmlxReader(BaseReader):
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
try: try:
# Read the .emlx file # Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx files have a length prefix followed by the email content # .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email # The first line contains the length, followed by the email
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] email_content = lines[1]
@@ -57,20 +58,24 @@ class EmlxReader(BaseReader):
msg = email.message_from_string(email_content) msg = email.message_from_string(email_content)
# Extract email metadata # Extract email metadata
subject = msg.get('Subject', 'No Subject') subject = msg.get("Subject", "No Subject")
from_addr = msg.get('From', 'Unknown') from_addr = msg.get("From", "Unknown")
to_addr = msg.get('To', 'Unknown') to_addr = msg.get("To", "Unknown")
date = msg.get('Date', 'Unknown') date = msg.get("Date", "Unknown")
# Extract email body # Extract email body
body = "" body = ""
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
if part.get_content_type() == "text/plain": if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore') body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break break
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content # Create document content
doc_content = f""" doc_content = f"""
@@ -84,12 +89,12 @@ Date: {date}
# Create metadata # Create metadata
metadata = { metadata = {
'file_path': filepath, "file_path": filepath,
'subject': subject, "subject": subject,
'from': from_addr, "from": from_addr,
'to': to_addr, "to": to_addr,
'date': date, "date": date,
'filename': filename "filename": filename,
} }
doc = Document(text=doc_content, metadata=metadata) doc = Document(text=doc_content, metadata=metadata)
@@ -107,6 +112,7 @@ Date: {date}
print(f"Loaded {len(docs)} email documents") print(f"Loaded {len(docs)} email documents")
return docs return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000): def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
""" """
Create the index from mail data and save it to disk. Create the index from mail data and save it to disk.
@@ -129,10 +135,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_coun
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0) text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
# Create index # Create index
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
documents,
transformations=[text_splitter]
)
# Save the index # Save the index
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
@@ -141,6 +144,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_coun
return index return index
def load_index(save_dir: str = "mail_index"): def load_index(save_dir: str = "mail_index"):
""" """
Load the saved index from disk. Load the saved index from disk.
@@ -157,8 +161,7 @@ def load_index(save_dir: str = "mail_index"):
# Load index # Load index
index = VectorStoreIndex.from_vector_store( index = VectorStoreIndex.from_vector_store(
storage_context.vector_store, storage_context.vector_store, storage_context=storage_context
storage_context=storage_context
) )
print(f"Index loaded from {save_dir}") print(f"Index loaded from {save_dir}")
@@ -168,6 +171,7 @@ def load_index(save_dir: str = "mail_index"):
print(f"Error loading index: {e}") print(f"Error loading index: {e}")
return None return None
def query_index(index, query: str): def query_index(index, query: str):
""" """
Query the loaded index. Query the loaded index.
@@ -185,6 +189,7 @@ def query_index(index, query: str):
print(f"Query: {query}") print(f"Query: {query}")
print(f"Response: {response}") print(f"Response: {response}")
def main(): def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages" mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index" save_dir = "mail_index"
@@ -202,12 +207,13 @@ def main():
queries = [ queries = [
"Hows Berkeley Graduate Student Instructor", "Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?", "What emails mention GSR appointments?",
"Find emails about deadlines" "Find emails about deadlines",
] ]
for query in queries: for query in queries:
print("\n" + "="*50) print("\n" + "=" * 50)
query_index(index, query) query_index(index, query)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,10 +1,11 @@
import os
import email import email
from pathlib import Path import os
from typing import List, Any from typing import Any
from llama_index.core import VectorStoreIndex, Document, StorageContext
from llama_index.core.readers.base import BaseReader from llama_index.core import Document, StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader): class EmlxReader(BaseReader):
""" """
@@ -17,7 +18,7 @@ class EmlxReader(BaseReader):
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
""" """
Load data from the input directory containing .emlx files. Load data from the input directory containing .emlx files.
@@ -26,8 +27,8 @@ class EmlxReader(BaseReader):
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of messages to read. max_count (int): Maximum amount of messages to read.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
count = 0 count = 0
# Walk through the directory recursively # Walk through the directory recursively
@@ -43,12 +44,12 @@ class EmlxReader(BaseReader):
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
try: try:
# Read the .emlx file # Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx files have a length prefix followed by the email content # .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email # The first line contains the length, followed by the email
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] email_content = lines[1]
@@ -57,20 +58,24 @@ class EmlxReader(BaseReader):
msg = email.message_from_string(email_content) msg = email.message_from_string(email_content)
# Extract email metadata # Extract email metadata
subject = msg.get('Subject', 'No Subject') subject = msg.get("Subject", "No Subject")
from_addr = msg.get('From', 'Unknown') from_addr = msg.get("From", "Unknown")
to_addr = msg.get('To', 'Unknown') to_addr = msg.get("To", "Unknown")
date = msg.get('Date', 'Unknown') date = msg.get("Date", "Unknown")
# Extract email body # Extract email body
body = "" body = ""
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
if part.get_content_type() == "text/plain": if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore') body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break break
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content with metadata embedded in text # Create document content with metadata embedded in text
doc_content = f""" doc_content = f"""
@@ -84,10 +89,10 @@ Date: {date}
# Create minimal metadata (only essential info) # Create minimal metadata (only essential info)
metadata = { metadata = {
'subject': subject[:50], # Truncate subject "subject": subject[:50], # Truncate subject
'from': from_addr[:30], # Truncate from "from": from_addr[:30], # Truncate from
'date': date[:20], # Truncate date "date": date[:20], # Truncate date
'filename': filename # Keep filename "filename": filename, # Keep filename
} }
doc = Document(text=doc_content, metadata=metadata) doc = Document(text=doc_content, metadata=metadata)
@@ -105,7 +110,10 @@ Date: {date}
print(f"Loaded {len(docs)} email documents") print(f"Loaded {len(docs)} email documents")
return docs return docs
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000):
def create_and_save_index(
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
):
""" """
Create the index from mail data and save it to disk. Create the index from mail data and save it to disk.
@@ -127,10 +135,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", ma
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50) text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
# Create index # Create index
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
documents,
transformations=[text_splitter]
)
# Save the index # Save the index
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
@@ -139,6 +144,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_small", ma
return index return index
def load_index(save_dir: str = "mail_index_small"): def load_index(save_dir: str = "mail_index_small"):
""" """
Load the saved index from disk. Load the saved index from disk.
@@ -155,8 +161,7 @@ def load_index(save_dir: str = "mail_index_small"):
# Load index # Load index
index = VectorStoreIndex.from_vector_store( index = VectorStoreIndex.from_vector_store(
storage_context.vector_store, storage_context.vector_store, storage_context=storage_context
storage_context=storage_context
) )
print(f"Index loaded from {save_dir}") print(f"Index loaded from {save_dir}")
@@ -166,6 +171,7 @@ def load_index(save_dir: str = "mail_index_small"):
print(f"Error loading index: {e}") print(f"Error loading index: {e}")
return None return None
def query_index(index, query: str): def query_index(index, query: str):
""" """
Query the loaded index. Query the loaded index.
@@ -183,6 +189,7 @@ def query_index(index, query: str):
print(f"Query: {query}") print(f"Query: {query}")
print(f"Response: {response}") print(f"Response: {response}")
def main(): def main():
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages" mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
save_dir = "mail_index_small" save_dir = "mail_index_small"
@@ -200,12 +207,13 @@ def main():
queries = [ queries = [
"Hows Berkeley Graduate Student Instructor", "Hows Berkeley Graduate Student Instructor",
"What emails mention GSR appointments?", "What emails mention GSR appointments?",
"Find emails about deadlines" "Find emails about deadlines",
] ]
for query in queries: for query in queries:
print("\n" + "="*50) print("\n" + "=" * 50)
query_index(index, query) query_index(index, query)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,10 +1,11 @@
import os
import email import email
from pathlib import Path import os
from typing import List, Any from typing import Any
from llama_index.core import VectorStoreIndex, Document
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
class EmlxReader(BaseReader): class EmlxReader(BaseReader):
""" """
Apple Mail .emlx file reader. Apple Mail .emlx file reader.
@@ -16,7 +17,7 @@ class EmlxReader(BaseReader):
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]: def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
""" """
Load data from the input directory containing .emlx files. Load data from the input directory containing .emlx files.
@@ -25,8 +26,8 @@ class EmlxReader(BaseReader):
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of messages to read. max_count (int): Maximum amount of messages to read.
""" """
docs: List[Document] = [] docs: list[Document] = []
max_count = load_kwargs.get('max_count', 1000) max_count = load_kwargs.get("max_count", 1000)
count = 0 count = 0
# Check if directory exists and is accessible # Check if directory exists and is accessible
@@ -55,12 +56,12 @@ class EmlxReader(BaseReader):
print(f"Found .emlx file: {filepath}") print(f"Found .emlx file: {filepath}")
try: try:
# Read the .emlx file # Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read() content = f.read()
# .emlx files have a length prefix followed by the email content # .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email # The first line contains the length, followed by the email
lines = content.split('\n', 1) lines = content.split("\n", 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] email_content = lines[1]
@@ -69,20 +70,24 @@ class EmlxReader(BaseReader):
msg = email.message_from_string(email_content) msg = email.message_from_string(email_content)
# Extract email metadata # Extract email metadata
subject = msg.get('Subject', 'No Subject') subject = msg.get("Subject", "No Subject")
from_addr = msg.get('From', 'Unknown') from_addr = msg.get("From", "Unknown")
to_addr = msg.get('To', 'Unknown') to_addr = msg.get("To", "Unknown")
date = msg.get('Date', 'Unknown') date = msg.get("Date", "Unknown")
# Extract email body # Extract email body
body = "" body = ""
if msg.is_multipart(): if msg.is_multipart():
for part in msg.walk(): for part in msg.walk():
if part.get_content_type() == "text/plain": if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8', errors='ignore') body = part.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
break break
else: else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore') body = msg.get_payload(decode=True).decode(
"utf-8", errors="ignore"
)
# Create document content # Create document content
doc_content = f""" doc_content = f"""
@@ -96,12 +101,12 @@ Date: {date}
# Create metadata # Create metadata
metadata = { metadata = {
'file_path': filepath, "file_path": filepath,
'subject': subject, "subject": subject,
'from': from_addr, "from": from_addr,
'to': to_addr, "to": to_addr,
'date': date, "date": date,
'filename': filename "filename": filename,
} }
doc = Document(text=doc_content, metadata=metadata) doc = Document(text=doc_content, metadata=metadata)
@@ -119,6 +124,7 @@ Date: {date}
print(f"Loaded {len(docs)} email documents") print(f"Loaded {len(docs)} email documents")
return docs return docs
def main(): def main():
# Use the current directory where the sample.emlx file is located # Use the current directory where the sample.emlx file is located
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -143,5 +149,6 @@ def main():
res = query_engine.query("Hows Berkeley Graduate Student Instructor") res = query_engine.query("Hows Berkeley Graduate Student Instructor")
print(f"Response: {res}") print(f"Response: {res}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,20 +2,20 @@
import argparse import argparse
import time import time
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from transformers import AutoModel, BitsAndBytesConfig
@dataclass @dataclass
class BenchmarkConfig: class BenchmarkConfig:
model_path: str model_path: str
batch_sizes: List[int] batch_sizes: list[int]
seq_length: int seq_length: int
num_runs: int num_runs: int
use_fp16: bool = True use_fp16: bool = True
@@ -32,13 +32,11 @@ class GraphContainer:
def __init__(self, model: nn.Module, seq_length: int): def __init__(self, model: nn.Module, seq_length: int):
self.model = model self.model = model
self.seq_length = seq_length self.seq_length = seq_length
self.graphs: Dict[int, 'GraphWrapper'] = {} self.graphs: dict[int, GraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'GraphWrapper': def get_or_create(self, batch_size: int) -> "GraphWrapper":
if batch_size not in self.graphs: if batch_size not in self.graphs:
self.graphs[batch_size] = GraphWrapper( self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size] return self.graphs[batch_size]
@@ -55,13 +53,13 @@ class GraphWrapper:
self._warmup() self._warmup()
# Only use CUDA graphs on NVIDIA GPUs # Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'): if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
# Capture graph # Capture graph
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph): with torch.cuda.graph(self.graph):
self.static_output = self.model( self.static_output = self.model(
input_ids=self.static_input, input_ids=self.static_input,
attention_mask=self.static_attention_mask attention_mask=self.static_attention_mask,
) )
self.use_cuda_graph = True self.use_cuda_graph = True
else: else:
@@ -79,9 +77,7 @@ class GraphWrapper:
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint( return torch.randint(
0, 1000, (batch_size, seq_length), 0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
device=self.device,
dtype=torch.long
) )
def _warmup(self, num_warmup: int = 3): def _warmup(self, num_warmup: int = 3):
@@ -89,7 +85,7 @@ class GraphWrapper:
for _ in range(num_warmup): for _ in range(num_warmup):
self.model( self.model(
input_ids=self.static_input, input_ids=self.static_input,
attention_mask=self.static_attention_mask attention_mask=self.static_attention_mask,
) )
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
@@ -133,8 +129,12 @@ class ModelOptimizer:
print("- Using FP16 precision") print("- Using FP16 precision")
# Check if using SDPA (only on CUDA) # Check if using SDPA (only on CUDA)
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6: if (
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)") print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else: else:
print("- PyTorch SDPA not available") print("- PyTorch SDPA not available")
@@ -142,7 +142,8 @@ class ModelOptimizer:
# Flash Attention (only on CUDA) # Flash Attention (only on CUDA)
if config.use_flash_attention and torch.cuda.is_available(): if config.use_flash_attention and torch.cuda.is_available():
try: try:
from flash_attn.flash_attention import FlashAttention from flash_attn.flash_attention import FlashAttention # noqa: F401
print("- Flash Attention 2 available") print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"): if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2" model.config.attention_mode = "flash_attention_2"
@@ -153,8 +154,9 @@ class ModelOptimizer:
# Memory efficient attention (only on CUDA) # Memory efficient attention (only on CUDA)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from xformers.ops import memory_efficient_attention from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention") print("- Enabled xformers memory efficient attention")
else: else:
@@ -220,7 +222,7 @@ class Benchmark:
self.graphs = None self.graphs = None
self.timer = Timer() self.timer = Timer()
except Exception as e: except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}") print(f"ERROR in benchmark initialization: {e!s}")
raise raise
def _load_model(self) -> nn.Module: def _load_model(self) -> nn.Module:
@@ -230,15 +232,17 @@ class Benchmark:
# Int4 quantization using HuggingFace integration # Int4 quantization using HuggingFace integration
if self.config.use_int4: if self.config.use_int4:
import bitsandbytes as bnb import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}") print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化 # Check if using custom 8bit quantization
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt: if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers") print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置) # Load original model (without quantization config)
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
# set default to half # set default to half
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32 compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
@@ -247,52 +251,58 @@ class Benchmark:
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
) )
# 定义替换函数 # Define replacement function
def replace_linear_with_linear8bitlt(model): def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt""" """Recursively replace all nn.Linear layers with Linear8bitLt"""
for name, module in list(model.named_children()): for name, module in list(model.named_children()):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# 获取原始线性层的参数 # Get original linear layer parameters
in_features = module.in_features in_features = module.in_features
out_features = module.out_features out_features = module.out_features
bias = module.bias is not None bias = module.bias is not None
# 创建8bit线性层 # Create 8bit linear layer
# print size # print size
print(f"in_features: {in_features}, out_features: {out_features}") print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt( new_module = bnb.nn.Linear8bitLt(
in_features, in_features,
out_features, out_features,
bias=bias, bias=bias,
has_fp16_weights=False has_fp16_weights=False,
) )
# 复制权重和偏置 # Copy weights and bias
new_module.weight.data = module.weight.data new_module.weight.data = module.weight.data
if bias: if bias:
new_module.bias.data = module.bias.data new_module.bias.data = module.bias.data
# 替换模块 # Replace module
setattr(model, name, new_module) setattr(model, name, new_module)
else: else:
# 递归处理子模块 # Process child modules recursively
replace_linear_with_linear8bitlt(module) replace_linear_with_linear8bitlt(module)
return model return model
# 替换所有线性层 # Replace all linear layers
model = replace_linear_with_linear8bitlt(model) model = replace_linear_with_linear8bitlt(model)
# add torch compile # add torch compile
model = torch.compile(model) model = torch.compile(model)
# 将模型移到GPU量化发生在这里 # Move model to GPU (quantization happens here)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model = model.to(device) model = model.to(device)
print("- All linear layers replaced with Linear8bitLt") print("- All linear layers replaced with Linear8bitLt")
else: else:
# 使用原来的Int4量化方法 # Use original Int4 quantization method
print("- Using bitsandbytes for Int4 quantization") print("- Using bitsandbytes for Int4 quantization")
# Create quantization config # Create quantization config
@@ -302,7 +312,7 @@ class Benchmark:
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4" bnb_4bit_quant_type="nf4",
) )
print("- Quantization config:", quantization_config) print("- Quantization config:", quantization_config)
@@ -312,7 +322,7 @@ class Benchmark:
self.config.model_path, self.config.model_path,
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping device_map="auto", # Let HF decide on device mapping
) )
# Check if model loaded successfully # Check if model loaded successfully
@@ -324,7 +334,7 @@ class Benchmark:
# Apply optimizations directly here # Apply optimizations directly here
print("\nApplying model optimizations:") print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt: if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization") print("- Model moved to GPU with Linear8bitLt quantization")
else: else:
# Skip moving to GPU since device_map="auto" already did that # Skip moving to GPU since device_map="auto" already did that
@@ -334,8 +344,12 @@ class Benchmark:
print(f"- Using {compute_dtype} for compute dtype") print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA # Check CUDA and SDPA
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6: if (
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)") print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else: else:
print("- PyTorch SDPA not available") print("- PyTorch SDPA not available")
@@ -343,8 +357,7 @@ class Benchmark:
# Try xformers if available (only on CUDA) # Try xformers if available (only on CUDA)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from xformers.ops import memory_efficient_attention if hasattr(model, "enable_xformers_memory_efficient_attention"):
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention() model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention") print("- Enabled xformers memory efficient attention")
else: else:
@@ -370,7 +383,7 @@ class Benchmark:
self.config.model_path, self.config.model_path,
quantization_config=quantization_config, quantization_config=quantization_config,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
device_map="auto" device_map="auto",
) )
if model is None: if model is None:
@@ -389,6 +402,7 @@ class Benchmark:
# Apply standard optimizations # Apply standard optimizations
# set default to half # set default to half
import torch import torch
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config) model = ModelOptimizer.optimize(model, self.config)
model = model.half() model = model.half()
@@ -403,25 +417,31 @@ class Benchmark:
return model return model
except Exception as e: except Exception as e:
print(f"ERROR loading model: {str(e)}") print(f"ERROR loading model: {e!s}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return torch.randint( return torch.randint(
0, 1000, 0,
1000,
(batch_size, self.config.seq_length), (batch_size, self.config.seq_length),
device=device, device=device,
dtype=torch.long dtype=torch.long,
) )
def _run_inference( def _run_inference(
self, self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
input_ids: torch.Tensor, ) -> tuple[float, torch.Tensor]:
graph_wrapper: Optional[GraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing(): with torch.no_grad(), self.timer.timing():
@@ -432,7 +452,7 @@ class Benchmark:
return self.timer.elapsed_time(), output return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]: def run(self) -> dict[int, dict[str, float]]:
results = {} results = {}
# Reset peak memory stats # Reset peak memory stats
@@ -450,9 +470,7 @@ class Benchmark:
# Get or create graph for this batch size # Get or create graph for this batch size
graph_wrapper = ( graph_wrapper = (
self.graphs.get_or_create(batch_size) self.graphs.get_or_create(batch_size) if self.graphs is not None else None
if self.graphs is not None
else None
) )
# Pre-allocate input tensor # Pre-allocate input tensor
@@ -490,7 +508,7 @@ class Benchmark:
# Log memory usage # Log memory usage
if torch.cuda.is_available(): if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0 # MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0 peak_memory_gb = 0.0
@@ -604,7 +622,15 @@ def main():
os.makedirs("results", exist_ok=True) os.makedirs("results", exist_ok=True)
# Generate filename based on configuration # Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32" precision_type = (
"int4"
if config.use_int4
else "int8"
if config.use_int8
else "fp16"
if config.use_fp16
else "fp32"
)
model_name = os.path.basename(config.model_path) model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json" output_file = f"results/benchmark_{model_name}_{precision_type}.json"
@@ -612,17 +638,20 @@ def main():
with open(output_file, "w") as f: with open(output_file, "w") as f:
json.dump( json.dump(
{ {
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()}, "config": {
"results": {str(k): v for k, v in results.items()} k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
},
"results": {str(k): v for k, v in results.items()},
}, },
f, f,
indent=2 indent=2,
) )
print(f"Results saved to {output_file}") print(f"Results saved to {output_file}")
except Exception as e: except Exception as e:
print(f"Benchmark failed: {e}") print(f"Benchmark failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@@ -1,5 +1,7 @@
import os import os
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core import StorageContext, VectorStoreIndex
def load_index(save_dir: str = "mail_index"): def load_index(save_dir: str = "mail_index"):
""" """
@@ -17,8 +19,7 @@ def load_index(save_dir: str = "mail_index"):
# Load index # Load index
index = VectorStoreIndex.from_vector_store( index = VectorStoreIndex.from_vector_store(
storage_context.vector_store, storage_context.vector_store, storage_context=storage_context
storage_context=storage_context
) )
print(f"Index loaded from {save_dir}") print(f"Index loaded from {save_dir}")
@@ -28,6 +29,7 @@ def load_index(save_dir: str = "mail_index"):
print(f"Error loading index: {e}") print(f"Error loading index: {e}")
return None return None
def query_index(index, query: str): def query_index(index, query: str):
""" """
Query the loaded index. Query the loaded index.
@@ -45,11 +47,14 @@ def query_index(index, query: str):
print(f"\nQuery: {query}") print(f"\nQuery: {query}")
print(f"Response: {response}") print(f"Response: {response}")
def main(): def main():
save_dir = "mail_index" save_dir = "mail_index"
# Check if index exists # Check if index exists
if not os.path.exists(save_dir) or not os.path.exists(os.path.join(save_dir, "vector_store.json")): if not os.path.exists(save_dir) or not os.path.exists(
os.path.join(save_dir, "vector_store.json")
):
print(f"Index not found in {save_dir}") print(f"Index not found in {save_dir}")
print("Please run mail_reader_save_load.py first to create the index.") print("Please run mail_reader_save_load.py first to create the index.")
return return
@@ -61,22 +66,22 @@ def main():
print("Failed to load index.") print("Failed to load index.")
return return
print("\n" + "="*60) print("\n" + "=" * 60)
print("Email Query Interface") print("Email Query Interface")
print("="*60) print("=" * 60)
print("Type 'quit' to exit") print("Type 'quit' to exit")
print("Type 'help' for example queries") print("Type 'help' for example queries")
print("="*60) print("=" * 60)
# Interactive query loop # Interactive query loop
while True: while True:
try: try:
query = input("\nEnter your query: ").strip() query = input("\nEnter your query: ").strip()
if query.lower() == 'quit': if query.lower() == "quit":
print("Goodbye!") print("Goodbye!")
break break
elif query.lower() == 'help': elif query.lower() == "help":
print("\nExample queries:") print("\nExample queries:")
print("- Hows Berkeley Graduate Student Instructor") print("- Hows Berkeley Graduate Student Instructor")
print("- What emails mention GSR appointments?") print("- What emails mention GSR appointments?")
@@ -95,5 +100,6 @@ def main():
except Exception as e: except Exception as e:
print(f"Error processing query: {e}") print(f"Error processing query: {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,29 +1,32 @@
import time import time
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer
import mlx.core as mx import mlx.core as mx
import numpy as np
import torch
from mlx_lm import load from mlx_lm import load
from sentence_transformers import SentenceTransformer
# --- Configuration --- # --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B" MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ" MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128] BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data --- # --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES) DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b # --- Benchmark Functions ---b
def benchmark_torch(model, sentences): def benchmark_torch(model, sentences):
start_time = time.time() start_time = time.time()
model.encode(sentences, convert_to_numpy=True) model.encode(sentences, convert_to_numpy=True)
end_time = time.time() end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences): def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time() start_time = time.time()
@@ -63,6 +66,7 @@ def benchmark_mlx(model, tokenizer, sentences):
end_time = time.time() end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution --- # --- Main Execution ---
def main(): def main():
print("--- Initializing Models ---") print("--- Initializing Models ---")
@@ -98,7 +102,9 @@ def main():
results_torch.append(np.mean(torch_times)) results_torch.append(np.mean(torch_times))
# Benchmark MLX # Benchmark MLX
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)] mlx_times = [
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times)) results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---") print("\n--- Benchmark Results (Average time per batch in ms) ---")
@@ -109,10 +115,16 @@ def main():
# --- Plotting --- # --- Plotting ---
print("\n--- Generating Plot ---") print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})') plt.plot(
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX') BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}') plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
plt.xlabel("Batch Size") plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)") plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES) plt.xticks(BATCH_SIZES)
@@ -124,5 +136,6 @@ def main():
plt.savefig(output_filename) plt.savefig(output_filename)
print(f"Plot saved to {output_filename}") print(f"Plot saved to {output_filename}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -3,13 +3,15 @@
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
""" """
import zmq
import time
import threading
import sys import sys
sys.path.append('packages/leann-backend-diskann') import time
import zmq
sys.path.append("packages/leann-backend-diskann")
from leann_backend_diskann import embedding_pb2 from leann_backend_diskann import embedding_pb2
def test_zmq_with_same_model(): def test_zmq_with_same_model():
print("=== Testing ZMQ with same model as main_cli_example.py ===") print("=== Testing ZMQ with same model as main_cli_example.py ===")
@@ -18,19 +20,20 @@ def test_zmq_with_same_model():
# Start server with the same model # Start server with the same model
import subprocess import subprocess
server_cmd = [ server_cmd = [
sys.executable, "-m", sys.executable,
"-m",
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", "5556", # Use different port to avoid conflicts "--zmq-port",
"--model-name", model_name "5556", # Use different port to avoid conflicts
"--model-name",
model_name,
] ]
print(f"Starting server with command: {' '.join(server_cmd)}") print(f"Starting server with command: {' '.join(server_cmd)}")
server_process = subprocess.Popen( server_process = subprocess.Popen(
server_cmd, server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
) )
# Wait for server to start # Wait for server to start
@@ -105,6 +108,7 @@ def test_zmq_with_same_model():
server_process.wait() server_process.wait()
print("Server terminated") print("Server terminated")
if __name__ == "__main__": if __name__ == "__main__":
success = test_zmq_with_same_model() success = test_zmq_with_same_model()
if success: if success:

View File

@@ -1,26 +1,27 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModel
# Add MLX imports # Add MLX imports
try: try:
import mlx.core as mx import mlx.core as mx
from mlx_lm.utils import load from mlx_lm.utils import load
MLX_AVAILABLE = True MLX_AVAILABLE = True
except ImportError as e: except ImportError:
print("MLX not available. Install with: uv pip install mlx mlx-lm") print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False MLX_AVAILABLE = False
@dataclass @dataclass
class BenchmarkConfig: class BenchmarkConfig:
model_path: str = "facebook/contriever" model_path: str = "facebook/contriever"
batch_sizes: List[int] = None batch_sizes: list[int] = None
seq_length: int = 256 seq_length: int = 256
num_runs: int = 5 num_runs: int = 5
use_fp16: bool = True use_fp16: bool = True
@@ -35,6 +36,7 @@ class BenchmarkConfig:
if self.batch_sizes is None: if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64] self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark: class MLXBenchmark:
"""MLX-specific benchmark for embedding models""" """MLX-specific benchmark for embedding models"""
@@ -55,11 +57,7 @@ class MLXBenchmark:
def _create_random_batch(self, batch_size: int): def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch""" """Create random input batches for MLX testing - same as PyTorch"""
return torch.randint( return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
0, 1000,
(batch_size, self.config.seq_length),
dtype=torch.long
)
def _run_inference(self, input_ids: torch.Tensor) -> float: def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch""" """Run MLX inference with same input as PyTorch"""
@@ -82,12 +80,12 @@ class MLXBenchmark:
except Exception as e: except Exception as e:
print(f"MLX inference error: {e}") print(f"MLX inference error: {e}")
return float('inf') return float("inf")
end_time = time.time() end_time = time.time()
return end_time - start_time return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]: def run(self) -> dict[int, dict[str, float]]:
"""Run the MLX benchmark across all batch sizes""" """Run the MLX benchmark across all batch sizes"""
results = {} results = {}
@@ -111,10 +109,10 @@ class MLXBenchmark:
break break
# Run benchmark # Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"): for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try: try:
elapsed_time = self._run_inference(input_ids) elapsed_time = self._run_inference(input_ids)
if elapsed_time != float('inf'): if elapsed_time != float("inf"):
times.append(elapsed_time) times.append(elapsed_time)
except Exception as e: except Exception as e:
print(f"Error during MLX inference: {e}") print(f"Error during MLX inference: {e}")
@@ -145,16 +143,22 @@ class MLXBenchmark:
return results return results
class Benchmark: class Benchmark:
def __init__(self, config: BenchmarkConfig): def __init__(self, config: BenchmarkConfig):
self.config = config self.config = config
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = self._load_model() self.model = self._load_model()
def _load_model(self) -> nn.Module: def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...") print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path) model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16: if self.config.use_fp16:
model = model.half() model = model.half()
@@ -166,10 +170,11 @@ class Benchmark:
def _create_random_batch(self, batch_size: int) -> torch.Tensor: def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint( return torch.randint(
0, 1000, 0,
1000,
(batch_size, self.config.seq_length), (batch_size, self.config.seq_length),
device=self.device, device=self.device,
dtype=torch.long dtype=torch.long,
) )
def _run_inference(self, input_ids: torch.Tensor) -> float: def _run_inference(self, input_ids: torch.Tensor) -> float:
@@ -177,12 +182,12 @@ class Benchmark:
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
output = self.model(input_ids=input_ids, attention_mask=attention_mask) self.model(input_ids=input_ids, attention_mask=attention_mask)
end_time = time.time() end_time = time.time()
return end_time - start_time return end_time - start_time
def run(self) -> Dict[int, Dict[str, float]]: def run(self) -> dict[int, dict[str, float]]:
results = {} results = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -194,7 +199,7 @@ class Benchmark:
input_ids = self._create_random_batch(batch_size) input_ids = self._create_random_batch(batch_size)
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"): for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try: try:
elapsed_time = self._run_inference(input_ids) elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time) times.append(elapsed_time)
@@ -219,7 +224,7 @@ class Benchmark:
print(f"Throughput: {throughput:.2f} sequences/second") print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available(): if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
else: else:
peak_memory_gb = 0.0 peak_memory_gb = 0.0
@@ -228,6 +233,7 @@ class Benchmark:
return results return results
def run_benchmark(): def run_benchmark():
"""Main function to run the benchmark with optimized parameters.""" """Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig() config = BenchmarkConfig()
@@ -242,16 +248,13 @@ def run_benchmark():
return { return {
"max_throughput": max_throughput, "max_throughput": max_throughput,
"avg_throughput": avg_throughput, "avg_throughput": avg_throughput,
"results": results "results": results,
} }
except Exception as e: except Exception as e:
print(f"Benchmark failed: {e}") print(f"Benchmark failed: {e}")
return { return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
def run_mlx_benchmark(): def run_mlx_benchmark():
"""Run MLX-specific benchmark""" """Run MLX-specific benchmark"""
@@ -260,13 +263,10 @@ def run_mlx_benchmark():
return { return {
"max_throughput": 0.0, "max_throughput": 0.0,
"avg_throughput": 0.0, "avg_throughput": 0.0,
"error": "MLX not available" "error": "MLX not available",
} }
config = BenchmarkConfig( config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
use_mlx=True
)
try: try:
benchmark = MLXBenchmark(config) benchmark = MLXBenchmark(config)
@@ -276,7 +276,7 @@ def run_mlx_benchmark():
return { return {
"max_throughput": 0.0, "max_throughput": 0.0,
"avg_throughput": 0.0, "avg_throughput": 0.0,
"error": "No valid results" "error": "No valid results",
} }
max_throughput = max(results[batch_size]["throughput"] for batch_size in results) max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
@@ -285,16 +285,13 @@ def run_mlx_benchmark():
return { return {
"max_throughput": max_throughput, "max_throughput": max_throughput,
"avg_throughput": avg_throughput, "avg_throughput": avg_throughput,
"results": results "results": results,
} }
except Exception as e: except Exception as e:
print(f"MLX benchmark failed: {e}") print(f"MLX benchmark failed: {e}")
return { return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": str(e)
}
if __name__ == "__main__": if __name__ == "__main__":
print("=== PyTorch Benchmark ===") print("=== PyTorch Benchmark ===")
@@ -308,7 +305,7 @@ if __name__ == "__main__":
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second") print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results # Compare results
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0: if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput'] speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
print(f"\n=== Comparison ===") print("\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch") print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")

87
tests/README.md Normal file
View File

@@ -0,0 +1,87 @@
# LEANN Tests
This directory contains automated tests for the LEANN project using pytest.
## Test Files
### `test_readme_examples.py`
Tests the examples shown in README.md:
- The basic example code that users see first
- Import statements work correctly
- Different backend options (HNSW, DiskANN)
- Different LLM configuration options
### `test_basic.py`
Basic functionality tests that verify:
- All packages can be imported correctly
- C++ extensions (FAISS, DiskANN) load properly
- Basic index building and searching works for both HNSW and DiskANN backends
- Uses parametrized tests to test both backends
### `test_main_cli.py`
Tests the main CLI example functionality:
- Tests with facebook/contriever embeddings
- Tests with OpenAI embeddings (if API key is available)
- Tests error handling with invalid parameters
- Verifies that normalized embeddings are detected and cosine distance is used
## Running Tests
### Install test dependencies:
```bash
# Using extras
uv pip install -e ".[test]"
```
### Run all tests:
```bash
pytest tests/
# Or with coverage
pytest tests/ --cov=leann --cov-report=html
# Run in parallel (faster)
pytest tests/ -n auto
```
### Run specific tests:
```bash
# Only basic tests
pytest tests/test_basic.py
# Only tests that don't require OpenAI
pytest tests/ -m "not openai"
# Skip slow tests
pytest tests/ -m "not slow"
```
### Run with specific backend:
```bash
# Test only HNSW backend
pytest tests/test_basic.py::test_backend_basic[hnsw]
# Test only DiskANN backend
pytest tests/test_basic.py::test_backend_basic[diskann]
```
## CI/CD Integration
Tests are automatically run in GitHub Actions:
1. After building wheel packages
2. On multiple Python versions (3.9 - 3.13)
3. On both Ubuntu and macOS
4. Using pytest with appropriate markers and flags
### pytest.ini Configuration
The `pytest.ini` file configures:
- Test discovery paths
- Default timeout (600 seconds)
- Environment variables (HF_HUB_DISABLE_SYMLINKS, TOKENIZERS_PARALLELISM)
- Custom markers for slow and OpenAI tests
- Verbose output with short tracebacks
### Known Issues
- OpenAI tests are automatically skipped if no API key is provided

92
tests/test_basic.py Normal file
View File

@@ -0,0 +1,92 @@
"""
Basic functionality tests for CI pipeline using pytest.
"""
import os
import tempfile
from pathlib import Path
import pytest
def test_imports():
"""Test that all packages can be imported."""
# Test C++ extensions
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
)
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
def test_backend_basic(backend_name):
"""Test basic functionality for each backend."""
from leann.api import LeannBuilder, LeannSearcher, SearchResult
# Create temporary directory for index
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"test.{backend_name}")
# Test with small data
texts = [f"This is document {i} about topic {i % 5}" for i in range(100)]
# Configure builder based on backend
if backend_name == "hnsw":
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
M=16,
efConstruction=200,
)
else: # diskann
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
num_neighbors=32,
search_list_size=50,
)
# Add texts
for text in texts:
builder.add_text(text)
# Build index
builder.build_index(index_path)
# Test search
searcher = LeannSearcher(index_path)
results = searcher.search("document about topic 2", top_k=5)
# Verify results
assert len(results) > 0
assert isinstance(results[0], SearchResult)
assert "topic 2" in results[0].text or "document" in results[0].text
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
)
def test_large_index():
"""Test with larger dataset."""
from leann.api import LeannBuilder, LeannSearcher
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / "test_large.hnsw")
texts = [f"Document {i}: {' '.join([f'word{j}' for j in range(50)])}" for i in range(1000)]
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
searcher = LeannSearcher(index_path)
results = searcher.search(["word10 word20"], top_k=10)
assert len(results[0]) == 10

49
tests/test_ci_minimal.py Normal file
View File

@@ -0,0 +1,49 @@
"""
Minimal tests for CI that don't require model loading or significant memory.
"""
import subprocess
import sys
def test_package_imports():
"""Test that all core packages can be imported."""
# Core package
# Backend packages
# Core modules
assert True # If we get here, imports worked
def test_cli_help():
"""Test that CLI example shows help."""
result = subprocess.run(
[sys.executable, "examples/main_cli_example.py", "--help"], capture_output=True, text=True
)
assert result.returncode == 0
assert "usage:" in result.stdout.lower() or "usage:" in result.stderr.lower()
assert "--llm" in result.stdout or "--llm" in result.stderr
def test_backend_registration():
"""Test that backends are properly registered."""
from leann.api import get_registered_backends
backends = get_registered_backends()
assert "hnsw" in backends
assert "diskann" in backends
def test_version_info():
"""Test that packages have version information."""
import leann
import leann_backend_diskann
import leann_backend_hnsw
# Check that packages have __version__ or can be imported
assert hasattr(leann, "__version__") or True
assert hasattr(leann_backend_hnsw, "__version__") or True
assert hasattr(leann_backend_diskann, "__version__") or True

120
tests/test_main_cli.py Normal file
View File

@@ -0,0 +1,120 @@
"""
Test main_cli_example functionality using pytest.
"""
import os
import subprocess
import sys
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def test_data_dir():
"""Return the path to test data directory."""
return Path("examples/data")
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
)
def test_main_cli_simulated(test_data_dir):
"""Test main_cli with simulated LLM."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use a subdirectory that doesn't exist yet to force index creation
index_dir = Path(temp_dir) / "test_index"
cmd = [
sys.executable,
"examples/main_cli_example.py",
"--llm",
"simulated",
"--embedding-model",
"facebook/contriever",
"--embedding-mode",
"sentence-transformers",
"--index-dir",
str(index_dir),
"--data-dir",
str(test_data_dir),
"--query",
"What is Pride and Prejudice about?",
]
env = os.environ.copy()
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
env["TOKENIZERS_PARALLELISM"] = "false"
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
# Check return code
assert result.returncode == 0, f"Command failed: {result.stderr}"
# Verify output
output = result.stdout + result.stderr
assert "Leann index built at" in output or "Using existing index" in output
assert "This is a simulated answer" in output
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
def test_main_cli_openai(test_data_dir):
"""Test main_cli with OpenAI embeddings."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use a subdirectory that doesn't exist yet to force index creation
index_dir = Path(temp_dir) / "test_index_openai"
cmd = [
sys.executable,
"examples/main_cli_example.py",
"--llm",
"simulated", # Use simulated LLM to avoid GPT-4 costs
"--embedding-model",
"text-embedding-3-small",
"--embedding-mode",
"openai",
"--index-dir",
str(index_dir),
"--data-dir",
str(test_data_dir),
"--query",
"What is Pride and Prejudice about?",
]
env = os.environ.copy()
env["TOKENIZERS_PARALLELISM"] = "false"
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
assert result.returncode == 0, f"Command failed: {result.stderr}"
# Verify cosine distance was used
output = result.stdout + result.stderr
assert any(
msg in output
for msg in [
"distance_metric='cosine'",
"Automatically setting distance_metric='cosine'",
"Using cosine distance",
]
)
def test_main_cli_error_handling(test_data_dir):
"""Test main_cli with invalid parameters."""
with tempfile.TemporaryDirectory() as temp_dir:
cmd = [
sys.executable,
"examples/main_cli_example.py",
"--llm",
"invalid_llm_type",
"--index-dir",
temp_dir,
"--data-dir",
str(test_data_dir),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
# Should fail with invalid LLM type
assert result.returncode != 0
assert "Unknown LLM type" in result.stderr or "invalid_llm_type" in result.stderr

View File

@@ -0,0 +1,165 @@
"""
Test examples from README.md to ensure documentation is accurate.
"""
import os
import platform
import tempfile
from pathlib import Path
import pytest
def test_readme_basic_example():
"""Test the basic example from README.md."""
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
# This is the exact code from README (with smaller model for CI)
from leann import LeannBuilder, LeannChat, LeannSearcher
from leann.api import SearchResult
with tempfile.TemporaryDirectory() as temp_dir:
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
# Build an index
# In CI, use a smaller model to avoid memory issues
if os.environ.get("CI") == "true":
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
dimensions=384, # Smaller dimensions
)
else:
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
builder.build_index(INDEX_PATH)
# Verify index was created
# The index path should be a directory containing index files
index_dir = Path(INDEX_PATH).parent
assert index_dir.exists()
# Check that index files were created
index_files = list(index_dir.glob(f"{Path(INDEX_PATH).stem}.*"))
assert len(index_files) > 0
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Verify search results
assert len(results) > 0
assert isinstance(results[0], SearchResult)
# The second text about banana-crocodile should be more relevant
assert "banana" in results[0].text or "crocodile" in results[0].text
# Chat with your data (using simulated LLM to avoid external dependencies)
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
# Verify chat works
assert isinstance(response, str)
assert len(response) > 0
def test_readme_imports():
"""Test that the imports shown in README work correctly."""
# These are the imports shown in README
from leann import LeannBuilder, LeannChat, LeannSearcher
# Verify they are the correct types
assert callable(LeannBuilder)
assert callable(LeannSearcher)
assert callable(LeannChat)
def test_backend_options():
"""Test different backend options mentioned in documentation."""
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
from leann import LeannBuilder
with tempfile.TemporaryDirectory() as temp_dir:
# Use smaller model in CI to avoid memory issues
if os.environ.get("CI") == "true":
model_args = {
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"dimensions": 384,
}
else:
model_args = {}
# Test HNSW backend (as shown in README)
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args)
builder_hnsw.add_text("Test document for HNSW backend")
builder_hnsw.build_index(hnsw_path)
assert Path(hnsw_path).parent.exists()
assert len(list(Path(hnsw_path).parent.glob(f"{Path(hnsw_path).stem}.*"))) > 0
# Test DiskANN backend (mentioned as available option)
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
builder_diskann = LeannBuilder(backend_name="diskann", **model_args)
builder_diskann.add_text("Test document for DiskANN backend")
builder_diskann.build_index(diskann_path)
assert Path(diskann_path).parent.exists()
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
def test_llm_config_simulated():
"""Test simulated LLM configuration option."""
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
from leann import LeannBuilder, LeannChat
with tempfile.TemporaryDirectory() as temp_dir:
# Build a simple index
index_path = str(Path(temp_dir) / "test.leann")
# Use smaller model in CI to avoid memory issues
if os.environ.get("CI") == "true":
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
dimensions=384,
)
else:
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("Test document for LLM testing")
builder.build_index(index_path)
# Test simulated LLM config
llm_config = {"type": "simulated"}
chat = LeannChat(index_path, llm_config=llm_config)
response = chat.ask("What is this document about?", top_k=1)
assert isinstance(response, str)
assert len(response) > 0
@pytest.mark.skip(reason="Requires HF model download and may timeout")
def test_llm_config_hf():
"""Test HuggingFace LLM configuration option."""
from leann import LeannBuilder, LeannChat
pytest.importorskip("transformers") # Skip if transformers not installed
with tempfile.TemporaryDirectory() as temp_dir:
# Build a simple index
index_path = str(Path(temp_dir) / "test.leann")
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("Test document for LLM testing")
builder.build_index(index_path)
# Test HF LLM config
llm_config = {"type": "hf", "model": "Qwen/Qwen3-0.6B"}
chat = LeannChat(index_path, llm_config=llm_config)
response = chat.ask("What is this document about?", top_k=1)
assert isinstance(response, str)
assert len(response) > 0

3212
uv.lock generated
View File

File diff suppressed because it is too large Load Diff