Compare commits

..

11 Commits

Author SHA1 Message Date
Andy Lee
fd5c052bd8 Update faiss for batch distances calc & caching when updating 2025-09-30 12:40:05 -07:00
Andy Lee
2f77d0185c Merge remote-tracking branch 'origin/main' into fix-update 2025-09-30 00:56:27 -07:00
Andy Lee
82d536b2ae fix: launch embedding server before adding 2025-09-30 00:53:22 -07:00
yichuan520030910320
e2b37914ce add dynamic add test 2025-09-30 00:48:46 -07:00
Andy Lee
e588100674 fix: set ntotal for storage as well (#129) 2025-09-29 20:43:16 -07:00
Andy Lee
f42e086383 fix: set ntotal for storage as well 2025-09-29 19:10:09 -07:00
Andy Lee
fecee94af1 Experiments (#68)
* feat: finance bench

* docs: results

* chore: ignroe data README

* feat: fix financebench

* feat: laion, also required idmaps support

* style: format

* style: format

* fix: resolve ruff linting errors

- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention

* feat: enron email bench

* experiments for running DiskANN & BM25 on Arch 4090

* style: format

* chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules

* docs: data

* docs: data updated

* fix: as package

* fix(ci): only run pre-commit

* chore: use http url of astchunk; use group for some dev deps

* fix(ci): should checkout modules as well since `uv sync` checks

* fix(ci): run with lint only

* fix: find links to install wheels available

* CI: force local wheels in uv install step

* CI: install local wheels via file paths

* CI: pick wheels matching current Python tag

* CI: handle python tag mismatches for local wheels

* CI: use matrix python venv and set macOS deployment target

* CI: revert install step to match main

* CI: use uv group install with local wheel selection

* CI: rely on setup-uv for Python and tighten group install

* CI: install build deps with uv python interpreter

* CI: use temporary uv venv for build deps

* CI: add build venv scripts path for wheel repair
2025-09-24 11:19:04 -07:00
yichuan520030910320
01475c10a0 add img 2025-09-23 23:25:05 -07:00
yichuan520030910320
c8aa063f48 merge main 2025-09-23 23:21:53 -07:00
yichuan520030910320
576beb13db add doc about multimodal 2025-09-23 23:21:03 -07:00
Andy Lee
63c7b0c8a3 Fix restart embedding server when passages change (#117)
* fix: restart embedding server when passages change

* fix: restore python 3.9 typing compatibility
2025-09-23 22:28:36 -07:00
13 changed files with 593 additions and 504 deletions

6
.gitignore vendored
View File

@@ -99,3 +99,9 @@ benchmarks/data/
## multi vector ## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*

View File

@@ -0,0 +1,113 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -4,39 +4,24 @@
# pip install tqdm # pip install tqdm
# pip install pillow # pip install pillow
# %%
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
# %%
import os import os
import re
import sys
from pathlib import Path from pathlib import Path
from typing import cast
# Make local leann packages importable without installing from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3] _repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
import sys
if str(_leann_core_src) not in sys.path: if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src)) sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path: if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg)) sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
class LeannRetriever(LeannMultiVector):
pass
# %%
from typing import cast
import torch import torch
from colpali_engine.models import ColPali from colpali_engine.models import ColPali
@@ -88,13 +73,6 @@ for batch_query in dataloader:
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape) print(qs[0].shape)
# %% # %%
import re
from PIL import Image
from tqdm import tqdm
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group())) page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames] images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]

View File

@@ -169,7 +169,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
) )
doc_vecs: list[Any] = [] doc_vecs: list[Any] = []
for batch_doc in dataloader: for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad(): with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32 # autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
@@ -200,7 +200,7 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
) )
q_vecs: list[Any] = [] q_vecs: list[Any] = []
for batch_query in dataloader: for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad(): with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()} batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
@@ -362,7 +362,7 @@ if USE_HF_DATASET:
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = [] filepaths: list[str] = []
images: list[Image.Image] = [] images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset"): for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i] p = dataset[i]
# Compose a descriptive identifier for printing later # Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"

View File

@@ -43,7 +43,11 @@ from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1] REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?" DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"] DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"] DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
@@ -182,6 +186,7 @@ def run_workflow(
is_recompute: bool, is_recompute: bool,
query: str, query: str,
top_k: int, top_k: int,
skip_search: bool,
) -> dict[str, Any]: ) -> dict[str, Any]:
prefix = f"[{label}] " if label else "" prefix = f"[{label}] " if label else ""
@@ -198,12 +203,15 @@ def run_workflow(
) )
initial_size = index_file_size(index_path) initial_size = index_file_size(index_path)
before_results = run_search( if not skip_search:
index_path, before_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...") print(f"\n{prefix}Updating index with additional passages...")
update_index( update_index(
@@ -215,20 +223,23 @@ def run_workflow(
is_recompute=is_recompute, is_recompute=is_recompute,
) )
after_results = run_search( if not skip_search:
index_path, after_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path) updated_size = index_file_size(index_path)
return { return {
"initial_size": initial_size, "initial_size": initial_size,
"updated_size": updated_size, "updated_size": updated_size,
"delta": updated_size - initial_size, "delta": updated_size - initial_size,
"before_results": before_results, "before_results": before_results if not skip_search else None,
"after_results": after_results, "after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path), "metadata": load_metadata_snapshot(index_path),
} }
@@ -314,6 +325,12 @@ def main() -> None:
action="store_false", action="store_false",
help="Skip building the no-recompute baseline.", help="Skip building the no-recompute baseline.",
) )
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True) parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args() args = parser.parse_args()
@@ -350,10 +367,13 @@ def main() -> None:
is_recompute=True, is_recompute=True,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print_results("initial search", recompute_stats["before_results"]) if not args.skip_search:
print_results("after update", recompute_stats["after_results"]) print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print( print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes" f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})" f"{recompute_stats['delta']})"
@@ -378,6 +398,7 @@ def main() -> None:
is_recompute=False, is_recompute=False,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print( print(
@@ -385,8 +406,12 @@ def main() -> None:
f"{baseline_stats['delta']})" f"{baseline_stats['delta']})"
) )
after_texts = [res.text for res in recompute_stats["after_results"]] after_texts = (
baseline_after_texts = [res.text for res in baseline_stats["after_results"]] [res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts: if after_texts == baseline_after_texts:
print( print(
"[no-recompute] Search results match recompute baseline; see above for the shared output." "[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import os
import pickle import pickle
import re import re
import subprocess import subprocess
@@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
@@ -728,6 +730,7 @@ class LeannBuilder:
index = faiss.read_index(str(index_file)) index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"): if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None: if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT: if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d) storage_index = faiss.IndexFlatIP(index.d)
@@ -735,37 +738,107 @@ class LeannBuilder:
storage_index = faiss.IndexFlatL2(index.d) storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index index.storage = storage_index
index.own_fields = True index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim: if index.d != embedding_dim:
raise ValueError( raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})." f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
) )
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks): for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset) new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) # Append passages/offsets before we attempt index.add so the ZMQ server
faiss.write_index(index, str(index_file)) # can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
with open(passages_file, "a", encoding="utf-8") as f: try:
for chunk in valid_chunks: with open(passages_file, "a", encoding="utf-8") as f:
offset = f.tell() for chunk in valid_chunks:
json.dump( offset = f.tell()
{ json.dump(
"id": chunk["id"], {
"text": chunk["text"], "id": chunk["id"],
"metadata": chunk.get("metadata", {}), "text": chunk["text"],
}, "metadata": chunk.get("metadata", {}),
f, },
ensure_ascii=False, f,
) ensure_ascii=False,
f.write("\n") )
offset_map[chunk["id"]] = offset f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f: with open(offset_file, "wb") as f:
pickle.dump(offset_map, f) pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map) meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f: with open(meta_path, "w", encoding="utf-8") as f:

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest( def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434" model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
available_models = check_ollama_models(host) resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models: if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation." error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface): class OllamaChat(LLMInterface):
"""LLM interface for Ollama models.""" """LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model self.model = model
self.host = host self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try: try:
import requests import requests
# Check if the Ollama server is responsive # Check if the Ollama server is responsive
if host: if self.host:
requests.get(host) requests.get(self.host)
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host) model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running." f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
@@ -577,33 +582,18 @@ class HFChat(LLMInterface):
def timeout_handler(signum, frame): def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out") raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (increase to 300s for large models) # Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler) old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(300) signal.alarm(60)
try: try:
logger.info(f"Loading tokenizer for {model_name}...") logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...") logger.info(f"Loading model {model_name}...")
# Choose a numerically stable dtype per device
if self.device == "cuda":
# Prefer bfloat16 when available; otherwise fall back to float16
try:
bf16_ok = torch.cuda.is_bf16_supported()
except Exception:
bf16_ok = False
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
elif self.device == "mps":
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
# Use float32 for stability, even if it increases memory.
load_dtype = torch.float32
else:
load_dtype = torch.float32
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
torch_dtype=load_dtype, torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None, device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True, trust_remote_code=True,
) )
@@ -621,12 +611,8 @@ class HFChat(LLMInterface):
logger.error(f"Failed to load model {model_name}: {e}") logger.error(f"Failed to load model {model_name}: {e}")
raise raise
# Move model to device only if not managed by accelerate (no device_map) # Move model to device if not using device_map
try: if self.device != "cpu" and "device_map" not in str(self.model):
has_device_map = getattr(self.model, "hf_device_map", None) is not None
except Exception:
has_device_map = False
if self.device != "cpu" and not has_device_map:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
# Set pad token if not present # Set pad token if not present
@@ -658,15 +644,13 @@ class HFChat(LLMInterface):
# Fallback for models without chat template # Fallback for models without chat template
formatted_prompt = prompt formatted_prompt = prompt
# Tokenize input (respect model context length when available) # Tokenize input
inputs = self.tokenizer( inputs = self.tokenizer(
formatted_prompt, formatted_prompt,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
max_length=getattr( max_length=2048,
getattr(self.model, "config", None), "max_position_embeddings", 2048
),
) )
# Move inputs to device # Move inputs to device
@@ -681,8 +665,6 @@ class HFChat(LLMInterface):
"do_sample": kwargs.get("temperature", 0.7) > 0, "do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id, "eos_token_id": self.tokenizer.eos_token_id,
# Helps avoid numerical issues in sampling when logits processors are used
"renormalize_logits": True,
} }
# Handle temperature=0 for greedy decoding # Handle temperature=0 for greedy decoding
@@ -692,39 +674,11 @@ class HFChat(LLMInterface):
logger.info(f"Generating with HuggingFace model, config: {generation_config}") logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Streaming support (optional) # Generate
stream = bool(kwargs.get("stream", False))
if stream:
try:
from threading import Thread
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
def _gen():
with torch.no_grad():
self.model.generate(**inputs, **generation_config, streamer=streamer)
t = Thread(target=_gen)
t.start()
pieces = []
for new_text in streamer:
print(new_text, end="", flush=True)
pieces.append(new_text)
t.join()
print("") # newline after streaming
return ("".join(pieces)).strip()
except Exception as e:
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
# Non-streaming path
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config) outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@@ -788,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models.""" """LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
) )
logger.info(f"Initializing OpenAI Chat with model='{model}'") logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try: try:
import openai import openai
self.client = openai.OpenAI(api_key=self.api_key) self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -892,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama": if llm_type == "ollama":
return OllamaChat( return OllamaChat(
model=model or "llama3:8b", model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"), host=llm_config.get("host"),
) )
elif llm_type == "hf": elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai": elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated": elif llm_type == "simulated":

View File

@@ -1,4 +1,5 @@
import atexit import atexit
import json
import logging import logging
import os import os
import socket import socket
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity # Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A simplified manager for embedding server processes that avoids complex update mechanisms. A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None) provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = { config_signature = self._build_config_signature(
"model_name": model_name, model_name=model_name,
"passages_file": kwargs.get("passages_file", ""), embedding_mode=embedding_mode,
"embedding_mode": embedding_mode, provider_options=provider_options,
"provider_options": provider_options or {}, passages_file=passages_file,
} )
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if ( if (
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
port, port,
model_name, model_name,
embedding_mode, embedding_mode,
config_signature=config_signature,
provider_options=provider_options, provider_options=provider_options,
**kwargs, **kwargs,
) )
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
**kwargs, **kwargs,
) )
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
command, command,
actual_port, actual_port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready_colab(actual_port) started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started: if started:
self._server_config = { self._server_config = config_signature or {
"model_name": model_name, "model_name": model_name,
"passages_file": kwargs.get("passages_file", ""), "passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode, "embedding_mode": embedding_mode,
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
command, command,
port, port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready(port) started, ready_port = self._wait_for_server_ready(port)
if started: if started:
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready) # Record config for in-process reuse (best effort; refined later when ready)
try: if config_signature is not None:
self._server_config = { self._server_config = config_signature
"model_name": command[command.index("--model-name") + 1] else: # Fallback for unexpected code paths
if "--model-name" in command try:
else "", self._server_config = {
"passages_file": command[command.index("--passages-file") + 1] "model_name": command[command.index("--model-name") + 1]
if "--passages-file" in command if "--model-name" in command
else "", else "",
"embedding_mode": command[command.index("--embedding-mode") + 1] "passages_file": command[command.index("--passages-file") + 1]
if "--embedding-mode" in command if "--passages-file" in command
else "sentence-transformers", else "",
"provider_options": provider_options or {}, "embedding_mode": command[command.index("--embedding-mode") + 1]
} if "--embedding-mode" in command
except Exception: else "sentence-transformers",
self._server_config = { "provider_options": provider_options or {},
"model_name": "", }
"passages_file": "", except Exception:
"embedding_mode": "sentence-transformers", self._server_config = {
"provider_options": provider_options or {}, "model_name": "",
} "passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process) atexit.register(self._finalize_process)
self._atexit_registered = True self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode # Record config for in-process reuse is best-effort in Colab mode
self._server_config = { if config_signature is not None:
"model_name": "", self._server_config = config_signature
"passages_file": "", else:
"embedding_mode": "sentence-transformers", self._server_config = {
"provider_options": provider_options or {}, "model_name": "",
} "passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout.""" """Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -111,7 +111,7 @@ target-version = "py39"
line-length = 100 line-length = 100
extend-exclude = [ extend-exclude = [
"third_party", "third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py", "apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py" "apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
] ]

View File

@@ -1,324 +0,0 @@
#!/usr/bin/env python3
"""Measure generation latency of a HuggingFace/OpenAI-compatible model over prompt files."""
import argparse
import contextlib
import io
import json
import logging
import time
from pathlib import Path
from leann.chat import get_llm
PROMPT_PREFIX = "PROMPT #"
logging.getLogger("leann.chat").setLevel(logging.ERROR)
def load_prompts(path: Path) -> list[str]:
prompts: list[str] = []
buffer: list[str] = []
collecting = False
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.startswith(PROMPT_PREFIX):
if buffer:
prompts.append("".join(buffer).strip())
buffer.clear()
collecting = True
continue
if collecting:
buffer.append(line)
if buffer:
prompts.append("".join(buffer).strip())
return prompts
def measure_generation_times(
prompts: list[str],
llm,
generation_kwargs: dict[str, object],
allow_truncation: bool,
enable_qwen_thinking: bool,
verbose: bool,
per_call_timeout: int | None,
):
timings: list[float] = []
tokenizer = getattr(llm, "tokenizer", None)
max_positions = None
if hasattr(llm, "model") and hasattr(llm.model, "config"):
max_positions = getattr(llm.model.config, "max_position_embeddings", None)
requested_new_tokens = None
if max_positions is not None:
if "max_new_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_new_tokens"]
elif "max_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_tokens"]
context_max_length = max_positions
if max_positions is not None and requested_new_tokens is not None:
if requested_new_tokens >= max_positions:
requested_new_tokens = max_positions - 1
context_max_length = max(max_positions - requested_new_tokens, 1)
suppress_buffer = io.StringIO()
# Log base config
if verbose:
device = getattr(llm, "device", None)
try:
dtype = getattr(getattr(llm, "model", None), "dtype", None)
except Exception:
dtype = None
print(
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
)
total = len(prompts)
for idx, prompt in enumerate(prompts, start=1):
prompt_for_llm = prompt
if (
enable_qwen_thinking
and "/think" not in prompt_for_llm
and "/no_think" not in prompt_for_llm
):
prompt_for_llm = f"{prompt_for_llm}\n/think"
if allow_truncation and tokenizer is not None and max_positions is not None:
tokenized = tokenizer(
prompt_for_llm,
truncation=True,
max_length=context_max_length,
return_tensors="pt",
)
prompt_for_llm = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
per_call_kwargs = dict(generation_kwargs)
if requested_new_tokens is not None:
per_call_kwargs["max_new_tokens"] = requested_new_tokens
# Enable streaming if requested (HF backend will print tokens)
if verbose:
# When verbose (or --stream propagated), enable streaming in HF backend
per_call_kwargs["stream"] = True
# Extra debug info about token lengths
if verbose and tokenizer is not None:
try:
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
in_len = (
len(toks["input_ids"])
if isinstance(toks["input_ids"], list)
else len(toks["input_ids"][0])
)
except Exception:
in_len = None
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
print(
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
)
start = time.perf_counter()
# Optional per-call timeout using signal alarm
timeout_handler_installed = False
if per_call_timeout is not None:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("generation timed out")
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(int(per_call_timeout))
timeout_handler_installed = True
try:
if verbose:
print("[dbg] generation_start")
llm.ask(prompt_for_llm, **per_call_kwargs)
print("[dbg] generation_done")
else:
with contextlib.redirect_stdout(suppress_buffer):
llm.ask(prompt_for_llm, **per_call_kwargs)
except TimeoutError:
if verbose:
print("[dbg] generation_timeout")
finally:
if timeout_handler_installed:
import signal
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
end = time.perf_counter()
timings.append(end - start)
suppress_buffer.seek(0)
suppress_buffer.truncate(0)
return timings
def parse_args():
parser = argparse.ArgumentParser(description="Measure generation timing for prompt files")
parser.add_argument(
"--max-prompts",
type=int,
default=None,
help="Optional limit on number of prompts to evaluate per file",
)
parser.add_argument(
"--allow-truncation",
action="store_true",
help="Allow truncating prompt context to respect model's max context",
)
parser.add_argument(
"--model",
type=str,
default="sshleifer/tiny-gpt2",
help="LLM model identifier (default: sshleifer/tiny-gpt2)",
)
parser.add_argument(
"--llm-type",
type=str,
default="hf",
choices=["hf", "openai", "ollama", "gemini", "simulated"],
help="LLM backend type (default: hf)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "auto"],
help="Device override for HF models (default: cpu)",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=16,
help="Max new tokens per generation (default: 16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="Sampling temperature (default: 0.2)",
)
parser.add_argument(
"--top-p",
type=float,
default=0.8,
help="Nucleus sampling top-p (default: 0.8)",
)
parser.add_argument(
"--qwen-thinking",
action="store_true",
help="Append /think to prompts for Qwen models",
)
parser.add_argument(
"--no-max-new-tokens",
action="store_true",
help="Do not set max_new_tokens in generation kwargs",
)
parser.add_argument(
"--per-call-timeout",
type=int,
default=None,
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream generated text to stdout during generation",
)
parser.add_argument(
"--datasets",
type=str,
default=None,
help=(
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
"Default: all"
),
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable debug logging and show generation progress",
)
return parser.parse_args()
def main():
args = parse_args()
dataset_map = {
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
}
if args.datasets:
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
invalid = [k for k in selected if k not in dataset_map]
if invalid:
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
dataset_files = [dataset_map[k] for k in selected]
else:
dataset_files = list(dataset_map.values())
generation_kwargs = {
"temperature": args.temperature,
"top_p": args.top_p,
}
if not args.no_max_new_tokens:
generation_kwargs["max_new_tokens"] = args.max_new_tokens
results: dict[str, dict[str, float | int]] = {}
llm_config = {"type": args.llm_type, "model": args.model}
try:
llm = get_llm(llm_config)
except Exception as exc:
print(f"Failed to initialize LLM: {exc}")
raise SystemExit(1) from exc
if args.llm_type == "hf" and hasattr(llm, "model") and args.device == "cpu":
llm.model = llm.model.to("cpu")
if hasattr(llm, "device"):
llm.device = "cpu"
for dataset_path in dataset_files:
print(f"Processing {dataset_path.name}...")
prompts = load_prompts(dataset_path)
if args.max_prompts is not None:
prompts = prompts[: args.max_prompts]
if args.verbose:
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
timings = measure_generation_times(
prompts,
llm,
generation_kwargs,
args.allow_truncation,
args.qwen_thinking,
args.verbose or args.stream,
args.per_call_timeout,
)
total_time = sum(timings)
count = len(timings)
average_time = total_time / count if count else 0.0
results[str(dataset_path.name)] = {
"total_prompts": count,
"total_time_seconds": total_time,
"average_time_seconds": average_time,
}
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,137 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature