Compare commits

..

36 Commits

Author SHA1 Message Date
Andy Lee
4e9e2f3da0 CI: add build venv scripts path for wheel repair 2025-09-24 01:39:35 -07:00
Andy Lee
ed167f43b0 CI: use temporary uv venv for build deps 2025-09-24 01:23:16 -07:00
Andy Lee
f9746d3fe2 CI: install build deps with uv python interpreter 2025-09-24 01:19:19 -07:00
Andy Lee
a090a3444a CI: rely on setup-uv for Python and tighten group install 2025-09-24 01:14:54 -07:00
Andy Lee
aaaba27a4f CI: use uv group install with local wheel selection 2025-09-24 01:10:16 -07:00
Andy Lee
f40f539456 CI: revert install step to match main 2025-09-24 00:50:27 -07:00
Andy Lee
576a2dcb49 CI: use matrix python venv and set macOS deployment target 2025-09-24 00:48:27 -07:00
Andy Lee
ad8ab84675 CI: handle python tag mismatches for local wheels 2025-09-23 23:24:02 -07:00
Andy Lee
58b96b64d8 CI: pick wheels matching current Python tag 2025-09-23 23:05:32 -07:00
Andy Lee
a76c3cdac4 CI: install local wheels via file paths 2025-09-23 22:53:44 -07:00
Andy Lee
520619deab CI: force local wheels in uv install step 2025-09-23 22:27:31 -07:00
Andy Lee
dea08c95b4 Merge remote-tracking branch 'origin/main' into financebench 2025-09-23 21:52:14 -07:00
Andy Lee
3357d5765e fix: find links to install wheels available 2025-09-15 22:22:38 -07:00
Andy Lee
9dbd0c64cc fix(ci): run with lint only 2025-09-15 21:55:19 -07:00
Andy Lee
9c400acd7e fix(ci): should checkout modules as well since uv sync checks 2025-09-15 21:40:35 -07:00
Andy Lee
ac560964f5 chore: use http url of astchunk; use group for some dev deps 2025-09-15 21:21:09 -07:00
Andy Lee
07e4f176e1 fix(ci): only run pre-commit 2025-09-15 19:57:56 -07:00
Andy Lee
b1daf021e0 Merge remote-tracking branch 'origin/main' into financebench 2025-09-15 19:52:37 -07:00
Andy Lee
3578680cb6 fix: as package 2025-09-15 19:50:45 -07:00
Andy Lee
a0d6857faa docs: data updated 2025-09-15 19:50:02 -07:00
Andy Lee
d7011bbea0 docs: data 2025-08-25 16:25:59 -07:00
Andy Lee
ef4c69d128 chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules 2025-08-25 16:08:16 -07:00
Andy Lee
75c8aeee5f style: format 2025-08-25 15:48:04 -07:00
Andy Lee
3d79741f9c experiments for running DiskANN & BM25 on Arch 4090 2025-08-25 15:46:48 -07:00
Andy Lee
df34c84bd3 feat: enron email bench 2025-08-24 23:06:57 -07:00
Andy Lee
8dfd2f015c fix: resolve ruff linting errors
- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention
2025-08-22 13:53:30 -07:00
Andy Lee
ed72232bab style: format 2025-08-22 13:51:10 -07:00
Andy Lee
26d961bfc5 style: format 2025-08-22 13:44:26 -07:00
Andy Lee
722bda4ebb Merge remote-tracking branch 'origin/main' into financebench 2025-08-22 13:39:08 -07:00
Andy Lee
a7c7e8801d feat: laion, also required idmaps support 2025-08-22 13:32:33 -07:00
Andy Lee
069bce558b feat: fix financebench 2025-08-22 13:32:23 -07:00
Andy Lee
772894012e Merge branch 'main' into financebench 2025-08-20 20:40:27 -07:00
Andy Lee
5c163737c4 Merge remote-tracking branch 'origin/main' into financebench 2025-08-17 11:58:34 -07:00
Andy Lee
6d1d67ead7 chore: ignroe data README 2025-08-17 11:58:32 -07:00
Andy Lee
ed27ea6990 docs: results 2025-08-16 16:48:01 -07:00
Andy Lee
baf2d76e0e feat: finance bench 2025-08-16 16:22:50 -07:00
11 changed files with 104 additions and 549 deletions

6
.gitignore vendored
View File

@@ -99,9 +99,3 @@ benchmarks/data/
## multi vector ## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*

View File

@@ -1,113 +0,0 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -169,7 +169,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
) )
doc_vecs: list[Any] = [] doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"): for batch_doc in dataloader:
with torch.no_grad(): with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32 # autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
@@ -200,7 +200,7 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
) )
q_vecs: list[Any] = [] q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"): for batch_query in dataloader:
with torch.no_grad(): with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()} batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda": if model.device.type == "cuda":
@@ -362,7 +362,7 @@ if USE_HF_DATASET:
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset)) N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = [] filepaths: list[str] = []
images: list[Image.Image] = [] images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ): for i in tqdm(range(N), desc="Loading dataset"):
p = dataset[i] p = dataset[i]
# Compose a descriptive identifier for printing later # Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}" identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"

View File

@@ -4,24 +4,39 @@
# pip install tqdm # pip install tqdm
# pip install pillow # pip install pillow
# %%
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
# %%
import os import os
import re
import sys
from pathlib import Path from pathlib import Path
from typing import cast
from PIL import Image # Make local leann packages importable without installing
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3] _repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
import sys
if str(_leann_core_src) not in sys.path: if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src)) sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path: if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg)) sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
class LeannRetriever(LeannMultiVector):
pass
# %%
from typing import cast
import torch import torch
from colpali_engine.models import ColPali from colpali_engine.models import ColPali
@@ -73,6 +88,13 @@ for batch_query in dataloader:
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape) print(qs[0].shape)
# %% # %%
import re
from PIL import Image
from tqdm import tqdm
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group())) page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames] images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]

View File

@@ -43,11 +43,7 @@ from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1] REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?" DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [ DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"] DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
@@ -186,7 +182,6 @@ def run_workflow(
is_recompute: bool, is_recompute: bool,
query: str, query: str,
top_k: int, top_k: int,
skip_search: bool,
) -> dict[str, Any]: ) -> dict[str, Any]:
prefix = f"[{label}] " if label else "" prefix = f"[{label}] " if label else ""
@@ -203,15 +198,12 @@ def run_workflow(
) )
initial_size = index_file_size(index_path) initial_size = index_file_size(index_path)
if not skip_search: before_results = run_search(
before_results = run_search( index_path,
index_path, query,
query, top_k,
top_k, recompute_embeddings=is_recompute,
recompute_embeddings=is_recompute, )
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...") print(f"\n{prefix}Updating index with additional passages...")
update_index( update_index(
@@ -223,23 +215,20 @@ def run_workflow(
is_recompute=is_recompute, is_recompute=is_recompute,
) )
if not skip_search: after_results = run_search(
after_results = run_search( index_path,
index_path, query,
query, top_k,
top_k, recompute_embeddings=is_recompute,
recompute_embeddings=is_recompute, )
)
else:
after_results = None
updated_size = index_file_size(index_path) updated_size = index_file_size(index_path)
return { return {
"initial_size": initial_size, "initial_size": initial_size,
"updated_size": updated_size, "updated_size": updated_size,
"delta": updated_size - initial_size, "delta": updated_size - initial_size,
"before_results": before_results if not skip_search else None, "before_results": before_results,
"after_results": after_results if not skip_search else None, "after_results": after_results,
"metadata": load_metadata_snapshot(index_path), "metadata": load_metadata_snapshot(index_path),
} }
@@ -325,12 +314,6 @@ def main() -> None:
action="store_false", action="store_false",
help="Skip building the no-recompute baseline.", help="Skip building the no-recompute baseline.",
) )
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True) parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args() args = parser.parse_args()
@@ -367,13 +350,10 @@ def main() -> None:
is_recompute=True, is_recompute=True,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
if not args.skip_search: print_results("initial search", recompute_stats["before_results"])
print_results("initial search", recompute_stats["before_results"]) print_results("after update", recompute_stats["after_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print( print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes" f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})" f"{recompute_stats['delta']})"
@@ -398,7 +378,6 @@ def main() -> None:
is_recompute=False, is_recompute=False,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print( print(
@@ -406,12 +385,8 @@ def main() -> None:
f"{baseline_stats['delta']})" f"{baseline_stats['delta']})"
) )
after_texts = ( after_texts = [res.text for res in recompute_stats["after_results"]]
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts: if after_texts == baseline_after_texts:
print( print(
"[no-recompute] Search results match recompute baseline; see above for the shared output." "[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -5,7 +5,6 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import os
import pickle import pickle
import re import re
import subprocess import subprocess
@@ -21,7 +20,6 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
@@ -730,7 +728,6 @@ class LeannBuilder:
index = faiss.read_index(str(index_file)) index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"): if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None: if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT: if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d) storage_index = faiss.IndexFlatIP(index.d)
@@ -738,107 +735,37 @@ class LeannBuilder:
storage_index = faiss.IndexFlatL2(index.d) storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index index.storage = storage_index
index.own_fields = True index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim: if index.d != embedding_dim:
raise ValueError( raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})." f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
) )
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks): for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset) new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id chunk["id"] = new_id
# Append passages/offsets before we attempt index.add so the ZMQ server index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
# can resolve newly assigned IDs during recompute. Keep rollback hooks faiss.write_index(index, str(index_file))
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try: with open(passages_file, "a", encoding="utf-8") as f:
with open(passages_file, "a", encoding="utf-8") as f: for chunk in valid_chunks:
for chunk in valid_chunks: offset = f.tell()
offset = f.tell() json.dump(
json.dump( {
{ "id": chunk["id"],
"id": chunk["id"], "text": chunk["text"],
"text": chunk["text"], "metadata": chunk.get("metadata", {}),
"metadata": chunk.get("metadata", {}), },
}, f,
f, ensure_ascii=False,
ensure_ascii=False, )
) f.write("\n")
f.write("\n") offset_map[chunk["id"]] = offset
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f: with open(offset_file, "wb") as f:
pickle.dump(offset_map, f) pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map) meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f: with open(meta_path, "w", encoding="utf-8") as f:

View File

@@ -1,5 +1,4 @@
import atexit import atexit
import json
import logging import logging
import os import os
import socket import socket
@@ -49,85 +48,6 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity # Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A simplified manager for embedding server processes that avoids complex update mechanisms. A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -165,14 +85,13 @@ class EmbeddingServerManager:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None) provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = self._build_config_signature( config_signature = {
model_name=model_name, "model_name": model_name,
embedding_mode=embedding_mode, "passages_file": kwargs.get("passages_file", ""),
provider_options=provider_options, "embedding_mode": embedding_mode,
passages_file=passages_file, "provider_options": provider_options or {},
) }
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if ( if (
@@ -196,7 +115,6 @@ class EmbeddingServerManager:
port, port,
model_name, model_name,
embedding_mode, embedding_mode,
config_signature=config_signature,
provider_options=provider_options, provider_options=provider_options,
**kwargs, **kwargs,
) )
@@ -218,30 +136,11 @@ class EmbeddingServerManager:
**kwargs, **kwargs,
) )
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
@@ -264,11 +163,10 @@ class EmbeddingServerManager:
command, command,
actual_port, actual_port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready_colab(actual_port) started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started: if started:
self._server_config = config_signature or { self._server_config = {
"model_name": model_name, "model_name": model_name,
"passages_file": kwargs.get("passages_file", ""), "passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode, "embedding_mode": embedding_mode,
@@ -300,7 +198,6 @@ class EmbeddingServerManager:
command, command,
port, port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready(port) started, ready_port = self._wait_for_server_ready(port)
if started: if started:
@@ -344,9 +241,7 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -381,29 +276,26 @@ class EmbeddingServerManager:
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready) # Record config for in-process reuse (best effort; refined later when ready)
if config_signature is not None: try:
self._server_config = config_signature self._server_config = {
else: # Fallback for unexpected code paths "model_name": command[command.index("--model-name") + 1]
try: if "--model-name" in command
self._server_config = { else "",
"model_name": command[command.index("--model-name") + 1] "passages_file": command[command.index("--passages-file") + 1]
if "--model-name" in command if "--passages-file" in command
else "", else "",
"passages_file": command[command.index("--passages-file") + 1] "embedding_mode": command[command.index("--embedding-mode") + 1]
if "--passages-file" in command if "--embedding-mode" in command
else "", else "sentence-transformers",
"embedding_mode": command[command.index("--embedding-mode") + 1] "provider_options": provider_options or {},
if "--embedding-mode" in command }
else "sentence-transformers", except Exception:
"provider_options": provider_options or {}, self._server_config = {
} "model_name": "",
except Exception: "passages_file": "",
self._server_config = { "embedding_mode": "sentence-transformers",
"model_name": "", "provider_options": provider_options or {},
"passages_file": "", }
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
@@ -511,9 +403,7 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
@@ -539,15 +429,12 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process) atexit.register(self._finalize_process)
self._atexit_registered = True self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode # Record config for in-process reuse is best-effort in Colab mode
if config_signature is not None: self._server_config = {
self._server_config = config_signature "model_name": "",
else: "passages_file": "",
self._server_config = { "embedding_mode": "sentence-transformers",
"model_name": "", "provider_options": provider_options or {},
"passages_file": "", }
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout.""" """Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -111,7 +111,7 @@ target-version = "py39"
line-length = 100 line-length = 100
extend-exclude = [ extend-exclude = [
"third_party", "third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py", "apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py" "apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
] ]

View File

@@ -1,137 +0,0 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature