Compare commits

..

17 Commits

Author SHA1 Message Date
Andy Lee
fd5c052bd8 Update faiss for batch distances calc & caching when updating 2025-09-30 12:40:05 -07:00
Andy Lee
2f77d0185c Merge remote-tracking branch 'origin/main' into fix-update 2025-09-30 00:56:27 -07:00
Andy Lee
82d536b2ae fix: launch embedding server before adding 2025-09-30 00:53:22 -07:00
yichuan520030910320
e2b37914ce add dynamic add test 2025-09-30 00:48:46 -07:00
Andy Lee
e588100674 fix: set ntotal for storage as well (#129) 2025-09-29 20:43:16 -07:00
Andy Lee
f42e086383 fix: set ntotal for storage as well 2025-09-29 19:10:09 -07:00
Andy Lee
fecee94af1 Experiments (#68)
* feat: finance bench

* docs: results

* chore: ignroe data README

* feat: fix financebench

* feat: laion, also required idmaps support

* style: format

* style: format

* fix: resolve ruff linting errors

- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention

* feat: enron email bench

* experiments for running DiskANN & BM25 on Arch 4090

* style: format

* chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules

* docs: data

* docs: data updated

* fix: as package

* fix(ci): only run pre-commit

* chore: use http url of astchunk; use group for some dev deps

* fix(ci): should checkout modules as well since `uv sync` checks

* fix(ci): run with lint only

* fix: find links to install wheels available

* CI: force local wheels in uv install step

* CI: install local wheels via file paths

* CI: pick wheels matching current Python tag

* CI: handle python tag mismatches for local wheels

* CI: use matrix python venv and set macOS deployment target

* CI: revert install step to match main

* CI: use uv group install with local wheel selection

* CI: rely on setup-uv for Python and tighten group install

* CI: install build deps with uv python interpreter

* CI: use temporary uv venv for build deps

* CI: add build venv scripts path for wheel repair
2025-09-24 11:19:04 -07:00
yichuan520030910320
01475c10a0 add img 2025-09-23 23:25:05 -07:00
yichuan520030910320
c8aa063f48 merge main 2025-09-23 23:21:53 -07:00
yichuan520030910320
576beb13db add doc about multimodal 2025-09-23 23:21:03 -07:00
Andy Lee
63c7b0c8a3 Fix restart embedding server when passages change (#117)
* fix: restart embedding server when passages change

* fix: restore python 3.9 typing compatibility
2025-09-23 22:28:36 -07:00
Andy Lee
ec889f7ef4 Allow 'leann ask' to accept a positional question (#116) 2025-09-23 21:18:57 -07:00
Yi-Ting Chiu
322e5c162d docs: open ai api compatibility (#118) 2025-09-23 21:17:50 -07:00
Yichuan Wang
edde0cdeb2 [Feat] ColQwen intergration (#111)
* add colqwen stuff

* add colqwen stuff and pass ruff

* remove ipynb
2025-09-23 17:51:29 -07:00
Andy Lee
db7ba27ff6 feat: Add support for configurable local LLM endpoints (#115)
* feat: support configurable local llm endpoints

* docs
2025-09-23 15:12:13 -07:00
Andy Lee
5f7806e16f Introducing dynamic index update (#108)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly

* fix: enable is_compact=False, is_recompute=True

* feat: update when recompute

* test

* fix: real recompute

* refactor

* fix: compare with no-recompute

* fix: test
2025-09-21 22:56:27 -07:00
yichuan-w
d034e2195b fix build from source in diskann 2025-09-20 19:52:29 +00:00
53 changed files with 12113 additions and 5049 deletions

View File

@@ -17,26 +17,17 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
- name: Run pre-commit with only lint group (no project deps)
run: |
uv tool install ruff
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
@@ -103,14 +94,11 @@ jobs:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
@@ -168,11 +156,24 @@ jobs:
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }} .uv-build
if [[ "$RUNNER_OS" == "Windows" ]]; then
BUILD_PY=".uv-build\\Scripts\\python.exe"
else
uv pip install --system delocate
BUILD_PY=".uv-build/bin/python"
fi
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --python "$BUILD_PY" auditwheel
else
uv pip install --python "$BUILD_PY" delocate
fi
if [[ "$RUNNER_OS" == "Windows" ]]; then
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
else
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
fi
- name: Set macOS environment variables
@@ -308,18 +309,66 @@ jobs:
- name: Install built packages for testing
run: |
# Create a virtual environment with the correct Python version
# Create uv-managed virtual environment with the requested interpreter
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate
# Install packages using --find-links to prioritize local builds
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
if [[ "$RUNNER_OS" == "Windows" ]]; then
UV_PY=".venv\\Scripts\\python.exe"
else
UV_PY=".venv/bin/python"
fi
# Install test dependencies using extras
uv pip install -e ".[test]"
# Install test dependency group only (avoids reinstalling project package)
uv pip install --python "$UV_PY" --group test
# Install core wheel built in this job
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$CORE_WHL" ]]; then
uv pip install --python "$UV_PY" "$CORE_WHL"
else
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
fi
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
fi
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$HNSW_WHL" ]]; then
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$HNSW_WHL" ]]; then
uv pip install --python "$UV_PY" "$HNSW_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
fi
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$DISKANN_WHL" ]]; then
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$DISKANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$DISKANN_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
fi
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$LEANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$LEANN_WHL"
else
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
fi
- name: Run tests with pytest
env:

17
.gitignore vendored
View File

@@ -18,6 +18,7 @@ demo/experiment_results/**/*.json
*.eml
*.emlx
*.json
*.png
!.vscode/*.json
*.sh
*.txt
@@ -94,11 +95,13 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
paru-bin/
CLAUDE.md
CLAUDE.local.md
.claude/*.local.*
.claude/local/*
benchmarks/data/
test_add/*
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*

View File

@@ -182,7 +182,10 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
#### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
@@ -193,6 +196,68 @@ Set your OpenAI API key as an environment variable:
export OPENAI_API_KEY="your-api-key-here"
```
Make sure to use `--llm openai` flag when using the CLI.
You can also specify the model name with `--llm-model <model-name>` flag.
</details>
<details>
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
```sh
export OPENAI_API_KEY="xxx"
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
```
To use OpenAI compatible endpoint with the CLI interface:
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
-----
Below is a list of base URLs for common providers to get you started.
### 🖥️ Local Inference Engines (Recommended for full privacy)
| Provider | Sample Base URL |
| ---------------- | --------------------------- |
| **Ollama** | `http://localhost:11434/v1` |
| **LM Studio** | `http://localhost:1234/v1` |
| **vLLM** | `http://localhost:8000/v1` |
| **llama.cpp** | `http://localhost:8080/v1` |
| **SGLang** | `http://localhost:30000/v1` |
| **LiteLLM** | `http://localhost:4000` |
-----
### ☁️ Cloud Providers
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
| Provider | Base URL |
| ---------------- | ---------------------------------------------------------- |
| **OpenAI** | `https://api.openai.com/v1` |
| **OpenRouter** | `https://openrouter.ai/api/v1` |
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
| **x.AI (Grok)** | `https://api.x.ai/v1` |
| **Groq AI** | `https://api.groq.com/openai/v1` |
| **DeepSeek** | `https://api.deepseek.com/v1` |
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` |
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
</details>
<details>
@@ -546,6 +611,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes
leann list
@@ -706,9 +774,8 @@ results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
## Reproduce Our Results
```bash
uv pip install -e ".[dev]" # Install dev dependencies
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!

View File

@@ -11,6 +11,7 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv()
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
default=None,
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
llm_group.add_argument(
"--thinking-budget",
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,

View File

@@ -0,0 +1,113 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
def _ensure_repo_paths_importable(current_file: str) -> None:
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores

View File

@@ -0,0 +1,112 @@
# pip install pdf2image
# pip install pymilvus
# pip install colpali_engine
# pip install tqdm
# pip install pillow
import os
import re
import sys
from pathlib import Path
from typing import cast
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
# Auto-select device: CUDA > MPS (mac) > CPU
_device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(_device_str)
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=_dtype,
device_map=device,
).eval()
print(f"Using device={_device_str}, dtype={_dtype}")
queries = [
"How to end-to-end retrieval with ColBert",
"Where is ColBERT performance Table, including text representation results?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
for query in qs:
query_np = query.float().numpy()
result = retriever.search(query_np, topk=1)
print(filepaths[result[0][1]])

View File

@@ -0,0 +1,477 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
out_path = str(out_dir / out_name)
else:
out_dir = output_base
out_dir.mkdir(parents=True, exist_ok=True)
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
else:
out_path = None
chosen_idx, max_sim = _generate_similarity_map(
model=model,
processor=processor,
image=img,
query=QUERY,
token_idx=token_idx,
output_path=out_path,
)
if out_path:
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
# %%
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)

0
benchmarks/__init__.py Normal file
View File

View File

@@ -0,0 +1,23 @@
BM25 vs DiskANN Baselines
```bash
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
```
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
- Machine-specific; results measured locally with the current repo.
DiskANN (NQ queries, search-only)
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
BM25
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
Notes
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -0,0 +1,183 @@
# /// script
# dependencies = [
# "pyserini"
# ]
# ///
# sudo pacman -S jdk21-openjdk
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
# sudo archlinux-java status
# sudo archlinux-java set java-21-openjdk
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
# fish_add_path --global $JAVA_HOME/bin
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
import argparse
import json
import os
import sys
import time
from statistics import mean
def load_queries(path: str, limit: int | None) -> list[str]:
queries: list[str] = []
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
_, ext = os.path.splitext(path)
if ext.lower() in {".jsonl", ".json"}:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# Not strict JSONL? treat the whole line as the query
queries.append(line)
continue
q = obj.get("query") or obj.get("text") or obj.get("question")
if q:
queries.append(str(q))
else:
with open(path, encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
queries.append(s)
if limit is not None and limit > 0:
queries = queries[:limit]
return queries
def percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
s = sorted(values)
k = (len(s) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(s) - 1)
if f == c:
return s[f]
return s[f] + (s[c] - s[f]) * (k - f)
def main():
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
ap.add_argument(
"--bm25-index",
default="benchmarks/data/indices/bm25_index",
help="Path to Pyserini Lucene index directory",
)
ap.add_argument(
"--queries",
default="benchmarks/data/queries/nq_open.jsonl",
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
)
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
ap.add_argument(
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
)
ap.add_argument(
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
)
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
args = ap.parse_args()
try:
from pyserini.search.lucene import LuceneSearcher
except Exception:
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
raise
if not os.path.isdir(args.bm25_index):
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
sys.exit(1)
queries = load_queries(args.queries, args.limit)
if not queries:
print("No queries loaded.", file=sys.stderr)
sys.exit(1)
print(f"Loaded {len(queries)} queries from {args.queries}")
print(f"Opening BM25 index: {args.bm25_index}")
searcher = LuceneSearcher(args.bm25_index)
# Some builds of pyserini require explicit set_bm25; others ignore
try:
searcher.set_bm25(k1=args.k1, b=args.b)
except Exception:
pass
latencies: list[float] = []
total_searches = 0
# Warmup
for i in range(min(args.warmup, len(queries))):
_ = searcher.search(queries[i], k=args.k)
t0 = time.time()
for i, q in enumerate(queries):
t1 = time.time()
hits = searcher.search(q, k=args.k)
t2 = time.time()
latencies.append(t2 - t1)
total_searches += 1
if args.fetch_docs:
# Optional doc fetch to include I/O time
for h in hits:
try:
_ = searcher.doc(h.docid)
except Exception:
pass
if (i + 1) % 50 == 0:
print(f"Processed {i + 1}/{len(queries)} queries")
t1 = time.time()
total_time = t1 - t0
if latencies:
avg = mean(latencies)
p50 = percentile(latencies, 50)
p90 = percentile(latencies, 90)
p95 = percentile(latencies, 95)
p99 = percentile(latencies, 99)
qps = total_searches / total_time if total_time > 0 else 0.0
else:
avg = p50 = p90 = p95 = p99 = qps = 0.0
print("BM25 Latency Report")
print(f" queries: {total_searches}")
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
print(f" avg per query: {avg:.6f} s")
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
if args.report:
payload = {
"queries": total_searches,
"k": args.k,
"k1": args.k1,
"b": args.b,
"avg_s": avg,
"p50_s": p50,
"p90_s": p90,
"p95_s": p95,
"p99_s": p99,
"total_time_s": total_time,
"qps": qps,
"index_dir": os.path.abspath(args.bm25_index),
"fetch_docs": bool(args.fetch_docs),
}
with open(args.report, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
print(f"Saved report to {args.report}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,124 @@
# /// script
# dependencies = [
# "leann-backend-diskann"
# ]
# ///
import argparse
import json
import time
from pathlib import Path
import numpy as np
def load_queries(path: Path, limit: int | None) -> list[str]:
out: list[str] = []
with open(path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
out.append(obj["query"])
if limit and len(out) >= limit:
break
return out
def main() -> None:
ap = argparse.ArgumentParser(
description="DiskANN baseline on real NQ queries (search-only timing)"
)
ap.add_argument(
"--index-dir",
default="benchmarks/data/indices/diskann_rpj_wiki",
help="Directory containing DiskANN files",
)
ap.add_argument("--index-prefix", default="ann")
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
ap.add_argument("--num-queries", type=int, default=200)
ap.add_argument("--top-k", type=int, default=10)
ap.add_argument("--complexity", type=int, default=62)
ap.add_argument("--threads", type=int, default=1)
ap.add_argument("--beam-width", type=int, default=1)
ap.add_argument("--cache-mechanism", type=int, default=2)
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
args = ap.parse_args()
index_dir = Path(args.index_dir).resolve()
if not index_dir.is_dir():
raise SystemExit(f"Index dir not found: {index_dir}")
qpath = Path(args.queries_file).resolve()
if not qpath.exists():
raise SystemExit(f"Queries file not found: {qpath}")
queries = load_queries(qpath, args.num_queries)
print(f"Loaded {len(queries)} queries from {qpath}")
# Compute embeddings once (exclude from timing)
from leann.api import compute_embeddings as _compute
embs = _compute(
queries,
model_name="facebook/contriever-msmarco",
mode="sentence-transformers",
use_server=False,
).astype(np.float32)
if embs.ndim != 2:
raise SystemExit("Embedding compute failed or returned wrong shape")
# Build searcher
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
index_prefix_path = str(index_dir / args.index_prefix)
searcher = _DiskannSearcher(
index_prefix_path,
num_threads=int(args.threads),
cache_mechanism=int(args.cache_mechanism),
num_nodes_to_cache=int(args.num_nodes_to_cache),
)
# Warmup (not timed)
_ = searcher.search(
embs[0:1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
# Timed loop
times: list[float] = []
for i in range(embs.shape[0]):
t0 = time.time()
_ = searcher.search(
embs[i : i + 1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
times.append(time.time() - t0)
times_sorted = sorted(times)
avg = float(sum(times) / len(times))
p50 = times_sorted[len(times) // 2]
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
print("\nDiskANN (NQ, search-only) Report")
print(f" queries: {len(times)}")
print(
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
)
print(f" avg per query: {avg:.6f} s")
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
print(f" QPS: {1.0 / avg:.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,141 @@
# Enron Emails Benchmark
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
## Layout
benchmarks/enron_emails/
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
- data/: Generated passages, queries, embeddings-related files
- baseline/: FAISS Flat baseline files
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
## Quickstart
1) Prepare the data and index
cd benchmarks/enron_emails
python setup_enron_emails.py --data-dir data
Notes:
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
Alternatively, pass a local path to `--emails-csv`.
Notes:
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
2) Run recall evaluation (Stage 2)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
3) Complexity sweep (Stage 3)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
4) Index comparison (Stage 4)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
5) Generation evaluation (Stage 5)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
6) Combined index + generation evaluation (Stages 4+5, recommended)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
Notes:
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
- `--baseline-dir` defaults to `baseline`
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
- `--model-name` defaults to `Qwen/Qwen3-8B`
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
Optional flags:
- --queries data/evaluation_queries.jsonl (custom queries file)
- --baseline-dir baseline (where FAISS baseline lives)
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
- --llm-backend hf|vllm (LLM backend for generation)
- --model-name Qwen/Qwen3-8B (LLM model for generation)
- --max-queries 1000 (limit number of queries for evaluation)
## Files Produced
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
- data/enron_index_hnsw.leann.*: LEANN index files
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
## Notes
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
## Stages Summary
- Stage 2 (Recall@3):
- Compares LEANN vs FAISS Flat baseline on Recall@3.
- Compact index runs with `recompute_embeddings=True`.
- Stage 3 (Binary Search for Complexity):
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
- Stage 4 (Index Comparison):
- Reports .index-only sizes for compact vs non-compact.
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
- Stores retrieval results for Stage 5 generation evaluation.
- Fails fast if compact recompute cannot run.
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
- First from the current run (when running `--stage all`), otherwise
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
- Stage 5 (Generation Evaluation):
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
- Measures generation timing separately from search timing.
- Requires Stage 4 results (no additional searching performed).
## Example Results
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
- Stage 3 (Binary Search):
- Minimal complexity achieving 90% Recall@3: 88
- Sampled points:
- C=8 → 59.9% Recall@3
- C=72 → 89.4% Recall@3
- C=88 → 90.2% Recall@3
- C=96 → 90.7% Recall@3
- C=112 → 91.1% Recall@3
- C=136 → 91.3% Recall@3
- C=256 → 92.0% Recall@3
- Stage 4 (Index Sizes, .index only):
- Compact: ~2.2 MB
- Non-compact: ~82.0 MB
- Storage saving by compact: ~97.3%
- Stage 4 (Search Timing, 988 queries, complexity=88):
- Non-compact (no recompute): ~0.0075 s avg per query
- Compact (with recompute): ~1.981 s avg per query
- Speed ratio (non-compact/compact): ~0.0038x
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
- Average generation time: ~22.302 s per query
- Total queries processed: 988
- LLM backend: HuggingFace transformers
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
Full JSON output is saved by the script (see `--output`), e.g.:
`benchmarks/enron_emails/results_enron_stage45.json`.

View File

@@ -0,0 +1 @@
downloads/

View File

@@ -0,0 +1,614 @@
"""
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
On errors, fail fast without fallbacks.
"""
import argparse
import json
import logging
import os
import pickle
from pathlib import Path
import numpy as np
from leann import LeannBuilder, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 using FAISS Flat as ground truth"""
from leann.api import compute_embeddings
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
for i, query in enumerate(queries):
# Compute query embedding with the same model/mode as the index
q_emb = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS Flat ground truth
n = q_emb.shape[0]
k = 3
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(q_emb),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN (may require embedding server depending on index configuration)
results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {r.id for r in results}
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0
total_recall += recall
if i < 3:
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS: {list(baseline_ids)}")
print(f" LEANN: {list(test_ids)}")
print(f" ∩: {list(intersection)}")
avg = total_recall / max(1, len(queries))
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
return avg
def cleanup(self):
if hasattr(self, "searcher"):
self.searcher.cleanup()
class EnronEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
queries: list[str] = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
if "query" in data:
queries.append(data["query"])
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
return queries
def cleanup(self):
if self.searcher:
self.searcher.cleanup()
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes (.index only), similar to LAION bench."""
print("📏 Analyzing index sizes (.index only)...")
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem
sizes: dict[str, float] = {}
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json"
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
sizes["index_only_mb"] = (
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison using current passages and embeddings."""
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source and embedding model
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
# Load all passages and ids
ids: list[str] = []
texts: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
texts.append(data["text"])
# Compute embeddings using the same method as LEANN
from leann.api import compute_embeddings
embeddings = compute_embeddings(
texts,
meta["embedding_model"],
mode=meta.get("embedding_mode", "sentence-transformers"),
use_server=False,
).astype(np.float32)
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False,
is_compact=False,
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = EnronEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
) -> dict:
"""Compare search speed for non-compact vs compact indexes."""
import time
results: dict = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
"retrieval_results": [], # Store retrieval results for Stage 5
}
print("⚡ Comparing search performance between indexes...")
# Non-compact (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for q in test_queries:
t0 = time.time()
_ = non_compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=False
)
results["non_compact"]["search_times"].append(time.time() - t0)
# Compact (with recompute). Fail fast if it cannot run.
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for q in test_queries:
t0 = time.time()
docs = compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=True
)
results["compact"]["search_times"].append(time.time() - t0)
# Store retrieval results for Stage 5
results["retrieval_results"].append(
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
)
compact_searcher.cleanup()
if results["non_compact"]["search_times"]:
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
if results["compact"]["search_times"]:
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
if results["avg_search_times"].get("compact", 0) > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = 0.0
non_compact_searcher.cleanup()
return results
def evaluate_complexity(
self,
recall_eval: "RecallEvaluator",
queries: list[str],
target: float = 0.90,
c_min: int = 8,
c_max: int = 256,
max_iters: int = 10,
recompute: bool = False,
) -> dict:
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
def round_c(x: int) -> int:
# snap to multiple of 8 like other benches typically do
return max(1, int((x + 7) // 8) * 8)
metrics: list[dict] = []
lo = round_c(c_min)
hi = round_c(c_max)
print(
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
)
# Ensure upper bound can reach target; expand if needed (up to a cap)
r_lo = recall_eval.evaluate_recall_at_3(
queries, complexity=lo, recompute_embeddings=recompute
)
metrics.append({"complexity": lo, "recall_at_3": r_lo})
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
cap = 1024
while r_hi < target and hi < cap:
lo = hi
r_lo = r_hi
hi = round_c(hi * 2)
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
if r_hi < target:
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
print("📈 Observations:")
for m in metrics:
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
# Binary search within [lo, hi]
best = hi
iters = 0
while lo < hi and iters < max_iters:
mid = round_c((lo + hi) // 2)
r_mid = recall_eval.evaluate_recall_at_3(
queries, complexity=mid, recompute_embeddings=recompute
)
metrics.append({"complexity": mid, "recall_at_3": r_mid})
if r_mid >= target:
best = mid
hi = mid
else:
lo = mid + 8 # move past mid, respecting multiple-of-8 step
iters += 1
print("📈 Binary search results (sampled points):")
# Print unique complexity entries ordered by complexity
for m in sorted(
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
):
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
def main():
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all", "45"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument(
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
)
parser.add_argument(
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
)
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
args = parser.parse_args()
# Resolve queries file: if default path not found, fall back to index's directory
if not os.path.exists(args.queries):
from pathlib import Path
idx_dir = Path(args.index).parent
fallback_q = idx_dir / "evaluation_queries.jsonl"
if fallback_q.exists():
args.queries = str(fallback_q)
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_enron_emails.py first to build the baseline")
raise SystemExit(1)
results_out: dict = {}
if args.stage in ("2", "all"):
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[:10]
print(f"🧪 Using first {len(queries)} queries")
complexity = args.complexity or 64
r = evaluator.evaluate_recall_at_3(queries, complexity)
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 2 completed!\n")
if args.stage in ("3", "all"):
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[: args.max_queries]
print(f"🧪 Using first {len(queries)} queries")
# Build non-compact index for fast binary search (recompute_embeddings=False)
from pathlib import Path
index_path = Path(args.index)
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact evaluator for binary search with recompute=False
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
sweep = enron_eval.evaluate_complexity(
evaluator_nc, queries, target=args.target_recall, recompute=False
)
results_out["stage3"] = sweep
# Persist default stage 3 results near the index for Stage 4 auto-pickup
from pathlib import Path
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
with open(default_stage3_path, "w", encoding="utf-8") as f:
json.dump({"stage3": sweep}, f, indent=2)
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
evaluator_nc.cleanup()
enron_eval.cleanup()
print("✅ Stage 3 completed!\n")
if args.stage in ("4", "all", "45"):
print("🚀 Starting Stage 4: Index size + performance comparison")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
test_q = queries[: min(args.max_queries, len(queries))]
current_sizes = enron_eval.analyze_index_sizes()
# Build non-compact index for comparison (no fallback)
from pathlib import Path
index_path = Path(args.index)
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
nc_eval = EnronEvaluator(non_compact_path)
if (
current_sizes.get("index_only_mb", 0) > 0
and non_compact_sizes.get("index_only_mb", 0) > 0
):
storage_saving_percent = max(
0.0,
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
)
else:
storage_saving_percent = 0.0
if args.complexity is None:
# Prefer in-session Stage 3 result
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
complexity = results_out["stage3"]["best_complexity"]
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
else:
# Try to load last saved Stage 3 result near index
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
if default_stage3_path.exists():
with open(default_stage3_path, encoding="utf-8") as f:
prev = json.load(f)
complexity = prev.get("stage3", {}).get("best_complexity")
if complexity is None:
raise SystemExit(
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
)
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
else:
raise SystemExit(
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
)
else:
complexity = args.complexity
comp = enron_eval.compare_index_performance(
non_compact_path, args.index, test_q, complexity=complexity
)
results_out["stage4"] = {
"current_index": current_sizes,
"non_compact_index": non_compact_sizes,
"storage_saving_percent": storage_saving_percent,
"performance_comparison": comp,
}
nc_eval.cleanup()
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
# Check if Stage 4 results exist
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
print("❌ Stage 5 requires Stage 4 retrieval results")
print("💡 Run Stage 4 first or use --stage all")
raise SystemExit(1)
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
if not retrieval_results:
print("❌ No retrieval results found from Stage 4")
raise SystemExit(1)
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
# Load LLM
try:
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Run generation using stored retrieval results
import time
from llm_utils import create_prompt
generation_times = []
responses = []
print("🤖 Running generation on pre-retrieved results...")
for i, item in enumerate(retrieval_results):
query = item["query"]
retrieved_docs = item["retrieved_docs"]
# Prepare context from retrieved docs
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
prompt = create_prompt(context, query, "emails")
# Time generation only
gen_start = time.time()
response = llm_func(prompt)
gen_time = time.time() - gen_start
generation_times.append(gen_time)
responses.append(response)
if i < 3:
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
avg_gen_time = sum(generation_times) / len(generation_times)
print("\n📊 Generation Results:")
print(f" Total Queries: {len(retrieval_results)}")
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
print(" (Search time from Stage 4)")
results_out["stage5"] = {
"total_queries": len(retrieval_results),
"avg_generation_time": avg_gen_time,
"generation_times": generation_times,
"responses": responses,
}
# Show sample results
print("\n📝 Sample Results:")
for i in range(min(3, len(retrieval_results))):
query = retrieval_results[i]["query"]
response = responses[i]
print(f" Q{i + 1}: {query[:60]}...")
print(f" A{i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
print("💡 Make sure transformers/vllm is installed and model is available")
print("✅ Stage 5 completed!\n")
if args.output and results_out:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results_out, f, indent=2)
print(f"📝 Saved results to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,359 @@
"""
Enron Emails Benchmark Setup Script
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
"""
import argparse
import csv
import json
import os
import re
from collections.abc import Iterable
from email import message_from_string
from email.policy import default
from pathlib import Path
from typing import Optional
from leann import LeannBuilder
class EnronSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
self.index_path = self.data_dir / "enron_index_hnsw.leann"
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
self.downloads_dir = self.data_dir / "downloads"
self.downloads_dir.mkdir(parents=True, exist_ok=True)
# ----------------------------
# Dataset acquisition
# ----------------------------
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
"""Return a path to emails.csv, downloading from Kaggle if needed."""
if emails_csv:
p = Path(emails_csv)
if not p.exists():
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
return str(p)
print(
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
)
try:
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
)
candidate = self.downloads_dir / "emails.csv"
if candidate.exists():
print(f"✅ Downloaded emails.csv: {candidate}")
return str(candidate)
else:
raise FileNotFoundError(
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
)
except Exception as e:
print(
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
)
print(
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
)
raise e
# ----------------------------
# Data preparation
# ----------------------------
@staticmethod
def _extract_message_id(raw_email: str) -> str:
msg = message_from_string(raw_email, policy=default)
val = msg.get("Message-ID", "")
if val.startswith("<") and val.endswith(">"):
val = val[1:-1]
return val or ""
@staticmethod
def _split_header_body(raw_email: str) -> tuple[str, str]:
parts = raw_email.split("\n\n", 1)
if len(parts) == 2:
return parts[0].strip(), parts[1].strip()
# Heuristic fallback
first_lines = raw_email.splitlines()
if first_lines and ":" in first_lines[0]:
return raw_email.strip(), ""
return "", raw_email.strip()
@staticmethod
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
text = (text or "").strip()
if not text:
return []
if chunk_words <= 0:
return [text]
words = text.split()
if not words:
return []
limit = len(words)
if not keep_last:
limit = (len(words) // chunk_words) * chunk_words
if limit == 0:
return []
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
return [c for c in (s.strip() for s in chunks) if c]
def _iter_passages_from_csv(
self,
emails_csv: Path,
chunk_words: int = 256,
keep_last_header: bool = True,
keep_last_body: bool = True,
max_emails: int | None = None,
) -> Iterable[dict]:
with open(emails_csv, encoding="utf-8") as f:
reader = csv.DictReader(f)
count = 0
for i, row in enumerate(reader):
if max_emails is not None and count >= max_emails:
break
raw_message = row.get("message", "")
email_file_id = row.get("file", "")
if not raw_message.strip():
continue
message_id = self._extract_message_id(raw_message)
if not message_id:
# Fallback ID based on CSV position and file path
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
message_id = f"enron_{i}_{safe_file}"
header, body = self._split_header_body(raw_message)
# Header chunks
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": True,
"email_file_id": email_file_id,
},
}
# Body chunks
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": False,
"email_file_id": email_file_id,
},
}
count += 1
# ----------------------------
# Build LEANN index and FAISS baseline
# ----------------------------
def build_leann_index(
self,
emails_csv: Optional[str],
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
chunk_words: int = 256,
max_emails: int | None = None,
) -> str:
emails_csv_path = self.ensure_emails_csv(emails_csv)
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True,
is_compact=True,
num_threads=4,
)
# Stream passages and add to builder
preview_written = 0
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
for p in self._iter_passages_from_csv(
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
):
builder.add_text(p["text"], metadata=p["metadata"])
if preview_written < 200:
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
preview_written += 1
print(f"🔨 Building index at {self.index_path}...")
builder.build_index(str(self.index_path))
print("✅ LEANN index built!")
return str(self.index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
print("🔨 Building FAISS Flat baseline from LEANN passages...")
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read meta for passage source and embedding model
meta_path = f"{index_path}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
# Load passages from builder output so IDs match LEANN
passages: list[str] = []
passage_ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"]) # builder-assigned ID
print(f"📄 Loaded {len(passages)} passages for baseline")
print(f"🤖 Embedding model: {embedding_model}")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
# Build FAISS IndexFlatIP
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
emb_f32 = embeddings.astype(np.float32)
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as pf:
pickle.dump(passage_ids, pf)
print(f"✅ FAISS baseline saved: {baseline_path}")
print(f"✅ Metadata saved: {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
# ----------------------------
# Queries (optional): prepare evaluation queries file
# ----------------------------
def prepare_queries(self, min_realism: float = 0.85) -> Path:
print(
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
)
try:
from datasets import load_dataset
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
except Exception as e:
print(f"⚠️ Failed to load dataset: {e}")
return self.queries_file
kept = 0
with open(self.queries_file, "w", encoding="utf-8") as out:
for i, item in enumerate(ds):
how_realistic = float(item.get("how_realistic", 0.0))
if how_realistic < min_realism:
continue
qid = str(item.get("id", f"enron_q_{i}"))
query = item.get("question", "")
if not query:
continue
record = {
"id": qid,
"query": query,
# For reference only, not used in recall metric below
"gt_message_ids": item.get("message_ids", []),
}
out.write(json.dumps(record) + "\n")
kept += 1
print(f"✅ Wrote {kept} queries to {self.queries_file}")
return self.queries_file
def main():
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
parser.add_argument(
"--emails-csv",
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
)
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model for LEANN",
)
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
args = parser.parse_args()
setup = EnronSetup(args.data_dir)
# Build index
if not args.skip_build:
index_path = setup.build_leann_index(
emails_csv=args.emails_csv,
backend=args.backend,
embedding_model=args.embedding_model,
chunk_words=args.chunk_words,
max_emails=args.max_emails,
)
# Build FAISS baseline from the same passages & embeddings
setup.build_faiss_flat_baseline(index_path)
else:
print("⏭️ Skipping LEANN index build and baseline")
# Queries file (optional)
if not args.skip_queries:
setup.prepare_queries()
else:
print("⏭️ Skipping query preparation")
print("\n🎉 Enron Emails setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("Next steps:")
print(
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
# FinanceBench Benchmark for LEANN-RAG
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
## Dataset
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
- **Questions**: 150 financial Q&A examples
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
## Structure
```
benchmarks/financebench/
├── setup_financebench.py # Downloads PDFs and builds index
├── evaluate_financebench.py # Intelligent evaluation script
├── data/
│ ├── financebench_merged.jsonl # Q&A dataset
│ ├── pdfs/ # Downloaded financial documents
│ └── index/ # LEANN indexes
│ └── financebench_full_hnsw.leann
└── README.md
```
## Usage
### 1. Setup (Download & Build Index)
```bash
cd benchmarks/financebench
python setup_financebench.py
```
This will:
- Download the 150 Q&A examples
- Download all 368 PDF documents (parallel processing)
- Build a LEANN index from 53K+ text chunks
- Verify setup with test query
### 2. Evaluation
```bash
# Basic retrieval evaluation
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
# RAG generation evaluation with Qwen3-8B
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
```
## Evaluation Methods
### Retrieval Evaluation
Uses intelligent matching with three strategies:
1. **Exact text overlap** - Direct substring matches
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
3. **Semantic similarity** - Word overlap with 20% threshold
### QA Evaluation
LLM-based answer evaluation using GPT-4o:
- Handles numerical rounding and equivalent representations
- Considers fractions, percentages, and decimal equivalents
- Evaluates semantic meaning rather than exact text match
## Benchmark Results
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
**Retrieval Metrics:**
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
- **Number Match Rate**: 120.7% (key financial figures matched)*
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
- **Average Search Time**: 0.097s
**QA Metrics:**
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
- **Average QA Time**: 4.71s (end-to-end response time)
**System Performance:**
- **Index Size**: 53,985 chunks from 368 PDFs
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
### LEANN-RAG Generation Performance (Qwen3-8B)
- **Stage 4 (Index Comparison):**
- Compact Index: 5.0 MB
- Non-compact Index: 172.2 MB
- **Storage Saving**: 97.1%
- **Search Performance**:
- Non-compact (no recompute): 0.009s avg per query
- Compact (with recompute): 2.203s avg per query
- Speed ratio: 0.004x
**Generation Evaluation (20 queries, complexity=64):**
- **Average Search Time**: 1.638s per query
- **Average Generation Time**: 45.957s per query
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
- **Total Questions Processed**: 20
## Options
```bash
# Use different backends
python setup_financebench.py --backend diskann
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
# Use different embedding models
python setup_financebench.py --embedding-model facebook/contriever
```

View File

@@ -0,0 +1,923 @@
"""
FinanceBench Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
from typing import Optional
import numpy as np
import openai
from leann import LeannChat, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for given queries at specified complexity"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(queries)
for i, query in enumerate(queries):
# Get ground truth: search with FAISS flat
from leann.api import compute_embeddings
query_embedding = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity
test_results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class FinanceBenchEvaluator:
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
self.index_path = index_path
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
self.searcher = LeannSearcher(index_path)
self.chat = LeannChat(index_path) if openai_api_key else None
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
"""Load FinanceBench dataset"""
data = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
print(f"📊 Loaded {len(data)} FinanceBench examples")
return data
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes with and without embeddings"""
print("📏 Analyzing index sizes...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes = {}
total_with_embeddings = 0
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
for file_path, name in [
(index_file, "index"),
(meta_file, "metadata"),
(passages_file, "passages_text"),
(passages_idx_file, "passages_index"),
]:
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
sizes[name] = size_mb
total_with_embeddings += size_mb
else:
sizes[name] = 0
sizes["total_with_embeddings"] = total_with_embeddings
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
return sizes
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
"""Create a compact index for comparison purposes"""
print("🏗️ Building compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage
**meta.get("backend_kwargs", {}),
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building compact index at {compact_index_path}...")
builder.build_index(compact_index_path)
# Analyze the compact index size
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
compact_sizes = temp_evaluator.analyze_index_sizes()
compact_sizes["index_type"] = "compact"
return compact_sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build non-compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
builder.build_index(non_compact_index_path)
# Analyze the non-compact index size
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
import time
from leann import LeannSearcher
# Test queries
test_queries = [item["question"] for item in test_data[:5]]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for query in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for query in test_queries:
start_time = time.time()
_ = compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def evaluate_timing_breakdown(
self, data: list[dict], max_samples: Optional[int] = None
) -> dict:
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
if not self.chat or not self.openai_client:
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
return {
"total_questions": 0,
"avg_search_time": 0.0,
"avg_generation_time": 0.0,
"avg_total_time": 0.0,
"accuracy": 0.0,
}
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
if max_samples:
data = data[:max_samples]
print(f"📝 Using first {max_samples} samples for timing evaluation")
search_times = []
generation_times = []
total_times = []
correct_answers = 0
for i, item in enumerate(data):
question = item["question"]
ground_truth = item["answer"]
try:
# Hack: Monkey-patch the ask method to capture internal timing
original_ask = self.chat.ask
captured_search_time = None
captured_generation_time = None
def patched_ask(*args, **kwargs):
nonlocal captured_search_time, captured_generation_time
# Time the search part
search_start = time.time()
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
captured_search_time = time.time() - search_start
# Time the generation part
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {args[0]}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
generation_start = time.time()
answer = self.chat.llm.ask(prompt)
captured_generation_time = time.time() - generation_start
return answer
# Apply the patch
self.chat.ask = patched_ask
# Time the total QA
total_start = time.time()
generated_answer = self.chat.ask(question)
total_time = time.time() - total_start
# Restore original method
self.chat.ask = original_ask
# Store the timings
search_times.append(captured_search_time)
generation_times.append(captured_generation_time)
total_times.append(total_time)
# Check accuracy using LLM as judge
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
if is_correct:
correct_answers += 1
status = "" if is_correct else ""
print(
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
)
print(f" GT: {ground_truth}")
print(f" Gen: {generated_answer[:100]}...")
except Exception as e:
print(f" ❌ Error: {e}")
search_times.append(0.0)
generation_times.append(0.0)
total_times.append(0.0)
accuracy = correct_answers / len(data) if data else 0.0
metrics = {
"total_questions": len(data),
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
"avg_generation_time": sum(generation_times) / len(generation_times)
if generation_times
else 0.0,
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
"accuracy": accuracy,
"correct_answers": correct_answers,
"search_times": search_times,
"generation_times": generation_times,
"total_times": total_times,
}
return metrics
def _check_answer_accuracy(
self, generated_answer: str, ground_truth: str, question: str
) -> bool:
"""Check if generated answer matches ground truth using LLM as judge"""
judge_prompt = f"""You are an expert judge evaluating financial question answering.
Question: {question}
Ground Truth Answer: {ground_truth}
Generated Answer: {generated_answer}
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
1. Numerical accuracy (exact values, units, currency)
2. Key financial concepts and terminology
3. Overall factual correctness
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
try:
judge_response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=10,
temperature=0,
)
judgment = judge_response.choices[0].message.content.strip().upper()
return judgment == "CORRECT"
except Exception as e:
print(f" ⚠️ Judge error: {e}, falling back to string matching")
# Fallback to simple string matching
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
return gt_clean in gen_clean
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 EVALUATION RESULTS")
print("=" * 50)
# Index comparison analysis
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
print("\n📏 Index Comparison Analysis:")
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
print("\n📊 Accuracy:")
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
print(
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
)
print("\n📊 Timing Breakdown:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
if timing_metrics.get("avg_total_time", 0) > 0:
search_pct = (
timing_metrics.get("avg_search_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
gen_pct = (
timing_metrics.get("avg_generation_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
print("\n📈 Time Distribution:")
print(f" Search: {search_pct:.1f}%")
print(f" Generation: {gen_pct:.1f}%")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
parser.add_argument(
"--stage",
choices=["2", "3", "4", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
)
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_financebench.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load FinanceBench queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Test with more queries for robust measurement
test_queries = queries[:2000]
print(f"🧪 Testing with {len(test_queries)} queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search (will be reused in Stage 4)
print("🏗️ Creating non-compact index for binary search...")
evaluator = FinanceBenchEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Use more queries for robust measurement
test_queries = queries[:200]
print(f"🧪 Testing with {len(test_queries)} queries")
# Binary search for 90% recall complexity (without recompute for speed)
target_recall = 0.9
min_complexity, max_complexity = 1, 32
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_queries[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Comprehensive evaluation with dual index comparison
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
# Use FinanceBench evaluator for QA evaluation
evaluator = FinanceBenchEvaluator(
args.index, args.openai_api_key if args.llm_backend == "openai" else None
)
print("📖 Loading FinanceBench dataset...")
data = evaluator.load_dataset(args.dataset)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
from pathlib import Path
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes
print("\n📊 Index size comparison:")
print(
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print(
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print("\n📊 Index-only size comparison (.index file only):")
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
# Use index-only size for fair comparison (same as Enron emails)
storage_saving = (
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
/ non_compact_size_metrics["index_only_mb"]
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for performance comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, data[:10], complexity=complexity
)
# Step 5: Generation evaluation
test_samples = 20
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
if args.llm_backend == "openai" and args.openai_api_key:
print("🔍🤖 Running OpenAI-based generation evaluation...")
evaluation_start = time.time()
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
evaluation_time = time.time() - evaluation_start
else:
print(
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
)
try:
# Load LLM
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Simple generation evaluation
queries = [item["question"] for item in data[:test_samples]]
gen_results = evaluate_rag(
evaluator.searcher,
llm_func,
queries,
domain="finance",
complexity=complexity,
)
timing_metrics = {
"total_questions": len(queries),
"avg_search_time": gen_results["avg_search_time"],
"avg_generation_time": gen_results["avg_generation_time"],
"results": gen_results["results"],
}
evaluation_time = time.time()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
timing_metrics = {
"total_questions": 0,
"avg_search_time": 0,
"avg_generation_time": 0,
}
evaluation_time = 0
# Combine all metrics
combined_metrics = {
**timing_metrics,
"total_evaluation_time": evaluation_time,
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print results
print("\n📊 Generation Results:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review accuracy and timing breakdown for performance optimization")
print(" - Run full evaluation on complete dataset if needed")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
from pathlib import Path
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,462 @@
#!/usr/bin/env python3
"""
FinanceBench Complete Setup Script
Downloads all PDFs and builds full LEANN datastore
"""
import argparse
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import pymupdf
import requests
from leann import LeannBuilder, LeannSearcher
from tqdm import tqdm
class FinanceBenchSetup:
def __init__(self, data_dir: str = "data"):
self.base_dir = Path(__file__).parent # benchmarks/financebench/
self.data_dir = self.base_dir / data_dir
self.pdf_dir = self.data_dir / "pdfs"
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
self.index_dir = self.data_dir / "index"
self.download_lock = Lock()
def download_dataset(self):
"""Download the main FinanceBench dataset"""
print("📊 Downloading FinanceBench dataset...")
self.data_dir.mkdir(parents=True, exist_ok=True)
if self.dataset_file.exists():
print(f"✅ Dataset already exists: {self.dataset_file}")
return
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self.dataset_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Dataset downloaded: {self.dataset_file}")
def get_pdf_list(self):
"""Get list of all PDF files from GitHub"""
print("📋 Fetching PDF list from GitHub...")
response = requests.get(
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
)
response.raise_for_status()
pdf_files = response.json()
print(f"Found {len(pdf_files)} PDF files")
return pdf_files
def download_single_pdf(self, pdf_info, position):
"""Download a single PDF file"""
pdf_name = pdf_info["name"]
pdf_path = self.pdf_dir / pdf_name
# Skip if already downloaded
if pdf_path.exists() and pdf_path.stat().st_size > 0:
return f"{pdf_name} (cached)"
try:
# Download PDF
response = requests.get(pdf_info["download_url"], timeout=60)
response.raise_for_status()
# Write to file
with self.download_lock:
with open(pdf_path, "wb") as f:
f.write(response.content)
return f"{pdf_name} ({len(response.content) // 1024}KB)"
except Exception as e:
return f"{pdf_name}: {e!s}"
def download_all_pdfs(self, max_workers: int = 5):
"""Download all PDF files with parallel processing"""
self.pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_files = self.get_pdf_list()
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_pdf = {
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
for i, pdf_info in enumerate(pdf_files)
}
# Process completed downloads with progress bar
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
for future in as_completed(future_to_pdf):
result = future.result()
pbar.set_postfix_str(result.split()[-1] if "" in result else "Error")
pbar.update(1)
# Verify downloads
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
# Show any failures
missing_pdfs = []
for pdf_info in pdf_files:
pdf_path = self.pdf_dir / pdf_info["name"]
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
missing_pdfs.append(pdf_info["name"])
if missing_pdfs:
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
for pdf in missing_pdfs[:5]: # Show first 5
print(f" - {pdf}")
if len(missing_pdfs) > 5:
print(f" ... and {len(missing_pdfs) - 5} more")
def build_leann_index(
self,
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
):
"""Build LEANN index from all PDFs"""
print(f"🏗️ Building LEANN index with {backend} backend...")
# Check if we have PDFs
pdf_files = list(self.pdf_dir.glob("*.pdf"))
if not pdf_files:
raise RuntimeError("No PDF files found! Run download first.")
print(f"Found {len(pdf_files)} PDF files to process")
start_time = time.time()
# Initialize builder with standard compact configuration
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage (pruned)
num_threads=4,
)
# Process PDFs and extract text
total_chunks = 0
failed_pdfs = []
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
try:
chunks = self.extract_pdf_text(pdf_path)
for chunk in chunks:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
total_chunks += 1
except Exception as e:
print(f"❌ Failed to process {pdf_path.name}: {e}")
failed_pdfs.append(pdf_path.name)
continue
# Build index in index directory
self.index_dir.mkdir(parents=True, exist_ok=True)
index_path = self.index_dir / f"financebench_full_{backend}.leann"
print(f"🔨 Building index: {index_path}")
builder.build_index(str(index_path))
build_time = time.time() - start_time
print("✅ Index built successfully!")
print(f" 📁 Index path: {index_path}")
print(f" 📊 Total chunks: {total_chunks:,}")
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
print(f" ⏱️ Build time: {build_time:.1f}s")
if failed_pdfs:
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
return str(index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
print("🔨 Building FAISS Flat baseline...")
import os
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read metadata from the built index
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
import json
meta = json.loads(f.read())
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
print(f"📊 Loading passages from {passage_file}...")
print(f"🤖 Using embedding model: {embedding_model}")
# Load all passages for baseline
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"📄 Loaded {len(passages)} passages")
# Compute embeddings using the same method as LEANN
print("🧮 Computing embeddings...")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
print(f"📐 Embedding shape: {embeddings.shape}")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(passage_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
"""Extract and chunk text from a PDF file"""
chunks = []
doc = pymupdf.open(pdf_path)
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text() # type: ignore
if not text.strip():
continue
# Create metadata
metadata = {
"source_file": pdf_path.name,
"page_number": page_num + 1,
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
"company": pdf_path.name.split("_")[0],
"doc_period": self.extract_year_from_filename(pdf_path.name),
}
# Use recursive character splitting like LangChain
if len(text.split()) > 500:
# Split by double newlines (paragraphs)
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
current_chunk = ""
for para in paragraphs:
# If adding this paragraph would make chunk too long, save current chunk
if current_chunk and len((current_chunk + " " + para).split()) > 300:
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
current_chunk = para
else:
current_chunk = (current_chunk + " " + para).strip()
# Add the last chunk
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
else:
# Page is short enough, use as single chunk
chunks.append(
{
"text": text.strip(),
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
}
)
doc.close()
return chunks
def extract_year_from_filename(self, filename: str) -> str:
"""Extract year from PDF filename"""
# Try to find 4-digit year in filename
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "unknown"
def verify_setup(self, index_path: str):
"""Verify the setup by testing a simple query"""
print("🧪 Verifying setup with test query...")
try:
searcher = LeannSearcher(index_path)
# Test query
test_query = "What is the capital expenditure for 3M in 2018?"
results = searcher.search(test_query, top_k=3)
print(f"✅ Test query successful! Found {len(results)} results:")
for i, result in enumerate(results, 1):
company = result.metadata.get("company", "Unknown")
year = result.metadata.get("doc_period", "Unknown")
page = result.metadata.get("page_number", "Unknown")
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
print(f" {result.text[:100]}...")
searcher.cleanup()
print("✅ Setup verification completed successfully!")
except Exception as e:
print(f"❌ Setup verification failed: {e}")
raise
def main():
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument(
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
)
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model",
)
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
parser.add_argument(
"--build-baseline-only",
action="store_true",
help="Only build FAISS baseline from existing index",
)
args = parser.parse_args()
print("🏦 FinanceBench Complete Setup")
print("=" * 50)
setup = FinanceBenchSetup(args.data_dir)
try:
if args.build_baseline_only:
# Only build baseline from existing index
index_path = setup.index_dir / f"financebench_full_{args.backend}"
index_file = f"{index_path}.index"
meta_file = f"{index_path}.leann.meta.json"
if not os.path.exists(index_file) or not os.path.exists(meta_file):
print("❌ Index files not found:")
print(f" Index: {index_file}")
print(f" Meta: {meta_file}")
print("💡 Run without --build-baseline-only to build the index first")
exit(1)
print(f"🔨 Building baseline from existing index: {index_path}")
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
print(f"✅ Baseline built at {baseline_path}")
return
# Step 1: Download dataset
setup.download_dataset()
# Step 2: Download PDFs
if not args.skip_download:
setup.download_all_pdfs(max_workers=args.max_workers)
else:
print("⏭️ Skipping PDF download")
# Step 3: Build LEANN index
if not args.skip_build:
index_path = setup.build_leann_index(
backend=args.backend, embedding_model=args.embedding_model
)
# Step 4: Build FAISS flat baseline
print("\n🔨 Building FAISS flat baseline...")
baseline_path = setup.build_faiss_flat_baseline(index_path)
print(f"✅ Baseline built at {baseline_path}")
# Step 5: Verify setup
setup.verify_setup(index_path)
else:
print("⏭️ Skipping index building")
print("\n🎉 FinanceBench setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("\nNext steps:")
print(
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
)
print(
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
)
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "faiss-cpu",
# "numpy",
# "sentence-transformers",
# "torch",
# "tqdm",
# ]
# ///
"""
Independent recall verification script using standard FAISS.
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
"""
import json
import time
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
"""
Direct embedding computation using sentence-transformers.
Copied logic to avoid dependency issues.
"""
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
print(f"Computing embeddings for {len(chunks)} chunks...")
embeddings = model.encode(
chunks,
show_progress_bar=True,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=False,
)
return embeddings.astype(np.float32)
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
"""Load FinanceBench queries from dataset"""
queries = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
if len(queries) >= max_queries:
break
return queries
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
"""Load passages from LEANN index structure"""
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
index_dir = Path(index_path).parent
passage_file = index_dir / Path(passage_file).name
print(f"Loading passages from {passage_file}")
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Loading passages"):
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"Loaded {len(passages)} passages")
return passages, passage_ids
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
dimension = embeddings.shape[1]
# Build Flat index (ground truth)
print("Building FAISS IndexFlatIP (ground truth)...")
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(embeddings)
# Build HNSW index
print("Building FAISS IndexHNSWFlat...")
M = 32 # Same as LEANN default
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
hnsw_index.add(embeddings)
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
return flat_index, hnsw_index
def evaluate_recall_at_k(
query_embeddings: np.ndarray,
flat_index: faiss.Index,
hnsw_index: faiss.Index,
passage_ids: list[str],
k: int = 3,
ef_search: int = 64,
) -> float:
"""Evaluate recall@k comparing HNSW vs Flat"""
# Set search parameters for HNSW
hnsw_index.hnsw.efSearch = ef_search
total_recall = 0.0
num_queries = query_embeddings.shape[0]
for i in range(num_queries):
query = query_embeddings[i : i + 1] # Keep 2D shape
# Get ground truth from Flat index (standard FAISS API)
flat_distances, flat_indices = flat_index.search(query, k)
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
# Get results from HNSW index (standard FAISS API)
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
# Calculate recall
intersection = ground_truth_ids.intersection(hnsw_ids)
recall = len(intersection) / k
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
print(f" Flat: {list(ground_truth_ids)}")
print(f" HNSW: {list(hnsw_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
return avg_recall
def main():
# Configuration
dataset_path = "data/financebench_merged.jsonl"
index_path = "data/index/financebench_full_hnsw.leann"
embedding_model = "sentence-transformers/all-mpnet-base-v2"
print("🔍 FAISS Recall Verification")
print("=" * 50)
# Check if files exist
if not Path(dataset_path).exists():
print(f"❌ Dataset not found: {dataset_path}")
return
if not Path(f"{index_path}.meta.json").exists():
print(f"❌ Index metadata not found: {index_path}.meta.json")
return
# Load data
print("📖 Loading FinanceBench queries...")
queries = load_financebench_queries(dataset_path, max_queries=50)
print(f"Loaded {len(queries)} queries")
print("📄 Loading passages from LEANN index...")
passages, passage_ids = load_passages_from_leann_index(index_path)
# Compute embeddings
print("🧮 Computing passage embeddings...")
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
print("🧮 Computing query embeddings...")
query_embeddings = compute_embeddings_direct(queries, embedding_model)
# Build FAISS indexes
print("🏗️ Building FAISS indexes...")
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
# Test different efSearch values (equivalent to LEANN complexity)
print("\n📊 Evaluating Recall@3 at different efSearch values...")
ef_search_values = [16, 32, 64, 128, 256]
for ef_search in ef_search_values:
print(f"\n🧪 Testing efSearch = {ef_search}")
start_time = time.time()
recall = evaluate_recall_at_k(
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
)
elapsed = time.time() - start_time
print(
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
)
print("\n✅ Verification completed!")
print("\n📋 Summary:")
print(" - Built independent FAISS Flat and HNSW indexes")
print(" - Compared recall@3 at different efSearch values")
print(" - Used same embedding model as LEANN")
print(" - This validates LEANN's recall measurements")
if __name__ == "__main__":
main()

1
benchmarks/laion/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
data/

199
benchmarks/laion/README.md Normal file
View File

@@ -0,0 +1,199 @@
# LAION Multimodal Benchmark
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
## Overview
This benchmark evaluates:
- **Image retrieval timing** using caption-based queries
- **Recall@K performance** for image search
- **Complexity analysis** across different search parameters
- **Index size and storage efficiency**
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
## Dataset Configuration
- **Dataset**: LAION-400M subset (10,000 images)
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
- **Queries**: 200 random captions from the dataset
- **Ground Truth**: Self-recall (query caption → original image)
## Quick Start
### 1. Setup the benchmark
```bash
cd benchmarks/laion
python setup_laion.py --num-samples 10000 --num-queries 200
```
This will:
- Create dummy LAION data (10K samples)
- Generate CLIP embeddings (512-dim)
- Build LEANN index with HNSW backend
- Create 200 evaluation queries
### 2. Run evaluation
```bash
# Run all evaluation stages
python evaluate_laion.py --index data/laion_index.leann
# Run specific stages
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
# Multimodal generation with Qwen2.5-VL
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
```
### 3. Save results
```bash
python evaluate_laion.py --index data/laion_index.leann --output results.json
```
## Configuration Options
### Setup Options
```bash
python setup_laion.py \
--num-samples 10000 \
--num-queries 200 \
--index-path data/laion_index.leann \
--backend hnsw
```
### Evaluation Options
```bash
python evaluate_laion.py \
--index data/laion_index.leann \
--queries data/evaluation_queries.jsonl \
--complexity 64 \
--top-k 3 \
--num-samples 100 \
--stage all
```
## Evaluation Stages
### Stage 2: Recall Evaluation
- Evaluates Recall@3 for multimodal retrieval
- Compares LEANN vs FAISS baseline performance
- Self-recall: query caption should retrieve original image
### Stage 3: Complexity Analysis
- Binary search for optimal complexity (90% recall target)
- Tests performance across different complexity levels
- Analyzes speed vs. accuracy tradeoffs
### Stage 4: Index Comparison
- Compares compact vs non-compact index sizes
- Measures search performance differences
- Reports storage efficiency and speed ratios
### Stage 5: Multimodal Generation
- Uses Qwen2.5-VL for image understanding and description
- Retrieval-Augmented Generation (RAG) with multimodal context
- Measures both search and generation timing
## Output Metrics
### Timing Metrics
- Average/median/min/max search time
- Standard deviation
- Searches per second
- Latency in milliseconds
### Recall Metrics
- Recall@3 percentage for image retrieval
- Number of queries with ground truth
### Index Metrics
- Total index size (MB)
- Component breakdown (index, passages, metadata)
- Storage savings (compact vs non-compact)
- Backend and embedding model info
### Generation Metrics (Stage 5)
- Average search time per query
- Average generation time per query
- Time distribution (search vs generation)
- Sample multimodal responses
- Model: Qwen2.5-VL performance
## Benchmark Results
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
**Stage 3: Optimal Complexity Analysis**
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
- **Binary Search Range**: 1-128
- **Target Recall**: 90%
- **Index Type**: Non-compact (for fast binary search)
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
- **Total Queries**: 20
- **Average Search Time**: 1.200s per query
- **Average Generation Time**: 6.558s per query
- **Time Distribution**: Search 15.5%, Generation 84.5%
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
- **Optimal Complexity**: 85
**System Performance:**
- **Index Size**: ~10,000 image embeddings from LAION subset
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
- **Backend**: HNSW with cosine distance
### Example Results
```
🎯 LAION MULTIMODAL BENCHMARK RESULTS
============================================================
📊 Multimodal Generation Results:
Total Queries: 20
Avg Search Time: 1.200s
Avg Generation Time: 6.558s
Time Distribution: Search 15.5%, Generation 84.5%
LLM Backend: HuggingFace transformers
Model: Qwen/Qwen2.5-VL-7B-Instruct
⚙️ Optimal Complexity Analysis:
Target Recall: 90%
Optimal Complexity: 85
Binary Search Range: 1-128
Non-compact Index (fast search, no recompute)
🚀 Performance Summary:
Multimodal RAG: 7.758s total per query
Search: 15.5% of total time
Generation: 84.5% of total time
```
## Directory Structure
```
benchmarks/laion/
├── setup_laion.py # Setup script
├── evaluate_laion.py # Evaluation script
├── README.md # This file
└── data/ # Generated data
├── laion_images/ # Image files (placeholder)
├── laion_metadata.jsonl # Image metadata
├── laion_passages.jsonl # LEANN passages
├── laion_embeddings.npy # CLIP embeddings
├── evaluation_queries.jsonl # Evaluation queries
└── laion_index.leann/ # LEANN index files
```
## Notes
- Current implementation uses dummy data for demonstration
- For real LAION data, implement actual download logic in `setup_laion.py`
- CLIP embeddings are randomly generated - replace with real CLIP model for production
- Adjust `num_samples` and `num_queries` based on available resources
- Consider using `--num-samples` during evaluation for faster testing

View File

@@ -0,0 +1,725 @@
"""
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
import numpy as np
from leann import LeannSearcher
from leann_backend_hnsw import faiss
from sentence_transformers import SentenceTransformer
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline (image embeddings)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.image_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
self.st_clip = SentenceTransformer("clip-ViT-L-14")
def evaluate_recall_at_3(
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(captions)
for i, caption in enumerate(captions):
# Get ground truth: search with FAISS flat using caption text embedding
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
query_embedding = self.st_clip.encode(
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results (image IDs from FAISS)
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity (using caption as text query)
test_results = self.searcher.search(
caption,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class LAIONEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
"""Load caption queries from evaluation file"""
captions = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if line.strip():
query_data = json.loads(line)
captions.append(query_data["query"])
print(f"📊 Loaded {len(captions)} caption queries")
return captions
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
print("📏 Analyzing index sizes (.index only)...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes: dict[str, float] = {}
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
# Core index size (.index only)
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
sizes["index_only_mb"] = index_mb
# Other files for reference (not counted in index_only_mb)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {index_mb:.1f} MB")
if sizes["metadata_mb"]:
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
print(
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
)
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Load CLIP embeddings
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
embeddings = np.load(embeddings_file)
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
distance_metric="cosine",
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact", "distance_metric"]
},
)
# Prepare ids and add passages
ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
# Ensure metadata contains the id used by the vector index
metadata = {**data.get("metadata", {}), "id": data["id"]}
builder.add_text(text=data["text"], metadata=metadata)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
# Test queries
test_queries = test_captions[:5]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for caption in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for caption in test_queries:
start_time = time.time()
_ = compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
print("=" * 60)
# Index comparison analysis (prefer .index-only view if present)
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
if "index_only_mb" in current and "index_only_mb" in non_compact:
print("\n📏 Index Comparison Analysis (.index only):")
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
# Show excluded components for reference if available
if any(
k in non_compact
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
):
print(" (passages excluded in totals, shown for reference):")
print(
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
)
else:
# Fallback to legacy totals if running with older metrics
print("\n📏 Index Comparison Analysis:")
print(
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
)
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend",
choices=["hf"],
default="hf",
help="LLM backend (Qwen2.5-VL only supports HF)",
)
parser.add_argument(
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
)
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_laion.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load caption queries for testing
laion_evaluator = LAIONEvaluator(args.index)
captions = laion_evaluator.load_queries(args.queries)
# Test with queries for robust measurement
test_captions = captions[:100] # Use subset for speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search
print("🏗️ Creating non-compact index for binary search...")
evaluator = LAIONEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load caption queries for testing
captions = evaluator.load_queries(args.queries)
# Use subset for robust measurement
test_captions = captions[:50] # Smaller subset for binary search speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Binary search for 90% recall complexity
target_recall = 0.9
min_complexity, max_complexity = 1, 128
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_captions[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Index comparison (without LLM generation)
print("🚀 Starting Stage 4: Index comparison analysis")
# Use LAION evaluator for index comparison
evaluator = LAIONEvaluator(args.index)
# Load caption queries
captions = evaluator.load_queries(args.queries)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes (.index only)
print("\n📊 Index size comparison (.index only):")
print(
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
)
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
storage_saving = 0.0
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
storage_saving = (
(
non_compact_size_metrics.get("index_only_mb", 0)
- compact_size_metrics.get("index_only_mb", 0)
)
/ non_compact_size_metrics.get("index_only_mb", 1)
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for index comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, captions[:10], complexity=complexity
)
# Combine all metrics
combined_metrics = {
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print comprehensive results
evaluator._print_results(combined_metrics)
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
evaluator = LAIONEvaluator(args.index)
captions = evaluator.load_queries(args.queries)
test_captions = captions[: min(20, len(captions))] # Use subset for generation
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
# Load Qwen2.5-VL model
try:
print("Loading Qwen2.5-VL model...")
processor, model = load_qwen_vl_model(args.model_name)
# Run multimodal generation evaluation
complexity = args.complexity or 64
gen_results = evaluate_multimodal_rag(
evaluator.searcher,
test_captions,
processor=processor,
model=model,
complexity=complexity,
)
print("\n📊 Multimodal Generation Results:")
print(f" Total Queries: {len(test_captions)}")
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
search_pct = (gen_results["avg_search_time"] / total_time) * 100
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
print(" LLM Backend: HuggingFace transformers")
print(f" Model: {args.model_name}")
# Show sample results
print("\n📝 Sample Multimodal Generations:")
for i, response in enumerate(gen_results["results"][:3]):
# Handle both string and dict formats for captions
if isinstance(test_captions[i], dict):
caption_text = test_captions[i].get("query", str(test_captions[i]))
else:
caption_text = str(test_captions[i])
print(f" Query {i + 1}: {caption_text[:60]}...")
print(f" Response {i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Multimodal generation evaluation failed: {e}")
print("💡 Make sure transformers and Qwen2.5-VL are installed")
import traceback
traceback.print_exc()
evaluator.cleanup()
print("✅ Stage 5 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Index comparison analysis completed")
print(" Stage 5: ✅ Multimodal generation evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review index comparison for storage vs performance tradeoffs")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
import traceback
traceback.print_exc()
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,576 @@
"""
LAION Multimodal Benchmark Setup Script
Downloads LAION subset and builds LEANN index with sentence embeddings
"""
import argparse
import asyncio
import io
import json
import os
import pickle
import time
from pathlib import Path
import aiohttp
import numpy as np
from datasets import load_dataset
from leann import LeannBuilder
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
class LAIONSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.images_dir = self.data_dir / "laion_images"
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
# Create directories
self.data_dir.mkdir(exist_ok=True)
self.images_dir.mkdir(exist_ok=True)
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
"""Download a single image asynchronously"""
async with semaphore: # Limit concurrent downloads
try:
image_url = sample_data["url"]
image_path = sample_data["image_path"]
# Skip if already exists
if os.path.exists(image_path):
progress_bar.update(1)
return sample_data
async with session.get(image_url, timeout=10) as response:
if response.status == 200:
content = await response.read()
# Verify it's a valid image
try:
img = Image.open(io.BytesIO(content))
img = img.convert("RGB")
img.save(image_path, "JPEG")
progress_bar.update(1)
return sample_data
except Exception:
progress_bar.update(1)
return None # Skip invalid images
else:
progress_bar.update(1)
return None
except Exception:
progress_bar.update(1)
return None
def download_laion_subset(self, num_samples: int = 1000):
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
# Load LAION-400M subset from HuggingFace
print("🤗 Loading from HuggingFace datasets...")
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
# Collect sample metadata first (fast)
print("📋 Collecting sample metadata...")
candidates = []
for sample in dataset:
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
break
image_url = sample.get("url", "")
caption = sample.get("caption", "")
if not image_url or not caption:
continue
image_filename = f"laion_{len(candidates):06d}.jpg"
image_path = self.images_dir / image_filename
candidate = {
"id": f"laion_{len(candidates):06d}",
"url": image_url,
"caption": caption,
"image_path": str(image_path),
"width": sample.get("original_width", 512),
"height": sample.get("original_height", 512),
"similarity": sample.get("similarity", 0.0),
}
candidates.append(candidate)
print(
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
)
# Download images in parallel
async def download_batch():
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
timeout = aiohttp.ClientTimeout(total=30)
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
tasks = []
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
task = self.download_single_image(session, candidate, semaphore, progress_bar)
tasks.append(task)
# Wait for all downloads
results = await asyncio.gather(*tasks, return_exceptions=True)
progress_bar.close()
# Filter successful downloads
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
return successful[:num_samples]
# Run async download
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
samples = loop.run_until_complete(download_batch())
finally:
loop.close()
# Save metadata
with open(self.metadata_file, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
return samples
def generate_clip_image_embeddings(self, samples: list[dict]):
"""Generate CLIP image embeddings for downloaded images"""
print("🔍 Generating CLIP image embeddings...")
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
# This single model can encode both images and text.
model = SentenceTransformer("clip-ViT-L-14")
embeddings = []
valid_samples = []
for sample in tqdm(samples, desc="Processing images"):
try:
# Load image
image_path = sample["image_path"]
image = Image.open(image_path).convert("RGB")
# Encode image to 768-dim embedding via sentence-transformers (normalized)
vec = model.encode(
[image],
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=1,
show_progress_bar=False,
)[0]
embeddings.append(vec.astype(np.float32))
valid_samples.append(sample)
except Exception as e:
print(f" ⚠️ Failed to process {sample['id']}: {e}")
# Skip invalid images
embeddings = np.array(embeddings, dtype=np.float32)
# Save embeddings
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
np.save(embeddings_file, embeddings)
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
return embeddings, valid_samples
def build_faiss_baseline(
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
):
"""Build FAISS flat baseline using CLIP image embeddings"""
print("🔨 Building FAISS Flat baseline...")
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Extract image IDs (must be present)
if not samples or "id" not in samples[0]:
raise KeyError("samples missing 'id' field for FAISS baseline")
image_ids: list[str] = [str(sample["id"]) for sample in samples]
print(f"📐 Embedding shape: {embeddings.shape}")
print(f"📄 Processing {len(image_ids)} images")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(image_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def create_leann_passages(self, samples: list[dict]):
"""Create LEANN-compatible passages from LAION data"""
print("📝 Creating LEANN passages...")
passages_file = self.data_dir / "laion_passages.jsonl"
with open(passages_file, "w", encoding="utf-8") as f:
for i, sample in enumerate(samples):
passage = {
"id": sample["id"],
"text": sample["caption"], # Use caption as searchable text
"metadata": {
"image_url": sample["url"],
"image_path": sample.get("image_path", ""),
"width": sample["width"],
"height": sample["height"],
"similarity": sample["similarity"],
"image_index": i, # Index for embedding lookup
},
}
f.write(json.dumps(passage) + "\n")
print(f"✅ Created {len(samples)} passages")
return passages_file
def build_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
print(f"🏗️ Building compact LEANN index with {backend} backend...")
start_time = time.time()
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
npy_path = self.data_dir / "clip_image_embeddings.npy"
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in the same order as passages_file (matches embeddings order)
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - compact with recompute
# Note: For multimodal case, we need to handle embeddings differently
# Let's try using sentence-transformers mode but with custom embeddings
builder = LeannBuilder(
backend_name=backend,
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
# HNSW params (or forwarded to chosen backend)
graph_degree=32,
complexity=64,
# Compact/pruned with recompute at query time
is_recompute=True,
is_compact=True,
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
num_threads=4,
)
# Add passages (text + metadata)
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def build_non_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
start_time = time.time()
# Ensure embeddings are saved (npy + pickle)
npy_path = self.data_dir / "clip_image_embeddings.npy"
if not npy_path.exists():
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in same order as passages_file
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
if not pkl_path.exists():
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - non-compact without recompute
builder = LeannBuilder(
backend_name=backend,
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=False, # Store embeddings (no recompute needed)
is_compact=False, # Store full index (not pruned)
distance_metric="cosine",
num_threads=4,
)
# Add passages - embeddings will be loaded from file
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Non-compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
"""Helper to add passages with pre-computed CLIP embeddings"""
with open(passages_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Adding passages"):
if line.strip():
passage = json.loads(line)
# Add image metadata - LEANN will handle embeddings separately
# Note: We store image metadata and caption text for searchability
# Important: ensure passage ID in metadata matches vector ID
builder.add_text(
text=passage["text"], # Image caption for searchability
metadata={**passage["metadata"], "id": passage["id"]},
)
def _analyze_index_size(self, index_path: str):
"""Analyze index file sizes"""
print("📏 Analyzing index sizes...")
index_path = Path(index_path)
index_dir = index_path.parent
index_name = index_path.name # e.g., laion_index.leann
index_prefix = index_path.stem # e.g., laion_index
files = [
(f"{index_prefix}.index", ".index", "core"),
(f"{index_name}.meta.json", ".meta.json", "core"),
(f"{index_name}.ids.txt", ".ids.txt", "core"),
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
]
def _fmt_size(bytes_val: int) -> str:
if bytes_val < 1024:
return f"{bytes_val} B"
kb = bytes_val / 1024
if kb < 1024:
return f"{kb:.1f} KB"
mb = kb / 1024
if mb < 1024:
return f"{mb:.2f} MB"
gb = mb / 1024
return f"{gb:.2f} GB"
total_index_only_mb = 0.0
total_all_mb = 0.0
for filename, label, group in files:
file_path = index_dir / filename
if file_path.exists():
size_bytes = file_path.stat().st_size
print(f" {label}: {_fmt_size(size_bytes)}")
size_mb = size_bytes / (1024 * 1024)
total_all_mb += size_mb
if group == "core":
total_index_only_mb += size_mb
else:
print(f" {label}: (missing)")
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
print(f" Total (including passages): {total_all_mb:.2f} MB")
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
"""Create evaluation queries from captions"""
print(f"📝 Creating {num_queries} evaluation queries...")
# Sample random captions as queries
import random
random.seed(42) # For reproducibility
query_samples = random.sample(samples, min(num_queries, len(samples)))
queries_file = self.data_dir / "evaluation_queries.jsonl"
with open(queries_file, "w", encoding="utf-8") as f:
for sample in query_samples:
query = {
"id": sample["id"],
"query": sample["caption"],
"ground_truth_id": sample["id"], # For potential recall evaluation
}
f.write(json.dumps(query) + "\n")
print(f"✅ Created {len(query_samples)} evaluation queries")
return queries_file
def main():
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
parser.add_argument(
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
)
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
args = parser.parse_args()
print("🚀 Setting up LAION Multimodal Benchmark")
print("=" * 50)
try:
# Initialize setup
setup = LAIONSetup(args.data_dir)
# Step 1: Download LAION subset
if not args.skip_download:
print("\n📦 Step 1: Download LAION subset")
samples = setup.download_laion_subset(args.num_samples)
# Step 2: Generate CLIP image embeddings
print("\n🔍 Step 2: Generate CLIP image embeddings")
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
# Step 3: Create LEANN passages (image metadata with embeddings)
print("\n📝 Step 3: Create LEANN passages")
passages_file = setup.create_leann_passages(valid_samples)
else:
print("⏭️ Skipping LAION dataset download")
# Load existing data
passages_file = setup.data_dir / "laion_passages.jsonl"
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
if not passages_file.exists() or not embeddings_file.exists():
raise FileNotFoundError(
"Passages or embeddings file not found. Run without --skip-download first."
)
embeddings = np.load(embeddings_file)
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
# Step 4: Build LEANN indexes (both compact and non-compact)
if not args.skip_build:
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
# Build compact index (production mode - small, recompute required)
compact_index_path = args.index_path
print(f"Building compact index: {compact_index_path}")
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
# Build non-compact index (comparison mode - large, fast search)
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
print(f"Building non-compact index: {non_compact_index_path}")
setup.build_non_compact_index(
passages_file, embeddings, non_compact_index_path, args.backend
)
# Step 5: Build FAISS flat baseline
print("\n🔨 Step 5: Build FAISS flat baseline")
if not args.skip_download:
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
else:
# Load valid_samples from passages file for FAISS baseline
valid_samples = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
passage = json.loads(line)
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
# Step 6: Create evaluation queries
print("\n📝 Step 6: Create evaluation queries")
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
else:
print("⏭️ Skipping index building")
baseline_path = "data/baseline/faiss_index.bin"
queries_file = setup.data_dir / "evaluation_queries.jsonl"
print("\n🎉 Setup completed successfully!")
print("📊 Summary:")
if not args.skip_download:
print(f" Downloaded samples: {len(samples)}")
print(f" Valid samples with embeddings: {len(valid_samples)}")
else:
print(f" Loaded {len(embeddings)} embeddings")
if not args.skip_build:
print(f" Compact index: {compact_index_path}")
print(f" Non-compact index: {non_compact_index_path}")
print(f" FAISS baseline: {baseline_path}")
print(f" Queries: {queries_file}")
print("\n🔧 Next steps:")
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
else:
print(" Skipped building indexes")
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

301
benchmarks/llm_utils.py Normal file
View File

@@ -0,0 +1,301 @@
"""
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
"""
import time
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
def is_qwen3_model(model_name):
"""Check if model is Qwen3"""
return "Qwen3" in model_name or "qwen3" in model_name.lower()
def is_qwen_vl_model(model_name):
"""Check if model is Qwen2.5-VL"""
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
def apply_qwen3_chat_template(tokenizer, prompt):
"""Apply Qwen3 chat template with thinking enabled"""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
def extract_thinking_answer(response):
"""Extract final answer from Qwen3 thinking model response"""
if "<think>" in response and "</think>" in response:
try:
think_end = response.index("</think>") + len("</think>")
final_answer = response[think_end:].strip()
return final_answer
except (ValueError, IndexError):
pass
return response.strip()
def load_hf_model(model_name="Qwen/Qwen3-8B"):
"""Load HuggingFace model"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
print(f"Loading HF: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
return tokenizer, model
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
"""Load vLLM model"""
if not VLLM_AVAILABLE:
raise ImportError("vllm not available")
print(f"Loading vLLM: {model_name}")
llm = LLM(model=model_name, trust_remote_code=True)
# Qwen3 specific config
if is_qwen3_model(model_name):
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
max_tokens = 2048
else:
stop_tokens = None
max_tokens = 1024
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
return llm, sampling_params
def generate_hf(tokenizer, model, prompt, max_tokens=None):
"""Generate with HF - supports Qwen3 thinking models"""
model_name = getattr(model, "name_or_path", "unknown")
is_qwen3 = is_qwen3_model(model_name)
# Apply chat template for Qwen3
if is_qwen3:
prompt = apply_qwen3_chat_template(tokenizer, prompt)
max_tokens = max_tokens or 2048
else:
max_tokens = max_tokens or 1024
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt) :].strip()
# Extract final answer for thinking models
if is_qwen3:
return extract_thinking_answer(response)
return response
def generate_vllm(llm, sampling_params, prompt):
"""Generate with vLLM - supports Qwen3 thinking models"""
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text.strip()
# Extract final answer for Qwen3 thinking models
model_name = str(llm.llm_engine.model_config.model)
if is_qwen3_model(model_name):
return extract_thinking_answer(response)
return response
def create_prompt(context, query, domain="default"):
"""Create RAG prompt"""
if domain == "emails":
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "finance":
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "multimodal":
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
else:
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
"""Simple RAG evaluation with timing"""
search_times = []
gen_times = []
results = []
for i, query in enumerate(queries):
# Search
start = time.time()
docs = searcher.search(query, top_k=top_k, complexity=complexity)
search_time = time.time() - start
# Generate
context = "\n\n".join([doc.text for doc in docs])
prompt = create_prompt(context, query, domain)
start = time.time()
response = llm_func(prompt)
gen_time = time.time() - start
search_times.append(search_time)
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
"""Load Qwen2.5-VL multimodal model"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
print(f"Loading Qwen2.5-VL: {model_name}")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
return processor, model
except Exception as e:
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
# Fallback to specific class
try:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
return processor, model
except Exception as e2:
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
"""Generate with Qwen2.5-VL multimodal model"""
from PIL import Image
# Prepare inputs
if image_path:
image = Image.open(image_path)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
)
# Decode response
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
response = processor.decode(generated_ids[0], skip_special_tokens=True)
return response
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
"""Create prompt for multimodal RAG"""
if task_type == "images":
return f"""Based on the retrieved images and their descriptions, answer the following question.
Retrieved Image Descriptions:
{context}
Question: {query}
Provide a detailed answer based on the visual content described above."""
return f"Context: {context}\nQuestion: {query}\nAnswer:"
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
"""Evaluate multimodal RAG with Qwen2.5-VL"""
search_times = []
gen_times = []
results = []
for i, query_item in enumerate(queries):
# Handle both string and dict formats for queries
if isinstance(query_item, dict):
query = query_item.get("query", "")
image_path = query_item.get("image_path") # Optional reference image
else:
query = str(query_item)
image_path = None
# Search
start_time = time.time()
search_results = searcher.search(query, top_k=3, complexity=complexity)
search_time = time.time() - start_time
search_times.append(search_time)
# Prepare context from search results
context_parts = []
for result in search_results:
context_parts.append(f"- {result.text}")
context = "\n".join(context_parts)
# Generate with multimodal model
start_time = time.time()
if processor and model:
prompt = create_multimodal_prompt(context, query, context_parts)
response = generate_qwen_vl(processor, model, prompt, image_path)
else:
response = f"Context: {context}"
gen_time = time.time() - start_time
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}

View File

@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
print("uv sync --only-group dev")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")

View File

@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`):
1. **Install pre-commit tools**:
```bash
uv pip install pre-commit
uv sync lint
```
2. **Install the git hooks**:
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
3. **Run pre-commit manually** (optional):
```bash
pre-commit run --all-files
uv run pre-commit run --all-files
```
### Pre-commit Checks
@@ -85,6 +85,9 @@ Our pre-commit configuration includes:
### Running Tests
```bash
# Install test tools only (no project runtime)
uv sync --group test
# Run all tests
uv run pytest

View File

@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
</details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)

0
examples/__init__.py Normal file
View File

View File

@@ -1,380 +0,0 @@
"""
Dynamic add example for LEANN using HNSW backend without recompute.
- Builds a base index from a directory of documents
- Incrementally adds new documents without recomputing stored embeddings
Defaults:
- Base data: /Users/yichuan/Desktop/code/LEANN/leann/data
- Incremental data: /Users/yichuan/Desktop/code/LEANN/leann/test_add
- Index path: <index_dir>/documents.leann
Usage examples:
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
--index-dir ./test_doc_files
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
--index-dir ./test_doc_files
Quick recompute test (both true):
# Recompute build
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
--recompute-build --ef-construction 200 \
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
--index-dir ./test_doc_files --index-name documents.leann
# Recompute add
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
--recompute-add --ef-construction 32 \
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
--index-dir ./test_doc_files --index-name documents.leann
"""
import argparse
import json
import pickle
import sys
from pathlib import Path
from typing import Any, Optional
# Ensure we can import from the local packages and apps folders
ROOT = Path(__file__).resolve().parents[1]
CORE_SRC = ROOT / "packages" / "leann-core" / "src"
HNSW_PKG_DIR = ROOT / "packages" / "leann-backend-hnsw"
APPS_DIR = ROOT / "apps"
# Prefer the installed backend if available (it contains the compiled extension)
def _prefer_installed(pkg_name: str) -> bool:
try:
import importlib
import importlib.util
spec = importlib.util.find_spec(pkg_name)
if spec and spec.origin and "site-packages" in spec.origin:
# ensure the faiss shim/extension is importable from the installed package
importlib.import_module(f"{pkg_name}.faiss")
return True
except Exception:
pass
return False
# Prepend paths, but only add the repo backend if the installed one is not present
paths_to_prepend = [CORE_SRC, APPS_DIR]
if not _prefer_installed("leann_backend_hnsw"):
paths_to_prepend.insert(1, HNSW_PKG_DIR)
for p in paths_to_prepend:
p_str = str(p)
if p_str not in sys.path:
sys.path.insert(0, p_str)
# Defer non-stdlib imports until after sys.path setup within functions (avoid E402)
def _load_documents(data_dir: str, required_exts: Optional[list[str]] = None) -> list[Any]:
from llama_index.core import SimpleDirectoryReader # type: ignore
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
if required_exts:
reader_kwargs["required_exts"] = required_exts
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
return documents
def _ensure_index_dir(index_dir: Path) -> None:
index_dir.mkdir(parents=True, exist_ok=True)
def _index_files(index_path: Path) -> tuple[Path, Path, Path]:
"""Return (passages.jsonl, passages.idx, index.index) paths for a given index base path.
Note: HNSWBackend writes the FAISS index using the stem (without .leann),
i.e., for base 'documents.leann' the file is 'documents.index'. We prefer the
existing file among candidates.
"""
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offsets_file = index_path.parent / f"{index_path.name}.passages.idx"
candidate_name_index = index_path.parent / f"{index_path.name}.index"
candidate_stem_index = index_path.parent / f"{index_path.stem}.index"
index_file = candidate_stem_index if candidate_stem_index.exists() else candidate_name_index
return passages_file, offsets_file, index_file
def _read_meta(index_path: Path) -> dict[str, Any]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Metadata file not found: {meta_path}")
with open(meta_path, encoding="utf-8") as f:
return json.load(f)
def _autodetect_index_base(index_dir: Path) -> Optional[Path]:
"""If exactly one *.leann.meta.json exists, return its base path (without .meta.json)."""
candidates = list(index_dir.glob("*.leann.meta.json"))
if len(candidates) == 1:
meta = candidates[0]
base = meta.with_suffix("") # remove .json
base = base.with_suffix("") # remove .meta
return base
return None
def _load_offset_map(offsets_file: Path) -> dict[str, int]:
if not offsets_file.exists():
return {}
with open(offsets_file, "rb") as f:
return pickle.load(f)
def _next_numeric_id(existing_ids: list[str]) -> int:
numeric_ids = [int(x) for x in existing_ids if x.isdigit()]
if not numeric_ids:
return 0
return max(numeric_ids) + 1
def build_base_index(
base_dir: str,
index_dir: str,
index_name: str,
embedding_model: str,
embedding_mode: str,
chunk_size: int,
chunk_overlap: int,
file_types: Optional[list[str]] = None,
max_items: int = -1,
ef_construction: Optional[int] = None,
recompute_build: bool = False,
) -> str:
print(f"Building base index from: {base_dir}")
documents = _load_documents(base_dir, required_exts=file_types)
if not documents:
raise ValueError(f"No documents found in base_dir: {base_dir}")
from chunking import create_text_chunks
texts = create_text_chunks(
documents,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
use_ast_chunking=False,
)
if max_items > 0 and len(texts) > max_items:
texts = texts[:max_items]
print(f"Limiting to {max_items} chunks")
index_dir_path = Path(index_dir)
_ensure_index_dir(index_dir_path)
index_path = index_dir_path / index_name
print("Creating HNSW index (non-compact)...")
from leann.api import LeannBuilder
from leann.registry import register_project_directory
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
embedding_mode=embedding_mode,
is_recompute=recompute_build,
is_compact=False,
efConstruction=(ef_construction if ef_construction is not None else 200),
)
for t in texts:
builder.add_text(t)
builder.build_index(str(index_path))
# Register for discovery
register_project_directory(Path.cwd())
print(f"Base index built at: {index_path}")
return str(index_path)
def add_incremental(
add_dir: str,
index_dir: str,
index_name: Optional[str] = None,
embedding_model: Optional[str] = None,
embedding_mode: Optional[str] = None,
chunk_size: int = 256,
chunk_overlap: int = 128,
file_types: Optional[list[str]] = None,
max_items: int = -1,
ef_construction: Optional[int] = None,
recompute_add: bool = False,
) -> str:
print(f"Adding incremental data from: {add_dir}")
index_dir_path = Path(index_dir)
index_path = index_dir_path / (index_name or "documents.leann")
# If specified base doesn't exist, try to auto-detect an existing base
try:
_read_meta(index_path)
except FileNotFoundError:
auto_base = _autodetect_index_base(index_dir_path)
if auto_base is not None:
print(f"Auto-detected index base: {auto_base.name}")
index_path = auto_base
_read_meta(index_path)
else:
raise FileNotFoundError(
f"No index metadata found for base '{index_path.name}'. Build base first with --build-base "
f"or provide --index-name to match an existing index (e.g., 'test_doc_files.leann')."
)
# Prepare validated context from core (checks backend/no-recompute and resolves embedding defaults)
from leann.api import create_incremental_add_context, incremental_add_texts_with_context
ctx = create_incremental_add_context(
str(index_path),
embedding_model=embedding_model,
embedding_mode=embedding_mode,
data_dir=add_dir,
required_exts=file_types,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
max_items=max_items,
)
# Use prepared texts from context to perform the add
prepared_texts = ctx.prepared_texts or []
if not prepared_texts:
print("No new chunks to add.")
return str(index_path)
added = incremental_add_texts_with_context(
ctx,
prepared_texts,
ef_construction=ef_construction,
recompute=recompute_add,
)
print(f"Incremental add completed. Added {added} chunks. Index: {index_path}")
return str(index_path)
def main():
parser = argparse.ArgumentParser(
description="Dynamic add to LEANN HNSW index without recompute",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--build-base", action="store_true", help="Build base index")
parser.add_argument("--add-incremental", action="store_true", help="Add incremental data")
parser.add_argument(
"--base-dir",
type=str,
default="/Users/yichuan/Desktop/code/LEANN/leann/data",
help="Base data directory",
)
parser.add_argument(
"--add-dir",
type=str,
default="/Users/yichuan/Desktop/code/LEANN/leann/test_add",
help="Incremental data directory",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory containing the index",
)
parser.add_argument(
"--index-name",
type=str,
default="documents.leann",
help=(
"Index base file name. If you built via document_rag.py, use 'test_doc_files.leann'. "
"Default: documents.leann"
),
)
parser.add_argument(
"--embedding-model",
type=str,
default="facebook/contriever",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument("--chunk-size", type=int, default=256)
parser.add_argument("--chunk-overlap", type=int, default=128)
parser.add_argument("--file-types", nargs="+", default=None)
parser.add_argument("--max-items", type=int, default=-1)
parser.add_argument("--ef-construction", type=int, default=32)
parser.add_argument(
"--recompute-add", action="store_true", help="Enable recompute-mode add (non-compact only)"
)
parser.add_argument(
"--recompute-build",
action="store_true",
help="Enable recompute-mode base build (non-compact only)",
)
args = parser.parse_args()
if not args.build_base and not args.add_incremental:
print("Nothing to do. Use --build-base and/or --add-incremental.")
return
index_path_str: Optional[str] = None
if args.build_base:
index_path_str = build_base_index(
base_dir=args.base_dir,
index_dir=args.index_dir,
index_name=args.index_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
file_types=args.file_types,
max_items=args.max_items,
ef_construction=args.ef_construction,
recompute_build=args.recompute_build,
)
if args.add_incremental:
index_path_str = add_incremental(
add_dir=args.add_dir,
index_dir=args.index_dir,
index_name=args.index_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
file_types=args.file_types,
max_items=args.max_items,
ef_construction=args.ef_construction,
recompute_add=args.recompute_add,
)
# Optional: quick test query using searcher
if index_path_str:
try:
from leann.api import LeannSearcher
searcher = LeannSearcher(index_path_str)
query = "what is LEANN?"
if args.add_incremental:
query = "what is the multi vector search and how it works?"
results = searcher.search(query, top_k=5)
if results:
print(f"Sample result: {results[0].text[:80]}...")
except Exception:
pass
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,429 @@
"""Dynamic HNSW update demo without compact storage.
This script reproduces the minimal scenario we used while debugging on-the-fly
recompute:
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
2. Print the top results with `recompute_embeddings=True`.
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
4. Run the same query again to show the newly inserted passages.
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
ZMQ activity)::
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
uv run -m examples.dynamic_update_no_recompute \
--index-path .leann/examples/leann-demo.leann
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
It issues the query "What's LEANN?" before and after the update to show how the
new passages become immediately searchable. The script uses the
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
freshly added passages are embedded locally just like the initial build.
To make storage comparisons easy, the script can also build a matching
``is_recompute=False`` baseline (enabled by default) and report the index size
delta after the update. Disable the baseline run with
``--skip-compare-no-recompute`` if you only need the recompute flow.
"""
import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from leann.api import LeannBuilder, LeannSearcher
from leann.registry import register_project_directory
from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path]) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
return [c for c in chunks if isinstance(c, str) and c.strip()]
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
searcher = LeannSearcher(str(index_path))
try:
return searcher.search(
query=query,
top_k=top_k,
recompute_embeddings=recompute_embeddings,
batch_size=16,
)
finally:
searcher.cleanup()
def print_results(title: str, results: Iterable) -> None:
print(f"\n=== {title} ===")
res_list = list(results)
print(f"results count: {len(res_list)}")
print("passages:")
if not res_list:
print(" (no passages returned)")
for res in res_list:
snippet = res.text.replace("\n", " ")[:120]
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def update_index(
index_path: Path,
start_id: int,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
updater = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for offset, passage in enumerate(paragraphs, start=start_id):
updater.add_text(passage, metadata={"id": str(offset)})
updater.update_index(str(index_path))
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
"""Remove leftover index artifacts for a clean rebuild."""
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def index_file_size(index_path: Path) -> int:
"""Return the size of the primary .index file for the given index path."""
index_file = index_path.parent / f"{index_path.stem}.index"
return index_file.stat().st_size if index_file.exists() else 0
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text())
except json.JSONDecodeError:
return None
def run_workflow(
*,
label: str,
index_path: Path,
initial_paragraphs: list[str],
update_paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
query: str,
top_k: int,
skip_search: bool,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
ensure_index_dir(index_path)
cleanup_index_files(index_path)
print(f"{prefix}Building initial index...")
build_initial_index(
index_path,
initial_paragraphs,
model_name,
embedding_mode,
is_recompute=is_recompute,
)
initial_size = index_file_size(index_path)
if not skip_search:
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...")
update_index(
index_path,
start_id=len(initial_paragraphs),
paragraphs=update_paragraphs,
model_name=model_name,
embedding_mode=embedding_mode,
is_recompute=is_recompute,
)
if not skip_search:
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results if not skip_search else None,
"after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--initial-files",
type=Path,
nargs="+",
default=DEFAULT_INITIAL_FILES,
help="Initial document files (PDF/TXT) used to build the base index",
)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/examples/leann-demo.leann"),
help="Destination index path (default: .leann/examples/leann-demo.leann)",
)
parser.add_argument(
"--initial-count",
type=int,
default=8,
help="Number of chunks to use from the initial documents (default: 8)",
)
parser.add_argument(
"--update-files",
type=Path,
nargs="*",
default=DEFAULT_UPDATE_FILES,
help="Additional documents to add during update (PDF/TXT)",
)
parser.add_argument(
"--update-count",
type=int,
default=4,
help="Number of chunks to append from update documents (default: 4)",
)
parser.add_argument(
"--update-text",
type=str,
default=(
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
"on demand to keep disk usage minimal."
),
help="Fallback text to append if --update-files is omitted",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of results to show for each search (default: 4)",
)
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="Query to run before/after the update",
)
parser.add_argument(
"--embedding-model",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--compare-no-recompute",
dest="compare_no_recompute",
action="store_true",
help="Also run a baseline with is_recompute=False and report its index growth.",
)
parser.add_argument(
"--skip-compare-no-recompute",
dest="compare_no_recompute",
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
ensure_index_dir(args.index_path)
register_project_directory(REPO_ROOT)
initial_chunks = load_chunks_from_files(list(args.initial_files))
if not initial_chunks:
raise ValueError("No text chunks extracted from the initial files.")
initial = initial_chunks[: args.initial_count]
if not initial:
raise ValueError("Initial chunk set is empty after applying --initial-count.")
if args.update_files:
update_chunks = load_chunks_from_files(list(args.update_files))
if not update_chunks:
raise ValueError("No text chunks extracted from the update files.")
to_add = update_chunks[: args.update_count]
else:
if not args.update_text:
raise ValueError("Provide --update-files or --update-text for the update step.")
to_add = [args.update_text]
if not to_add:
raise ValueError("Update chunk set is empty after applying --update-count.")
recompute_stats = run_workflow(
label="recompute",
index_path=args.index_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=True,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
if not args.skip_search:
print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
)
if recompute_stats["metadata"]:
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if args.compare_no_recompute:
baseline_path = (
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
)
baseline_stats = run_workflow(
label="no-recompute",
index_path=baseline_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=False,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print(
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
f"{baseline_stats['delta']})"
)
after_texts = (
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."
)
else:
print("[no-recompute] WARNING: search results differ from recompute baseline.")
if baseline_stats["metadata"]:
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[no-recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if __name__ == "__main__":
main()

View File

@@ -343,7 +343,8 @@ class DiskannSearcher(BaseSearcher):
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
"cache_mechanism": kwargs.get("cache_mechanism", 1),
"pq_prefix": "",
"partition_prefix": partition_prefix,
}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import numpy as np
import zmq
@@ -32,6 +32,16 @@ if not logger.handlers:
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
build-backend = "scikit_build_core.build"
[project]

View File

@@ -5,6 +5,8 @@ import os
import struct
import sys
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@@ -14,8 +14,7 @@ from leann.interface import (
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from .prune_index import prune_embeddings_preserve_graph_inplace
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
@@ -91,16 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
if self.is_recompute:
if self.is_compact:
self._convert_to_csr(index_file)
else:
# Non-compact format: prune only embeddings, keep original graph
ok = prune_embeddings_preserve_graph_inplace(str(index_file))
if not ok:
raise RuntimeError(
"Pruning embeddings while preserving graph failed for non-compact index"
)
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
@@ -142,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -157,13 +159,17 @@ class HNSWSearcher(BaseSearcher):
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
# If pruned (recompute mode), explicitly skip storage to avoid reading
# the pruned section. Still allow MMAP for graph.
io_flags = faiss.IO_FLAG_MMAP
if self.is_pruned:
io_flags |= faiss.IO_FLAG_SKIP_STORAGE
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
self._index = faiss.read_index(str(index_file), io_flags, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
@@ -263,58 +269,19 @@ class HNSWSearcher(BaseSearcher):
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}
# ---------- Helper API for incremental add (Python-level) ----------
def add_vectors(
index_file_path: str,
embeddings: np.ndarray,
*,
ef_construction: Optional[int] = None,
recompute: bool = False,
) -> None:
"""Append vectors to an existing non-compact HNSW index.
Args:
index_file_path: Path to the HNSW .index file
embeddings: float32 numpy array (N, D)
ef_construction: Optional override for efConstruction during insertion
recompute: Reserved for future use to control insertion-time recompute behaviors
"""
from . import faiss # type: ignore
if embeddings.dtype != np.float32:
embeddings = embeddings.astype(np.float32)
if not embeddings.flags.c_contiguous:
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
# Load index normally to ensure storage is present; toggle is_recompute on the object
index = faiss.read_index(str(index_file_path), faiss.IO_FLAG_MMAP)
# Best-effort: explicitly set flag on the object if the binding exposes it
try:
index.is_recompute = bool(recompute)
except Exception:
pass
try:
if ef_construction is not None:
index.hnsw.efConstruction = int(ef_construction)
except Exception:
# Best-effort; ignore if backend doesn't expose setter
pass
# For non-compact HNSW, calling add directly is sufficient. When is_recompute is set
# (via config or attribute), FAISS will run the insertion/search path accordingly.
# To strictly follow per-point insert semantics in recompute mode, add one-by-one.
if recompute:
# Insert row by row
n = embeddings.shape[0]
for i in range(n):
row = embeddings[i : i + 1]
index.add(1, faiss.swig_ptr(row))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file_path))

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import msgpack
import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
# Ensure we have handlers if none exist
if not logger.handlers:
handler = logging.StreamHandler()
stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server(
@@ -92,6 +114,35 @@ def create_hnsw_embedding_server(
embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
@@ -138,7 +189,12 @@ def create_hnsw_embedding_server(
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -168,13 +224,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
@@ -187,7 +244,10 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -238,13 +298,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
@@ -252,7 +313,12 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)

View File

@@ -1,149 +0,0 @@
import os
import struct
from pathlib import Path
from .convert_to_csr import (
EXPECTED_HNSW_FOURCCS,
NULL_INDEX_FOURCC,
read_struct,
read_vector_raw,
)
def _write_vector_raw(f_out, count: int, data_bytes: bytes) -> None:
"""Write a vector in the same binary layout as read_vector_raw reads: <Q count> + raw bytes."""
f_out.write(struct.pack("<Q", count))
if count > 0 and data_bytes:
f_out.write(data_bytes)
def prune_embeddings_preserve_graph(input_filename: str, output_filename: str) -> bool:
"""
Copy an original (non-compact) HNSW index file while pruning the trailing embedding storage.
Preserves the graph structure and metadata exactly; only writes a NULL storage marker instead of
the original storage fourcc and payload.
Returns True on success.
"""
print(f"Pruning embeddings from {input_filename} to {output_filename}")
print("--------------------------------")
# running in mode is-recompute=True and is-compact=False
in_path = Path(input_filename)
out_path = Path(output_filename)
try:
with open(in_path, "rb") as f_in, open(out_path, "wb") as f_out:
# Header
index_fourcc = read_struct(f_in, "<I")
if index_fourcc not in EXPECTED_HNSW_FOURCCS:
# Still proceed, but this is unexpected
pass
f_out.write(struct.pack("<I", index_fourcc))
d = read_struct(f_in, "<i")
ntotal_hdr = read_struct(f_in, "<q")
dummy1 = read_struct(f_in, "<q")
dummy2 = read_struct(f_in, "<q")
is_trained = read_struct(f_in, "?")
metric_type = read_struct(f_in, "<i")
f_out.write(struct.pack("<i", d))
f_out.write(struct.pack("<q", ntotal_hdr))
f_out.write(struct.pack("<q", dummy1))
f_out.write(struct.pack("<q", dummy2))
f_out.write(struct.pack("<?", is_trained))
f_out.write(struct.pack("<i", metric_type))
if metric_type > 1:
metric_arg = read_struct(f_in, "<f")
f_out.write(struct.pack("<f", metric_arg))
# Vectors: assign_probas (double), cum_nneighbor_per_level (int32), levels (int32)
cnt, data = read_vector_raw(f_in, "d")
_write_vector_raw(f_out, cnt, data)
cnt, data = read_vector_raw(f_in, "i")
_write_vector_raw(f_out, cnt, data)
cnt, data = read_vector_raw(f_in, "i")
_write_vector_raw(f_out, cnt, data)
# Probe potential extra alignment/flag byte present in some original formats
probe = f_in.read(1)
if probe:
if probe == b"\x00":
# Preserve this unexpected 0x00 byte
f_out.write(probe)
else:
# Likely part of the next vector; rewind
f_in.seek(-1, os.SEEK_CUR)
# Offsets (uint64) and neighbors (int32)
cnt, data = read_vector_raw(f_in, "Q")
_write_vector_raw(f_out, cnt, data)
cnt, data = read_vector_raw(f_in, "i")
_write_vector_raw(f_out, cnt, data)
# Scalar params
entry_point = read_struct(f_in, "<i")
max_level = read_struct(f_in, "<i")
ef_construction = read_struct(f_in, "<i")
ef_search = read_struct(f_in, "<i")
dummy_upper_beam = read_struct(f_in, "<i")
f_out.write(struct.pack("<i", entry_point))
f_out.write(struct.pack("<i", max_level))
f_out.write(struct.pack("<i", ef_construction))
f_out.write(struct.pack("<i", ef_search))
f_out.write(struct.pack("<i", dummy_upper_beam))
# Storage fourcc (if present) — write NULL marker and drop any remaining data
try:
read_struct(f_in, "<I")
# Regardless of original, write NULL
f_out.write(struct.pack("<I", NULL_INDEX_FOURCC))
# Discard the rest of the file (embedding payload)
# (Do not copy anything else)
except EOFError:
# No storage section; nothing else to write
pass
return True
except Exception:
# Best-effort cleanup
try:
if out_path.exists():
out_path.unlink()
except OSError:
pass
return False
def prune_embeddings_preserve_graph_inplace(index_file_path: str) -> bool:
"""
Convenience wrapper: write pruned file to a temporary path next to the
original, then atomically replace on success.
"""
print(f"Pruning embeddings from {index_file_path} to {index_file_path}")
print("--------------------------------")
# running in mode is-recompute=True and is-compact=False
src = Path(index_file_path)
tmp = src.with_suffix(".pruned.tmp")
ok = prune_embeddings_preserve_graph(str(src), str(tmp))
if not ok:
if tmp.exists():
try:
tmp.unlink()
except OSError:
pass
return False
try:
os.replace(str(tmp), str(src))
except Exception:
# Rollback on failure
try:
if tmp.exists():
tmp.unlink()
except OSError:
pass
return False
return True

View File

@@ -16,6 +16,7 @@ from pathlib import Path
from typing import Any, Literal, Optional, Union
import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface
@@ -40,6 +41,7 @@ def compute_embeddings(
use_server: bool = True,
port: Optional[int] = None,
is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -73,6 +75,7 @@ def compute_embeddings(
model_name,
mode=mode,
is_build=is_build,
provider_options=provider_options,
)
@@ -120,20 +123,6 @@ class SearchResult:
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class IncrementalAddContext:
"""Prepared context for safe incremental add operations on an index."""
index_path: str
passages_file: Path
offsets_file: Path
vector_index_file: Path
embedding_model: str
embedding_mode: str
distance_metric: str
prepared_texts: Optional[list[str]] = None
class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
@@ -293,6 +282,7 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs,
):
self.backend_name = backend_name
@@ -315,6 +305,7 @@ class LeannBuilder:
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
@@ -422,6 +413,7 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
provider_options=self.embedding_options,
)[0]
)
path = Path(index_path)
@@ -461,8 +453,20 @@ class LeannBuilder:
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
string_ids = [chunk["id"] for chunk in self.chunks]
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
@@ -487,12 +491,15 @@ class LeannBuilder:
],
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_recompute # Pruned only if compact and recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
@@ -579,6 +586,17 @@ class LeannBuilder:
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
# Persist ID map (order == embeddings order)
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
@@ -607,18 +625,237 @@ class LeannBuilder:
"embeddings_source": str(embeddings_file),
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
# Append passages/offsets before we attempt index.add so the ZMQ server
# can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -642,6 +879,7 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
@@ -653,6 +891,8 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
@@ -1032,405 +1272,8 @@ class LeannChat:
except Exception:
pass
# ------------------------------
# Incremental Add Utilities (HNSW no-recompute only)
# ------------------------------
def _resolve_index_paths(index_path: str) -> tuple[Path, Path, Path]:
"""Given base index path (without extension), return (passages.jsonl, passages.idx, vector.index).
For HNSW, vector index file is typically <stem>.index (e.g., documents.index) even when base is
'documents.leann'. We prefer an existing <stem>.index, otherwise fall back to <name>.index.
"""
base = Path(index_path)
passages_file = base.parent / f"{base.name}.passages.jsonl"
offsets_file = base.parent / f"{base.name}.passages.idx"
candidate_name_index = base.parent / f"{base.name}.index"
candidate_stem_index = base.parent / f"{base.stem}.index"
vector_index_file = (
candidate_stem_index if candidate_stem_index.exists() else candidate_name_index
)
return passages_file, offsets_file, vector_index_file
def _read_meta_file(index_path: str) -> dict[str, Any]:
meta_path = Path(f"{index_path}.meta.json")
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found: {meta_path}")
with open(meta_path, encoding="utf-8") as f:
return json.load(f)
def _load_offset_map_pickle(offsets_file: Path) -> dict[str, int]:
if not offsets_file.exists():
return {}
with open(offsets_file, "rb") as f:
return pickle.load(f)
def _append_passages_and_update_offsets(
passages_file: Path, offsets_file: Path, new_texts: list[str]
) -> list[str]:
"""Append new texts to passages file, update offset map, and return assigned string IDs.
IDs are assigned as incrementing integers based on existing keys in the offset map.
"""
offset_map = _load_offset_map_pickle(offsets_file)
# Compute next numeric id
numeric_ids = [int(x) for x in offset_map.keys() if str(x).isdigit()]
next_id_num = (max(numeric_ids) + 1) if numeric_ids else 0
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for text in new_texts:
offset = f.tell()
str_id = str(next_id_num)
json.dump({"id": str_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[str_id] = offset
assigned_ids.append(str_id)
next_id_num += 1
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
return assigned_ids
def incremental_add_texts(
index_path: str,
texts: list[str],
*,
embedding_model: Optional[str] = None,
embedding_mode: Optional[str] = None,
ef_construction: Optional[int] = None,
recompute: bool = False,
) -> int:
"""Incrementally add text chunks to an existing HNSW index built with no-recompute.
- Validates backend is HNSW and index is non-compact (no-recompute path)
- Appends passages and offsets
- Computes embeddings and appends to the HNSW vector index
Returns number of added chunks.
"""
if not texts:
return 0
meta = _read_meta_file(index_path)
if meta.get("backend_name") != "hnsw":
raise RuntimeError("Incremental add is currently supported only for HNSW backend")
if meta.get("is_compact", True):
raise RuntimeError(
"Index is compact/pruned. Rebuild base with is_recompute=False and is_compact=False for incremental add."
)
passages_file, offsets_file, vector_index_file = _resolve_index_paths(index_path)
if not vector_index_file.exists():
raise FileNotFoundError(
f"Vector index file missing: {vector_index_file}. Build base first with LeannBuilder."
)
# Resolve embedding config from meta if not provided
model_name = embedding_model or meta.get("embedding_model", "facebook/contriever")
mode_name = embedding_mode or meta.get("embedding_mode", "sentence-transformers")
# Append passages and update offsets
assigned_ids = _append_passages_and_update_offsets(passages_file, offsets_file, texts)
# Compute embeddings
# Embedding computation path
esm = None
port = None
if recompute:
# Determine distance metric early for server config
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
# Start embedding server and compute via ZMQ for consistency with recompute semantics
passages_source_file = f"{index_path}.meta.json"
esm = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
)
started, port = esm.start_server(
port=5557,
model_name=model_name,
embedding_mode=mode_name,
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=False,
)
if not started:
raise RuntimeError("Failed to start embedding server for recompute add")
embeddings = compute_embeddings_via_server(texts, model_name, port)
else:
embeddings = compute_embeddings(
texts,
model_name=model_name,
mode=mode_name,
use_server=False,
is_build=True,
)
# Normalize for cosine if needed
if "distance_metric" not in locals():
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
# Append via backend helper (supports ef_construction/recompute plumbing)
try:
from leann_backend_hnsw.hnsw_backend import add_vectors as hnsw_add_vectors # type: ignore
except Exception as e:
raise RuntimeError(
"Failed to import HNSW backend add helper. Ensure HNSW backend is installed."
) from e
# Propagate ZMQ port to FAISS add path when recompute is True
if recompute and port is not None:
os.environ["LEANN_ZMQ_PORT"] = str(port)
hnsw_add_vectors(
str(vector_index_file),
embeddings,
ef_construction=ef_construction,
recompute=recompute,
)
# Stop server after add when recompute path used
if esm is not None:
def __del__(self):
try:
esm.stop_server()
self.cleanup()
except Exception:
pass
# Sanity: ids length should match embeddings rows
if len(assigned_ids) != embeddings.shape[0]:
warnings.warn(
f"Assigned {len(assigned_ids)} IDs but computed {embeddings.shape[0]} embeddings.",
UserWarning,
stacklevel=2,
)
return len(assigned_ids)
def create_incremental_add_context(
index_path: str,
*,
# Optional embedding choices; if None will use meta
embedding_model: Optional[str] = None,
embedding_mode: Optional[str] = None,
# Optional data-to-text preparation in context
data_dir: Optional[str] = None,
required_exts: Optional[list[str]] = None,
chunk_size: int = 256,
chunk_overlap: int = 128,
max_items: int = -1,
) -> IncrementalAddContext:
"""Validate index and prepare context for repeated incremental adds.
Additionally, if data_dir is provided, this function will load documents,
chunk them to texts with the specified parameters, and store them in ctx.prepared_texts.
"""
meta = _read_meta_file(index_path)
if meta.get("backend_name") != "hnsw":
raise RuntimeError("Incremental add is currently supported only for HNSW backend")
if meta.get("is_compact", True):
raise RuntimeError(
"Index is compact/pruned. Rebuild base with is_recompute=False and is_compact=False for incremental add."
)
passages_file, offsets_file, vector_index_file = _resolve_index_paths(index_path)
if not vector_index_file.exists():
raise FileNotFoundError(
f"Vector index file missing: {vector_index_file}. Build base first with LeannBuilder."
)
model_name = embedding_model or meta.get("embedding_model", "facebook/contriever")
mode_name = embedding_mode or meta.get("embedding_mode", "sentence-transformers")
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
prepared_texts: Optional[list[str]] = None
if data_dir is not None:
try:
from llama_index.core import SimpleDirectoryReader # type: ignore
from llama_index.core.node_parser import SentenceSplitter # type: ignore
except Exception as e:
raise RuntimeError(
"llama-index-core is required when using data_dir in create_incremental_add_context"
) from e
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
if required_exts:
reader_kwargs["required_exts"] = required_exts
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
if documents:
splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
prepared_texts = []
for doc in documents:
try:
nodes = splitter.get_nodes_from_documents([doc])
if nodes:
prepared_texts.extend([node.get_content() for node in nodes])
except Exception:
content = doc.get_content()
if content and content.strip():
prepared_texts.append(content.strip())
if max_items > 0 and len(prepared_texts) > max_items:
prepared_texts = prepared_texts[:max_items]
return IncrementalAddContext(
index_path=index_path,
passages_file=passages_file,
offsets_file=offsets_file,
vector_index_file=vector_index_file,
embedding_model=model_name,
embedding_mode=mode_name,
distance_metric=distance_metric,
prepared_texts=prepared_texts,
)
def incremental_add_texts_with_context(
ctx: IncrementalAddContext,
texts: list[str],
*,
ef_construction: Optional[int] = None,
recompute: bool = False,
) -> int:
"""Incrementally add texts using a prepared context (no repeated validation).
For non-compact HNSW, ef_construction (efConstruction) can be overridden during insertion.
"""
if not texts:
return 0
# Append passages & offsets
_append_passages_and_update_offsets(ctx.passages_file, ctx.offsets_file, texts)
# Compute embeddings
# Embedding computation path
esm = None
port = None
if recompute:
passages_source_file = f"{ctx.index_path}.meta.json"
esm = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
)
started, port = esm.start_server(
port=5557,
model_name=ctx.embedding_model,
embedding_mode=ctx.embedding_mode,
passages_file=passages_source_file,
distance_metric=ctx.distance_metric,
enable_warmup=False,
)
if not started:
raise RuntimeError("Failed to start embedding server for recompute add")
embeddings = compute_embeddings_via_server(texts, ctx.embedding_model, port)
else:
embeddings = compute_embeddings(
texts,
model_name=ctx.embedding_model,
mode=ctx.embedding_mode,
use_server=False,
is_build=True,
)
# Normalize for cosine if needed
if ctx.distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
# Append via backend helper (supports ef_construction/recompute plumbing)
try:
from leann_backend_hnsw.hnsw_backend import add_vectors as hnsw_add_vectors # type: ignore
except Exception as e:
raise RuntimeError(
"Failed to import HNSW backend add helper. Ensure HNSW backend is installed."
) from e
if recompute and port is not None:
os.environ["LEANN_ZMQ_PORT"] = str(port)
hnsw_add_vectors(
str(ctx.vector_index_file),
embeddings,
ef_construction=ef_construction,
recompute=recompute,
)
# Stop server after add when recompute path used
if esm is not None:
try:
esm.stop_server()
except Exception:
pass
return embeddings.shape[0]
def incremental_add_directory(
index_path: str,
data_dir: str,
*,
chunk_size: int = 256,
chunk_overlap: int = 128,
required_exts: Optional[list[str]] = None,
max_items: int = -1,
embedding_model: Optional[str] = None,
embedding_mode: Optional[str] = None,
) -> int:
"""Load documents from a directory, chunk them, and incrementally add to an index.
Chunking uses LlamaIndex SentenceSplitter for simplicity and avoids external app dependencies.
"""
try:
from llama_index.core import SimpleDirectoryReader # type: ignore
from llama_index.core.node_parser import SentenceSplitter # type: ignore
except Exception as e:
raise RuntimeError("llama-index-core is required for incremental_add_directory") from e
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
if required_exts:
reader_kwargs["required_exts"] = required_exts
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
if not documents:
return 0
# Traditional text chunking
splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts: list[str] = []
for doc in documents:
try:
nodes = splitter.get_nodes_from_documents([doc])
if nodes:
all_texts.extend([node.get_content() for node in nodes])
except Exception:
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
if max_items > 0 and len(all_texts) > max_items:
all_texts = all_texts[:max_items]
return incremental_add_texts(
index_path,
all_texts,
embedding_model=embedding_model,
embedding_mode=embedding_mode,
)

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
if self.host:
requests.get(self.host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error:
raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
host=llm_config.get("host"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":

View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -123,6 +124,24 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
build_parser.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
build_parser.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
build_parser.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -238,6 +257,11 @@ Examples:
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument(
"--llm",
type=str,
@@ -248,7 +272,12 @@ Examples:
ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
)
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument(
"--host",
type=str,
default=None,
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
ask_parser.add_argument(
"--interactive", "-i", action="store_true", help="Interactive chat mode"
)
@@ -277,6 +306,18 @@ Examples:
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
ask_parser.add_argument(
"--api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
)
ask_parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command
subparsers.add_parser("list", help="List all indexes")
@@ -1325,10 +1366,20 @@ Examples:
print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
@@ -1476,11 +1527,38 @@ Examples:
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
llm_config["host"] = resolve_ollama_host(args.host)
elif args.llm == "openai":
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
@@ -1493,41 +1571,14 @@ Examples:
if not user_input:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
_ask_once(user_input)
else:
query = input("Enter your question: ").strip()
if query:
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
query = initial_query or input("Enter your question: ").strip()
if not query:
print("No question provided. Exiting.")
return
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
_ask_once(query)
async def run(self, args=None):
parser = self.create_parser()

View File

@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any
from typing import Any, Optional
import numpy as np
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -31,6 +33,7 @@ def compute_embeddings(
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -46,6 +49,8 @@ def compute_embeddings(
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
@@ -57,11 +62,21 @@ def compute_embeddings(
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else:
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
host: Ollama host URL (defaults to environment or http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response = requests.get(f"{resolved_host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
)
if test_response.status_code != 200:
error_msg = (
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)

View File

@@ -1,4 +1,5 @@
import atexit
import json
import logging
import os
import socket
@@ -8,6 +9,8 @@ import time
from pathlib import Path
from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable
@@ -46,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -82,16 +164,42 @@ class EmbeddingServerManager:
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = self._build_config_signature(
model_name=model_name,
embedding_mode=embedding_mode,
provider_options=provider_options,
passages_file=passages_file,
)
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server")
return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
return self._start_server_colab(
port,
model_name,
embedding_mode,
config_signature=config_signature,
provider_options=provider_options,
**kwargs,
)
# Always pick a fresh available port
try:
@@ -101,13 +209,40 @@ class EmbeddingServerManager:
return False, port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
)
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
@@ -125,8 +260,21 @@ class EmbeddingServerManager:
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
self._launch_server_process_colab(
command,
actual_port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
@@ -134,7 +282,13 @@ class EmbeddingServerManager:
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -142,8 +296,21 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
self._launch_server_process(
command,
port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
@@ -173,7 +340,14 @@ class EmbeddingServerManager:
return command
def _launch_server_process(self, command: list, port: int) -> None:
def _launch_server_process(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
@@ -193,32 +367,43 @@ class EmbeddingServerManager:
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
env=env,
)
self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
# Record config for in-process reuse (best effort; refined later when ready)
if config_signature is not None:
self._server_config = config_signature
else: # Fallback for unexpected code paths
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
@@ -322,16 +507,29 @@ class EmbeddingServerManager:
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None:
def _launch_server_process_colab(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
@@ -341,11 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
if config_signature is not None:
self._server_config = config_signature
else:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name,
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""

View File

@@ -0,0 +1,74 @@
"""Runtime configuration helpers for LEANN."""
from __future__ import annotations
import json
import os
from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
def _clean_url(value: str) -> str:
"""Normalize URL strings by stripping trailing slashes."""
return value.rstrip("/") if value else value
def resolve_ollama_host(explicit: str | None = None) -> str:
"""Resolve the Ollama-compatible endpoint to use."""
candidates = (
explicit,
os.getenv("LEANN_LOCAL_LLM_HOST"),
os.getenv("LEANN_OLLAMA_HOST"),
os.getenv("OLLAMA_HOST"),
os.getenv("LOCAL_LLM_ENDPOINT"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OLLAMA_HOST)
def resolve_openai_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for OpenAI-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_OPENAI_BASE_URL"),
os.getenv("OPENAI_BASE_URL"),
os.getenv("LOCAL_OPENAI_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
if explicit:
return explicit
return os.getenv("OPENAI_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""
if not options:
return None
try:
return json.dumps(options)
except (TypeError, ValueError):
# Fall back to empty payload if serialization fails
return None

View File

@@ -53,27 +53,10 @@ dependencies = [
"tree-sitter-java>=0.20.0",
"tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0",
"pytest-cov>=4.0",
"pytest-xdist>=3.0", # For parallel test execution
"black>=23.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
"matplotlib",
"huggingface-hub>=0.20.0",
"pre-commit>=3.5.0",
]
test = [
"pytest>=7.0",
"pytest-timeout>=2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
]
diskann = [
"leann-backend-diskann",
]
@@ -101,10 +84,36 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
astchunk = { path = "packages/astchunk-leann", editable = true }
[dependency-groups]
# Minimal lint toolchain for CI and local hooks
lint = [
"pre-commit>=3.5.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
]
# Test toolchain (no heavy project runtime deps)
test = [
"pytest>=7.0",
"pytest-cov>=4.0",
"pytest-xdist>=3.0",
"pytest-timeout>=2.0",
"python-dotenv>=1.0.0",
]
# dependencies by apps/ should list here
dev = [
"matplotlib",
"huggingface-hub>=0.20.0",
]
[tool.ruff]
target-version = "py39"
line-length = 100
extend-exclude = ["third_party"]
extend-exclude = [
"third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
]
[tool.ruff.lint]

121
scripts/hf_upload.py Normal file
View File

@@ -0,0 +1,121 @@
#!/usr/bin/env python3
"""
Upload local evaluation data to Hugging Face Hub, excluding diskann_rpj_wiki.
Defaults:
- repo_id: LEANN-RAG/leann-rag-evaluation-data (dataset)
- folder_path: benchmarks/data
- ignore_patterns: diskann_rpj_wiki/** and .cache/**
Requires authentication via `huggingface-cli login` or HF_TOKEN env var.
"""
from __future__ import annotations
import argparse
import os
try:
from huggingface_hub import HfApi
except Exception as e:
raise SystemExit(
"huggingface_hub is required. Install with: pip install huggingface_hub hf_transfer"
) from e
def _enable_transfer_accel_if_available() -> None:
"""Best-effort enabling of accelerated transfers across hub versions.
Tries the public util if present; otherwise, falls back to env flag when
hf_transfer is installed. Silently no-ops if unavailable.
"""
try:
# Newer huggingface_hub exposes this under utils
from huggingface_hub.utils import hf_hub_enable_hf_transfer # type: ignore
hf_hub_enable_hf_transfer()
return
except Exception:
pass
try:
# If hf_transfer is installed, set env flag recognized by the hub
import hf_transfer # noqa: F401
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
except Exception:
# Acceleration not available; proceed without it
pass
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Upload local data to HF, excluding diskann_rpj_wiki")
p.add_argument(
"--repo-id",
default="LEANN-RAG/leann-rag-evaluation-data",
help="Target dataset repo id (namespace/name)",
)
p.add_argument(
"--folder-path",
default="benchmarks/data",
help="Local folder to upload (default: benchmarks/data)",
)
p.add_argument(
"--ignore",
default=["diskann_rpj_wiki/**", ".cache/**"],
nargs="+",
help="Glob patterns to ignore (space-separated)",
)
p.add_argument(
"--allow",
default=["**"],
nargs="+",
help="Glob patterns to allow (space-separated). Defaults to everything.",
)
p.add_argument(
"--message",
default="sync local data (exclude diskann_rpj_wiki)",
help="Commit message",
)
p.add_argument(
"--no-transfer-accel",
action="store_true",
help="Disable hf_transfer accelerated uploads",
)
return p.parse_args()
def main() -> None:
args = parse_args()
if not args.no_transfer_accel:
_enable_transfer_accel_if_available()
if not os.path.isdir(args.folder_path):
raise SystemExit(f"Folder not found: {args.folder_path}")
print("Uploading to Hugging Face Hub:")
print(f" repo_id: {args.repo_id}")
print(" repo_type: dataset")
print(f" folder_path: {args.folder_path}")
print(f" allow_patterns: {args.allow}")
print(f" ignore_patterns:{args.ignore}")
api = HfApi()
# Perform upload. This skips unchanged files by content hash.
api.upload_folder(
repo_id=args.repo_id,
repo_type="dataset",
folder_path=args.folder_path,
path_in_repo=".",
allow_patterns=args.allow,
ignore_patterns=args.ignore,
commit_message=args.message,
)
print("Upload completed (unchanged files were skipped by the Hub).")
if __name__ == "__main__":
main()

View File

@@ -40,8 +40,8 @@ Tests DiskANN graph partitioning functionality:
### Install test dependencies:
```bash
# Using extras
uv pip install -e ".[test]"
# Using uv dependency groups (tools only)
uv sync --only-group test
```
### Run all tests:

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

View File

@@ -0,0 +1,137 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature

7886
uv.lock generated
View File

File diff suppressed because it is too large Load Diff