Compare commits

..

11 Commits

Author SHA1 Message Date
Andy Lee
fd5c052bd8 Update faiss for batch distances calc & caching when updating 2025-09-30 12:40:05 -07:00
Andy Lee
2f77d0185c Merge remote-tracking branch 'origin/main' into fix-update 2025-09-30 00:56:27 -07:00
Andy Lee
82d536b2ae fix: launch embedding server before adding 2025-09-30 00:53:22 -07:00
yichuan520030910320
e2b37914ce add dynamic add test 2025-09-30 00:48:46 -07:00
Andy Lee
e588100674 fix: set ntotal for storage as well (#129) 2025-09-29 20:43:16 -07:00
Andy Lee
f42e086383 fix: set ntotal for storage as well 2025-09-29 19:10:09 -07:00
Andy Lee
fecee94af1 Experiments (#68)
* feat: finance bench

* docs: results

* chore: ignroe data README

* feat: fix financebench

* feat: laion, also required idmaps support

* style: format

* style: format

* fix: resolve ruff linting errors

- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention

* feat: enron email bench

* experiments for running DiskANN & BM25 on Arch 4090

* style: format

* chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules

* docs: data

* docs: data updated

* fix: as package

* fix(ci): only run pre-commit

* chore: use http url of astchunk; use group for some dev deps

* fix(ci): should checkout modules as well since `uv sync` checks

* fix(ci): run with lint only

* fix: find links to install wheels available

* CI: force local wheels in uv install step

* CI: install local wheels via file paths

* CI: pick wheels matching current Python tag

* CI: handle python tag mismatches for local wheels

* CI: use matrix python venv and set macOS deployment target

* CI: revert install step to match main

* CI: use uv group install with local wheel selection

* CI: rely on setup-uv for Python and tighten group install

* CI: install build deps with uv python interpreter

* CI: use temporary uv venv for build deps

* CI: add build venv scripts path for wheel repair
2025-09-24 11:19:04 -07:00
yichuan520030910320
01475c10a0 add img 2025-09-23 23:25:05 -07:00
yichuan520030910320
c8aa063f48 merge main 2025-09-23 23:21:53 -07:00
yichuan520030910320
576beb13db add doc about multimodal 2025-09-23 23:21:03 -07:00
Andy Lee
63c7b0c8a3 Fix restart embedding server when passages change (#117)
* fix: restart embedding server when passages change

* fix: restore python 3.9 typing compatibility
2025-09-23 22:28:36 -07:00
13 changed files with 593 additions and 504 deletions

6
.gitignore vendored
View File

@@ -99,3 +99,9 @@ benchmarks/data/
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*

View File

@@ -0,0 +1,113 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -4,39 +4,24 @@
# pip install tqdm
# pip install pillow
# %%
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
# %%
import os
import re
import sys
from pathlib import Path
from typing import cast
# Make local leann packages importable without installing
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
import sys
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
class LeannRetriever(LeannMultiVector):
pass
# %%
from typing import cast
import torch
from colpali_engine.models import ColPali
@@ -88,13 +73,6 @@ for batch_query in dataloader:
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
import re
from PIL import Image
from tqdm import tqdm
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]

View File

@@ -169,7 +169,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
)
doc_vecs: list[Any] = []
for batch_doc in dataloader:
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
@@ -200,7 +200,7 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
)
q_vecs: list[Any] = []
for batch_query in dataloader:
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
@@ -362,7 +362,7 @@ if USE_HF_DATASET:
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset"):
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"

View File

@@ -43,7 +43,11 @@ from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
@@ -182,6 +186,7 @@ def run_workflow(
is_recompute: bool,
query: str,
top_k: int,
skip_search: bool,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
@@ -198,12 +203,15 @@ def run_workflow(
)
initial_size = index_file_size(index_path)
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
if not skip_search:
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...")
update_index(
@@ -215,20 +223,23 @@ def run_workflow(
is_recompute=is_recompute,
)
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
if not skip_search:
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results,
"after_results": after_results,
"before_results": before_results if not skip_search else None,
"after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path),
}
@@ -314,6 +325,12 @@ def main() -> None:
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
@@ -350,10 +367,13 @@ def main() -> None:
is_recompute=True,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print_results("initial search", recompute_stats["before_results"])
print_results("after update", recompute_stats["after_results"])
if not args.skip_search:
print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
@@ -378,6 +398,7 @@ def main() -> None:
is_recompute=False,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print(
@@ -385,8 +406,12 @@ def main() -> None:
f"{baseline_stats['delta']})"
)
after_texts = [res.text for res in recompute_stats["after_results"]]
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
after_texts = (
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json
import logging
import os
import pickle
import re
import subprocess
@@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
@@ -728,6 +730,7 @@ class LeannBuilder:
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
@@ -735,37 +738,107 @@ class LeannBuilder:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
# Append passages/offsets before we attempt index.add so the ZMQ server
# can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
if self.host:
requests.get(self.host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error:
raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
@@ -577,33 +582,18 @@ class HFChat(LLMInterface):
def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (increase to 300s for large models)
# Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(300)
signal.alarm(60)
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
# Choose a numerically stable dtype per device
if self.device == "cuda":
# Prefer bfloat16 when available; otherwise fall back to float16
try:
bf16_ok = torch.cuda.is_bf16_supported()
except Exception:
bf16_ok = False
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
elif self.device == "mps":
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
# Use float32 for stability, even if it increases memory.
load_dtype = torch.float32
else:
load_dtype = torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=load_dtype,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
@@ -621,12 +611,8 @@ class HFChat(LLMInterface):
logger.error(f"Failed to load model {model_name}: {e}")
raise
# Move model to device only if not managed by accelerate (no device_map)
try:
has_device_map = getattr(self.model, "hf_device_map", None) is not None
except Exception:
has_device_map = False
if self.device != "cpu" and not has_device_map:
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
@@ -658,15 +644,13 @@ class HFChat(LLMInterface):
# Fallback for models without chat template
formatted_prompt = prompt
# Tokenize input (respect model context length when available)
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=getattr(
getattr(self.model, "config", None), "max_position_embeddings", 2048
),
max_length=2048,
)
# Move inputs to device
@@ -681,8 +665,6 @@ class HFChat(LLMInterface):
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
# Helps avoid numerical issues in sampling when logits processors are used
"renormalize_logits": True,
}
# Handle temperature=0 for greedy decoding
@@ -692,39 +674,11 @@ class HFChat(LLMInterface):
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Streaming support (optional)
stream = bool(kwargs.get("stream", False))
if stream:
try:
from threading import Thread
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
def _gen():
with torch.no_grad():
self.model.generate(**inputs, **generation_config, streamer=streamer)
t = Thread(target=_gen)
t.start()
pieces = []
for new_text in streamer:
print(new_text, end="", flush=True)
pieces.append(new_text)
t.join()
print("") # newline after streaming
return ("".join(pieces)).strip()
except Exception as e:
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
# Non-streaming path
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@@ -788,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -892,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
host=llm_config.get("host"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":

View File

@@ -1,4 +1,5 @@
import atexit
import json
import logging
import os
import socket
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager:
"""
A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
config_signature = self._build_config_signature(
model_name=model_name,
embedding_mode=embedding_mode,
provider_options=provider_options,
passages_file=passages_file,
)
# If this manager already has a live server, just reuse it
if (
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
port,
model_name,
embedding_mode,
config_signature=config_signature,
provider_options=provider_options,
**kwargs,
)
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
**kwargs,
)
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
command,
actual_port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = {
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
command,
port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
)
self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready)
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
if config_signature is not None:
self._server_config = config_signature
else: # Fallback for unexpected code paths
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
if config_signature is not None:
self._server_config = config_signature
else:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -111,7 +111,7 @@ target-version = "py39"
line-length = 100
extend-exclude = [
"third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
]

View File

@@ -1,324 +0,0 @@
#!/usr/bin/env python3
"""Measure generation latency of a HuggingFace/OpenAI-compatible model over prompt files."""
import argparse
import contextlib
import io
import json
import logging
import time
from pathlib import Path
from leann.chat import get_llm
PROMPT_PREFIX = "PROMPT #"
logging.getLogger("leann.chat").setLevel(logging.ERROR)
def load_prompts(path: Path) -> list[str]:
prompts: list[str] = []
buffer: list[str] = []
collecting = False
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.startswith(PROMPT_PREFIX):
if buffer:
prompts.append("".join(buffer).strip())
buffer.clear()
collecting = True
continue
if collecting:
buffer.append(line)
if buffer:
prompts.append("".join(buffer).strip())
return prompts
def measure_generation_times(
prompts: list[str],
llm,
generation_kwargs: dict[str, object],
allow_truncation: bool,
enable_qwen_thinking: bool,
verbose: bool,
per_call_timeout: int | None,
):
timings: list[float] = []
tokenizer = getattr(llm, "tokenizer", None)
max_positions = None
if hasattr(llm, "model") and hasattr(llm.model, "config"):
max_positions = getattr(llm.model.config, "max_position_embeddings", None)
requested_new_tokens = None
if max_positions is not None:
if "max_new_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_new_tokens"]
elif "max_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_tokens"]
context_max_length = max_positions
if max_positions is not None and requested_new_tokens is not None:
if requested_new_tokens >= max_positions:
requested_new_tokens = max_positions - 1
context_max_length = max(max_positions - requested_new_tokens, 1)
suppress_buffer = io.StringIO()
# Log base config
if verbose:
device = getattr(llm, "device", None)
try:
dtype = getattr(getattr(llm, "model", None), "dtype", None)
except Exception:
dtype = None
print(
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
)
total = len(prompts)
for idx, prompt in enumerate(prompts, start=1):
prompt_for_llm = prompt
if (
enable_qwen_thinking
and "/think" not in prompt_for_llm
and "/no_think" not in prompt_for_llm
):
prompt_for_llm = f"{prompt_for_llm}\n/think"
if allow_truncation and tokenizer is not None and max_positions is not None:
tokenized = tokenizer(
prompt_for_llm,
truncation=True,
max_length=context_max_length,
return_tensors="pt",
)
prompt_for_llm = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
per_call_kwargs = dict(generation_kwargs)
if requested_new_tokens is not None:
per_call_kwargs["max_new_tokens"] = requested_new_tokens
# Enable streaming if requested (HF backend will print tokens)
if verbose:
# When verbose (or --stream propagated), enable streaming in HF backend
per_call_kwargs["stream"] = True
# Extra debug info about token lengths
if verbose and tokenizer is not None:
try:
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
in_len = (
len(toks["input_ids"])
if isinstance(toks["input_ids"], list)
else len(toks["input_ids"][0])
)
except Exception:
in_len = None
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
print(
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
)
start = time.perf_counter()
# Optional per-call timeout using signal alarm
timeout_handler_installed = False
if per_call_timeout is not None:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("generation timed out")
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(int(per_call_timeout))
timeout_handler_installed = True
try:
if verbose:
print("[dbg] generation_start")
llm.ask(prompt_for_llm, **per_call_kwargs)
print("[dbg] generation_done")
else:
with contextlib.redirect_stdout(suppress_buffer):
llm.ask(prompt_for_llm, **per_call_kwargs)
except TimeoutError:
if verbose:
print("[dbg] generation_timeout")
finally:
if timeout_handler_installed:
import signal
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
end = time.perf_counter()
timings.append(end - start)
suppress_buffer.seek(0)
suppress_buffer.truncate(0)
return timings
def parse_args():
parser = argparse.ArgumentParser(description="Measure generation timing for prompt files")
parser.add_argument(
"--max-prompts",
type=int,
default=None,
help="Optional limit on number of prompts to evaluate per file",
)
parser.add_argument(
"--allow-truncation",
action="store_true",
help="Allow truncating prompt context to respect model's max context",
)
parser.add_argument(
"--model",
type=str,
default="sshleifer/tiny-gpt2",
help="LLM model identifier (default: sshleifer/tiny-gpt2)",
)
parser.add_argument(
"--llm-type",
type=str,
default="hf",
choices=["hf", "openai", "ollama", "gemini", "simulated"],
help="LLM backend type (default: hf)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "auto"],
help="Device override for HF models (default: cpu)",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=16,
help="Max new tokens per generation (default: 16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="Sampling temperature (default: 0.2)",
)
parser.add_argument(
"--top-p",
type=float,
default=0.8,
help="Nucleus sampling top-p (default: 0.8)",
)
parser.add_argument(
"--qwen-thinking",
action="store_true",
help="Append /think to prompts for Qwen models",
)
parser.add_argument(
"--no-max-new-tokens",
action="store_true",
help="Do not set max_new_tokens in generation kwargs",
)
parser.add_argument(
"--per-call-timeout",
type=int,
default=None,
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream generated text to stdout during generation",
)
parser.add_argument(
"--datasets",
type=str,
default=None,
help=(
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
"Default: all"
),
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable debug logging and show generation progress",
)
return parser.parse_args()
def main():
args = parse_args()
dataset_map = {
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
}
if args.datasets:
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
invalid = [k for k in selected if k not in dataset_map]
if invalid:
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
dataset_files = [dataset_map[k] for k in selected]
else:
dataset_files = list(dataset_map.values())
generation_kwargs = {
"temperature": args.temperature,
"top_p": args.top_p,
}
if not args.no_max_new_tokens:
generation_kwargs["max_new_tokens"] = args.max_new_tokens
results: dict[str, dict[str, float | int]] = {}
llm_config = {"type": args.llm_type, "model": args.model}
try:
llm = get_llm(llm_config)
except Exception as exc:
print(f"Failed to initialize LLM: {exc}")
raise SystemExit(1) from exc
if args.llm_type == "hf" and hasattr(llm, "model") and args.device == "cpu":
llm.model = llm.model.to("cpu")
if hasattr(llm, "device"):
llm.device = "cpu"
for dataset_path in dataset_files:
print(f"Processing {dataset_path.name}...")
prompts = load_prompts(dataset_path)
if args.max_prompts is not None:
prompts = prompts[: args.max_prompts]
if args.verbose:
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
timings = measure_generation_times(
prompts,
llm,
generation_kwargs,
args.allow_truncation,
args.qwen_thinking,
args.verbose or args.stream,
args.per_call_timeout,
)
total_time = sum(timings)
count = len(timings)
average_time = total_time / count if count else 0.0
results[str(dataset_path.name)] = {
"total_prompts": count,
"total_time_seconds": total_time,
"average_time_seconds": average_time,
}
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,137 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature