Compare commits
17 Commits
fix/chunki
...
embed-laun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed15776564 | ||
|
|
8d202b8b0e | ||
|
|
9ac9eab48d | ||
|
|
cd1d853a46 | ||
|
|
253680043a | ||
|
|
36c44b8806 | ||
|
|
66c6aad3e4 | ||
|
|
29ef3c95dc | ||
|
|
469dce0045 | ||
|
|
0ac676f9cb | ||
|
|
97c9f39704 | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e |
@@ -12,6 +12,7 @@ from pathlib import Path
|
|||||||
try:
|
try:
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
sys.path.insert(0, str(leann_src))
|
sys.path.insert(0, str(leann_src))
|
||||||
from leann.chunking_utils import (
|
from leann.chunking_utils import (
|
||||||
CODE_EXTENSIONS,
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
create_ast_chunks,
|
create_ast_chunks,
|
||||||
create_text_chunks,
|
create_text_chunks,
|
||||||
create_traditional_chunks,
|
create_traditional_chunks,
|
||||||
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CODE_EXTENSIONS",
|
"CODE_EXTENSIONS",
|
||||||
|
"_traditional_chunks_as_dicts",
|
||||||
"create_ast_chunks",
|
"create_ast_chunks",
|
||||||
"create_text_chunks",
|
"create_text_chunks",
|
||||||
"create_traditional_chunks",
|
"create_traditional_chunks",
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
from __future__ import annotations
|
import concurrent.futures
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
_repo_root = Path(current_file).resolve().parents[3]
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
@@ -16,6 +22,380 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
|
|||||||
sys.path.append(str(_leann_hnsw_pkg))
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
def _find_backend_module_file() -> Optional[Path]:
|
||||||
|
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
|
||||||
|
this_file = Path(__file__).resolve()
|
||||||
|
candidates: list[Path] = []
|
||||||
|
|
||||||
|
# Common in-repo location
|
||||||
|
repo_root = this_file.parents[3]
|
||||||
|
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
|
||||||
|
candidates.append(
|
||||||
|
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
for cand in candidates:
|
||||||
|
try:
|
||||||
|
if cand.exists() and cand.resolve() != this_file:
|
||||||
|
return cand.resolve()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
|
||||||
|
for p in list(sys.path):
|
||||||
|
try:
|
||||||
|
cand = Path(p) / "leann_multi_vector.py"
|
||||||
|
if cand.exists() and cand.resolve() != this_file:
|
||||||
|
return cand.resolve()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_BACKEND_LEANN_CLASS: Optional[type] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_backend_leann_multi_vector() -> type:
|
||||||
|
"""Load backend LeannMultiVector class even if this file shadows its module name."""
|
||||||
|
global _BACKEND_LEANN_CLASS
|
||||||
|
if _BACKEND_LEANN_CLASS is not None:
|
||||||
|
return _BACKEND_LEANN_CLASS
|
||||||
|
|
||||||
|
backend_path = _find_backend_module_file()
|
||||||
|
if backend_path is None:
|
||||||
|
# Fallback to local implementation in this module
|
||||||
|
try:
|
||||||
|
cls = LeannMultiVector # type: ignore[name-defined]
|
||||||
|
_BACKEND_LEANN_CLASS = cls
|
||||||
|
return cls
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
|
||||||
|
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
module_name = "leann_hnsw_backend_module"
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
|
||||||
|
backend_module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = backend_module
|
||||||
|
spec.loader.exec_module(backend_module) # type: ignore[assignment]
|
||||||
|
|
||||||
|
if not hasattr(backend_module, "LeannMultiVector"):
|
||||||
|
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
|
||||||
|
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
|
||||||
|
return _BACKEND_LEANN_CLASS
|
||||||
|
|
||||||
|
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(
|
||||||
|
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
|
||||||
|
) -> Any:
|
||||||
|
LeannMultiVector = _get_backend_leann_multi_vector()
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
"image": images[i], # Include the original image
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
|
||||||
|
LeannMultiVector = _get_backend_leann_multi_vector()
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Check for the actual HNSW index file written by the backend + our sidecar files
|
||||||
|
index_file = index_base.parent / f"{index_base.stem}.index"
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_file.exists() and meta.exists() and labels.exists():
|
||||||
|
try:
|
||||||
|
with open(meta, encoding="utf-8") as f:
|
||||||
|
meta_json = json.load(f)
|
||||||
|
dim = int(meta_json.get("dimensions", 128))
|
||||||
|
except Exception:
|
||||||
|
dim = 128
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure repo paths are importable for dynamic backend loading
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
@@ -45,6 +425,7 @@ class LeannMultiVector:
|
|||||||
"is_recompute": is_recompute,
|
"is_recompute": is_recompute,
|
||||||
}
|
}
|
||||||
self._labels_meta: list[dict] = []
|
self._labels_meta: list[dict] = []
|
||||||
|
self._docid_to_indices: dict[int, list[int]] | None = None
|
||||||
|
|
||||||
def _meta_dict(self) -> dict:
|
def _meta_dict(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@@ -69,6 +450,7 @@ class LeannMultiVector:
|
|||||||
"doc_id": int(data["doc_id"]),
|
"doc_id": int(data["doc_id"]),
|
||||||
"filepath": data.get("filepath", ""),
|
"filepath": data.get("filepath", ""),
|
||||||
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
|
"image": data.get("image"), # PIL Image object (optional)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,6 +462,15 @@ class LeannMultiVector:
|
|||||||
index_path_obj = Path(self.index_path)
|
index_path_obj = Path(self.index_path)
|
||||||
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def _embeddings_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
|
||||||
|
|
||||||
|
def _images_dir_path(self) -> Path:
|
||||||
|
"""Directory where original images are stored."""
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.images"
|
||||||
|
|
||||||
def create_index(self) -> None:
|
def create_index(self) -> None:
|
||||||
if not self._pending_items:
|
if not self._pending_items:
|
||||||
return
|
return
|
||||||
@@ -87,10 +478,23 @@ class LeannMultiVector:
|
|||||||
embeddings: list[np.ndarray] = []
|
embeddings: list[np.ndarray] = []
|
||||||
labels_meta: list[dict] = []
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
# Create images directory if needed
|
||||||
|
images_dir = self._images_dir_path()
|
||||||
|
images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for item in self._pending_items:
|
for item in self._pending_items:
|
||||||
doc_id = int(item["doc_id"])
|
doc_id = int(item["doc_id"])
|
||||||
filepath = item.get("filepath", "")
|
filepath = item.get("filepath", "")
|
||||||
colbert_vecs = item["colbert_vecs"]
|
colbert_vecs = item["colbert_vecs"]
|
||||||
|
image = item.get("image")
|
||||||
|
|
||||||
|
# Save image if provided
|
||||||
|
image_path = ""
|
||||||
|
if image is not None and isinstance(image, Image.Image):
|
||||||
|
image_filename = f"doc_{doc_id}.png"
|
||||||
|
image_path = str(images_dir / image_filename)
|
||||||
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
for seq_id, vec in enumerate(colbert_vecs):
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
vec_np = np.asarray(vec, dtype=np.float32)
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
embeddings.append(vec_np)
|
embeddings.append(vec_np)
|
||||||
@@ -100,6 +504,7 @@ class LeannMultiVector:
|
|||||||
"doc_id": doc_id,
|
"doc_id": doc_id,
|
||||||
"seq_id": int(seq_id),
|
"seq_id": int(seq_id),
|
||||||
"filepath": filepath,
|
"filepath": filepath,
|
||||||
|
"image_path": image_path, # Store the path to the saved image
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,7 +512,6 @@ class LeannMultiVector:
|
|||||||
return
|
return
|
||||||
|
|
||||||
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
# print shape of embeddings_np
|
|
||||||
print(embeddings_np.shape)
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
@@ -121,6 +525,9 @@ class LeannMultiVector:
|
|||||||
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
_json.dump(labels_meta, f)
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
# Persist embeddings for exact reranking
|
||||||
|
np.save(self._embeddings_path(), embeddings_np)
|
||||||
|
|
||||||
self._labels_meta = labels_meta
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
def _load_labels_meta_if_needed(self) -> None:
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
@@ -133,6 +540,19 @@ class LeannMultiVector:
|
|||||||
with open(labels_path, encoding="utf-8") as f:
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
self._labels_meta = _json.load(f)
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def _build_docid_to_indices_if_needed(self) -> None:
|
||||||
|
if self._docid_to_indices is not None:
|
||||||
|
return
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
mapping: dict[int, list[int]] = {}
|
||||||
|
for idx, meta in enumerate(self._labels_meta):
|
||||||
|
try:
|
||||||
|
doc_id = int(meta["doc_id"]) # type: ignore[index]
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
mapping.setdefault(doc_id, []).append(idx)
|
||||||
|
self._docid_to_indices = mapping
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
) -> list[tuple[float, int]]:
|
) -> list[tuple[float, int]]:
|
||||||
@@ -180,3 +600,181 @@ class LeannMultiVector:
|
|||||||
|
|
||||||
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
return scores[:topk] if len(scores) >= topk else scores
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
|
def search_exact(
|
||||||
|
self,
|
||||||
|
data: np.ndarray,
|
||||||
|
topk: int,
|
||||||
|
*,
|
||||||
|
first_stage_k: int = 200,
|
||||||
|
max_workers: int = 32,
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
"""
|
||||||
|
High-precision MaxSim reranking over candidate documents.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
|
||||||
|
2) For each candidate doc, load all its token embeddings and compute
|
||||||
|
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
|
||||||
|
|
||||||
|
Returns top-k list of (score, doc_id).
|
||||||
|
"""
|
||||||
|
# Normalize inputs
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
self._build_docid_to_indices_if_needed()
|
||||||
|
|
||||||
|
emb_path = self._embeddings_path()
|
||||||
|
if not emb_path.exists():
|
||||||
|
# Fallback to approximate if we don't have persisted embeddings
|
||||||
|
return self.search(data, topk, first_stage_k=first_stage_k)
|
||||||
|
|
||||||
|
# Memory-map embeddings to avoid loading all into RAM
|
||||||
|
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||||
|
if all_embeddings.dtype != np.float32:
|
||||||
|
all_embeddings = all_embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
# First-stage ANN to collect candidate doc_ids
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
labels = raw.get("labels")
|
||||||
|
if labels is None:
|
||||||
|
return []
|
||||||
|
candidate_doc_ids: set[int] = set()
|
||||||
|
for batch in labels:
|
||||||
|
for sid in batch:
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
|
||||||
|
|
||||||
|
# Exact scoring per doc (parallelized)
|
||||||
|
assert self._docid_to_indices is not None
|
||||||
|
|
||||||
|
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||||
|
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||||
|
if not token_indices:
|
||||||
|
return (0.0, doc_id)
|
||||||
|
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||||
|
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
|
||||||
|
sim = np.dot(data, doc_vecs.T)
|
||||||
|
# nan-safe
|
||||||
|
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||||
|
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||||
|
return (float(score), doc_id)
|
||||||
|
|
||||||
|
scores: list[tuple[float, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||||
|
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
|
||||||
|
for fut in concurrent.futures.as_completed(futures):
|
||||||
|
scores.append(fut.result())
|
||||||
|
|
||||||
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
|
def search_exact_all(
|
||||||
|
self,
|
||||||
|
data: np.ndarray,
|
||||||
|
topk: int,
|
||||||
|
*,
|
||||||
|
max_workers: int = 32,
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
"""
|
||||||
|
Exact MaxSim over ALL documents (no ANN pre-filtering).
|
||||||
|
|
||||||
|
This computes, for each document, sum_i max_j dot(q_i, d_j).
|
||||||
|
It memory-maps the persisted token-embedding matrix for scalability.
|
||||||
|
"""
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
self._build_docid_to_indices_if_needed()
|
||||||
|
|
||||||
|
emb_path = self._embeddings_path()
|
||||||
|
if not emb_path.exists():
|
||||||
|
return self.search(data, topk)
|
||||||
|
|
||||||
|
all_embeddings = np.load(emb_path, mmap_mode="r")
|
||||||
|
if all_embeddings.dtype != np.float32:
|
||||||
|
all_embeddings = all_embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
assert self._docid_to_indices is not None
|
||||||
|
candidate_doc_ids = list(self._docid_to_indices.keys())
|
||||||
|
|
||||||
|
def _score_one(doc_id: int) -> tuple[float, int]:
|
||||||
|
token_indices = self._docid_to_indices.get(doc_id, [])
|
||||||
|
if not token_indices:
|
||||||
|
return (0.0, doc_id)
|
||||||
|
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
|
||||||
|
sim = np.dot(data, doc_vecs.T)
|
||||||
|
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
|
||||||
|
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
|
||||||
|
return (float(score), doc_id)
|
||||||
|
|
||||||
|
scores: list[tuple[float, int]] = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||||
|
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
|
||||||
|
for fut in concurrent.futures.as_completed(futures):
|
||||||
|
scores.append(fut.result())
|
||||||
|
|
||||||
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
|
|
||||||
|
def get_image(self, doc_id: int) -> Optional[Image.Image]:
|
||||||
|
"""
|
||||||
|
Retrieve the original image for a given doc_id from the index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: The document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image object if found, None otherwise
|
||||||
|
"""
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
|
||||||
|
for meta in self._labels_meta:
|
||||||
|
if meta.get("doc_id") == doc_id:
|
||||||
|
image_path = meta.get("image_path", "")
|
||||||
|
if image_path and Path(image_path).exists():
|
||||||
|
return Image.open(image_path)
|
||||||
|
break
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_metadata(self, doc_id: int) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Retrieve metadata for a given doc_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: The document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
|
||||||
|
"""
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
for meta in self._labels_meta:
|
||||||
|
if meta.get("doc_id") == doc_id:
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"filepath": meta.get("filepath", ""),
|
||||||
|
"image_path": meta.get("image_path", ""),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|||||||
@@ -2,34 +2,31 @@
|
|||||||
# %%
|
# %%
|
||||||
# uv pip install matplotlib qwen_vl_utils
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
import os
|
import os
|
||||||
import re
|
from typing import Any, Optional
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional, cast
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def _ensure_repo_paths_importable(current_file: str) -> None:
|
from leann_multi_vector import ( # utility functions/classes
|
||||||
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
_ensure_repo_paths_importable,
|
||||||
_repo_root = Path(current_file).resolve().parents[3]
|
_load_images_from_dir,
|
||||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
_maybe_convert_pdf_to_images,
|
||||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
_load_colvision,
|
||||||
if str(_leann_core_src) not in sys.path:
|
_embed_images,
|
||||||
sys.path.append(str(_leann_core_src))
|
_embed_queries,
|
||||||
if str(_leann_hnsw_pkg) not in sys.path:
|
_build_index,
|
||||||
sys.path.append(str(_leann_hnsw_pkg))
|
_load_retriever_if_index_exists,
|
||||||
|
_generate_similarity_map,
|
||||||
|
QwenVL,
|
||||||
|
)
|
||||||
|
|
||||||
_ensure_repo_paths_importable(__file__)
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Config
|
# Config
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
@@ -44,7 +41,7 @@ PAGES_DIR: str = "./pages"
|
|||||||
|
|
||||||
# Index + retrieval settings
|
# Index + retrieval settings
|
||||||
INDEX_PATH: str = "./indexes/colvision.leann"
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
TOPK: int = 1
|
TOPK: int = 3
|
||||||
FIRST_STAGE_K: int = 500
|
FIRST_STAGE_K: int = 500
|
||||||
REBUILD_INDEX: bool = False
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
@@ -54,332 +51,57 @@ SIMILARITY_MAP: bool = True
|
|||||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
ANSWER: bool = True
|
ANSWER: bool = True
|
||||||
MAX_NEW_TOKENS: int = 128
|
MAX_NEW_TOKENS: int = 1024
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# Helpers
|
|
||||||
def _natural_sort_key(name: str) -> int:
|
|
||||||
m = re.search(r"\d+", name)
|
|
||||||
return int(m.group()) if m else 0
|
|
||||||
|
|
||||||
|
|
||||||
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
|
||||||
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
|
||||||
filenames = sorted(filenames, key=_natural_sort_key)
|
|
||||||
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
|
||||||
images = [Image.open(p) for p in filepaths]
|
|
||||||
return filepaths, images
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
|
||||||
if not pdf_path:
|
|
||||||
return
|
|
||||||
os.makedirs(pages_dir, exist_ok=True)
|
|
||||||
try:
|
|
||||||
from pdf2image import convert_from_path
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
|
||||||
) from e
|
|
||||||
images = convert_from_path(pdf_path, dpi=dpi)
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
|
||||||
|
|
||||||
|
|
||||||
def _select_device_and_dtype():
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import get_torch_device
|
|
||||||
|
|
||||||
device_str = (
|
|
||||||
"cuda"
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else (
|
|
||||||
"mps"
|
|
||||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
device = get_torch_device(device_str)
|
|
||||||
# Stable dtype selection to avoid NaNs:
|
|
||||||
# - CUDA: prefer bfloat16 if supported, else float16
|
|
||||||
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
|
||||||
# - CPU: float32
|
|
||||||
if device_str == "cuda":
|
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
||||||
try:
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
elif device_str == "mps":
|
|
||||||
dtype = torch.float32
|
|
||||||
else:
|
|
||||||
dtype = torch.float32
|
|
||||||
return device_str, device, dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _load_colvision(model_choice: str):
|
|
||||||
import torch
|
|
||||||
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
|
||||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
device_str, device, dtype = _select_device_and_dtype()
|
|
||||||
|
|
||||||
if model_choice == "colqwen2":
|
|
||||||
model_name = "vidore/colqwen2-v1.0"
|
|
||||||
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
|
||||||
attn_implementation = (
|
|
||||||
"flash_attention_2"
|
|
||||||
if (device_str == "cuda" and is_flash_attn_2_available())
|
|
||||||
else "eager"
|
|
||||||
)
|
|
||||||
model = ColQwen2.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map=device,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
).eval()
|
|
||||||
processor = ColQwen2Processor.from_pretrained(model_name)
|
|
||||||
else:
|
|
||||||
model_name = "vidore/colpali-v1.2"
|
|
||||||
model = ColPali.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map=device,
|
|
||||||
).eval()
|
|
||||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
|
||||||
|
|
||||||
return model_name, model, processor, device_str, device, dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import ListDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Ensure deterministic eval and autocast for stability
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset=ListDataset[Image.Image](images),
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=lambda x: processor.process_images(x),
|
|
||||||
)
|
|
||||||
|
|
||||||
doc_vecs: list[Any] = []
|
|
||||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
|
||||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
with torch.autocast(
|
|
||||||
device_type="cuda",
|
|
||||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
|
||||||
):
|
|
||||||
embeddings_doc = model(**batch_doc)
|
|
||||||
else:
|
|
||||||
embeddings_doc = model(**batch_doc)
|
|
||||||
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
|
||||||
return doc_vecs
|
|
||||||
|
|
||||||
|
|
||||||
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.utils.torch_utils import ListDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset=ListDataset[str](queries),
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=lambda x: processor.process_queries(x),
|
|
||||||
)
|
|
||||||
|
|
||||||
q_vecs: list[Any] = []
|
|
||||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
|
||||||
if model.device.type == "cuda":
|
|
||||||
with torch.autocast(
|
|
||||||
device_type="cuda",
|
|
||||||
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
|
||||||
):
|
|
||||||
embeddings_query = model(**batch_query)
|
|
||||||
else:
|
|
||||||
embeddings_query = model(**batch_query)
|
|
||||||
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
|
||||||
return q_vecs
|
|
||||||
|
|
||||||
|
|
||||||
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
|
||||||
dim = int(doc_vecs[0].shape[-1])
|
|
||||||
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
|
||||||
retriever.create_collection()
|
|
||||||
for i, vec in enumerate(doc_vecs):
|
|
||||||
data = {
|
|
||||||
"colbert_vecs": vec.float().numpy(),
|
|
||||||
"doc_id": i,
|
|
||||||
"filepath": filepaths[i],
|
|
||||||
}
|
|
||||||
retriever.insert(data)
|
|
||||||
retriever.create_index()
|
|
||||||
return retriever
|
|
||||||
|
|
||||||
|
|
||||||
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
|
||||||
index_base = Path(index_path)
|
|
||||||
# Rough heuristic: index dir exists AND meta+labels files exist
|
|
||||||
meta = index_base.parent / f"{index_base.name}.meta.json"
|
|
||||||
labels = index_base.parent / f"{index_base.name}.labels.json"
|
|
||||||
if index_base.exists() and meta.exists() and labels.exists():
|
|
||||||
return LeannMultiVector(index_path=index_path, dim=dim)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_similarity_map(
|
|
||||||
model,
|
|
||||||
processor,
|
|
||||||
image: Image.Image,
|
|
||||||
query: str,
|
|
||||||
token_idx: Optional[int] = None,
|
|
||||||
output_path: Optional[str] = None,
|
|
||||||
) -> tuple[int, float]:
|
|
||||||
import torch
|
|
||||||
from colpali_engine.interpretability import (
|
|
||||||
get_similarity_maps_from_embeddings,
|
|
||||||
plot_similarity_map,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_images = processor.process_images([image]).to(model.device)
|
|
||||||
batch_queries = processor.process_queries([query]).to(model.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
image_embeddings = model.forward(**batch_images)
|
|
||||||
query_embeddings = model.forward(**batch_queries)
|
|
||||||
|
|
||||||
n_patches = processor.get_n_patches(
|
|
||||||
image_size=image.size,
|
|
||||||
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
|
||||||
)
|
|
||||||
image_mask = processor.get_image_mask(batch_images)
|
|
||||||
|
|
||||||
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
|
||||||
image_embeddings=image_embeddings,
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_patches=n_patches,
|
|
||||||
image_mask=image_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
similarity_maps = batched_similarity_maps[0]
|
|
||||||
|
|
||||||
# Determine token index if not provided: choose the token with highest max score
|
|
||||||
if token_idx is None:
|
|
||||||
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
|
||||||
token_idx = int(per_token_max.argmax().item())
|
|
||||||
|
|
||||||
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
|
||||||
|
|
||||||
if output_path:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
fig, ax = plot_similarity_map(
|
|
||||||
image=image,
|
|
||||||
similarity_map=similarity_maps[token_idx],
|
|
||||||
figsize=(14, 14),
|
|
||||||
show_colorbar=False,
|
|
||||||
)
|
|
||||||
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
||||||
plt.savefig(output_path, bbox_inches="tight")
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
return token_idx, float(max_sim_score)
|
|
||||||
|
|
||||||
|
|
||||||
class QwenVL:
|
|
||||||
def __init__(self, device: str):
|
|
||||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
||||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
|
||||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
|
||||||
torch_dtype="auto",
|
|
||||||
device_map=device,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
)
|
|
||||||
|
|
||||||
min_pixels = 256 * 28 * 28
|
|
||||||
max_pixels = 1280 * 28 * 28
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
|
||||||
)
|
|
||||||
|
|
||||||
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
|
|
||||||
content = []
|
|
||||||
for img in images:
|
|
||||||
buffer = BytesIO()
|
|
||||||
img.save(buffer, format="jpeg")
|
|
||||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
|
||||||
content.append({"type": "text", "text": query})
|
|
||||||
messages = [{"role": "user", "content": content}]
|
|
||||||
|
|
||||||
text = self.processor.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
image_inputs, video_inputs = process_vision_info(messages)
|
|
||||||
inputs = self.processor(
|
|
||||||
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
inputs = inputs.to(self.model.device)
|
|
||||||
|
|
||||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
||||||
generated_ids_trimmed = [
|
|
||||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
||||||
]
|
|
||||||
return self.processor.batch_decode(
|
|
||||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
# Step 1: Prepare data
|
# Step 1: Check if we can skip data loading (index already exists)
|
||||||
if USE_HF_DATASET:
|
retriever: Optional[Any] = None
|
||||||
from datasets import load_dataset
|
need_to_build_index = REBUILD_INDEX
|
||||||
|
|
||||||
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
if not REBUILD_INDEX:
|
||||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||||
filepaths: list[str] = []
|
if retriever is not None:
|
||||||
images: list[Image.Image] = []
|
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||||
for i in tqdm(range(N), desc="Loading dataset", total=N ):
|
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||||
p = dataset[i]
|
need_to_build_index = False
|
||||||
# Compose a descriptive identifier for printing later
|
else:
|
||||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
print(f"Index not found, will build new index")
|
||||||
print(identifier)
|
need_to_build_index = True
|
||||||
filepaths.append(identifier)
|
|
||||||
images.append(p["page_image"]) # PIL Image
|
# Step 2: Load data only if we need to build the index
|
||||||
|
if need_to_build_index:
|
||||||
|
print("Loading dataset...")
|
||||||
|
if USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(N), desc="Loading dataset", total=N):
|
||||||
|
p = dataset[i]
|
||||||
|
# Compose a descriptive identifier for printing later
|
||||||
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
filepaths.append(identifier)
|
||||||
|
images.append(p["page_image"]) # PIL Image
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
print(f"Loaded {len(images)} images")
|
||||||
else:
|
else:
|
||||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
print("Skipping dataset loading (using existing index)")
|
||||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
filepaths = [] # Not needed when using existing index
|
||||||
if not images:
|
images = [] # Not needed when using existing index
|
||||||
raise RuntimeError(
|
|
||||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 2: Load model and processor
|
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
@@ -387,34 +109,39 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
|||||||
# %%
|
# %%
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 3: Build or load index
|
# Step 4: Build index if needed
|
||||||
retriever: Optional[LeannMultiVector] = None
|
if need_to_build_index and retriever is None:
|
||||||
if not REBUILD_INDEX:
|
print("Building index...")
|
||||||
try:
|
|
||||||
one_vec = _embed_images(model, processor, [images[0]])[0]
|
|
||||||
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
|
||||||
except Exception:
|
|
||||||
retriever = None
|
|
||||||
|
|
||||||
if retriever is None:
|
|
||||||
doc_vecs = _embed_images(model, processor, images)
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||||
|
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||||
|
# Clear memory
|
||||||
|
del images, filepaths, doc_vecs
|
||||||
|
|
||||||
|
# Note: Images are now stored in the index, retriever will load them on-demand from disk
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 4: Embed query and search
|
# Step 5: Embed query and search
|
||||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
|
||||||
if not results:
|
if not results:
|
||||||
print("No results found.")
|
print("No results found.")
|
||||||
else:
|
else:
|
||||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
top_images: list[Image.Image] = []
|
top_images: list[Image.Image] = []
|
||||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
path = filepaths[doc_id]
|
# Retrieve image from index instead of memory
|
||||||
|
image = retriever.get_image(doc_id)
|
||||||
|
if image is None:
|
||||||
|
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = retriever.get_metadata(doc_id)
|
||||||
|
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||||
# For HF dataset, path is a descriptive identifier, not a real file path
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
top_images.append(images[doc_id])
|
top_images.append(image)
|
||||||
|
|
||||||
if SAVE_TOP_IMAGE:
|
if SAVE_TOP_IMAGE:
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
@@ -427,12 +154,17 @@ else:
|
|||||||
else:
|
else:
|
||||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
img.save(str(out_path))
|
img.save(str(out_path))
|
||||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||||
|
try:
|
||||||
|
score, _doc_id = results[rank - 1]
|
||||||
|
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||||
|
except Exception:
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 5: Similarity maps for top-K results
|
# Step 6: Similarity maps for top-K results
|
||||||
if results and SIMILARITY_MAP:
|
if results and SIMILARITY_MAP:
|
||||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
@@ -469,7 +201,7 @@ if results and SIMILARITY_MAP:
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Step 6: Optional answer generation
|
# Step 7: Optional answer generation
|
||||||
if results and ANSWER:
|
if results and ANSWER:
|
||||||
qwen = QwenVL(device=device_str)
|
qwen = QwenVL(device=device_str)
|
||||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
|||||||
98
benchmarks/issue_159.py
Normal file
98
benchmarks/issue_159.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to reproduce issue #159: Slow search performance
|
||||||
|
Configuration:
|
||||||
|
- GPU: A10
|
||||||
|
- embedding_model: BAAI/bge-large-zh-v1.5
|
||||||
|
- data size: 180M text (~90K chunks)
|
||||||
|
- backend: hnsw
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
|
||||||
|
|
||||||
|
# Configuration matching the issue
|
||||||
|
INDEX_PATH = "./test_issue_159.leann"
|
||||||
|
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
|
||||||
|
BACKEND_NAME = "hnsw"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_data(num_chunks=90000, chunk_size=2000):
|
||||||
|
"""Generate test data similar to 180MB text (~90K chunks)"""
|
||||||
|
# Each chunk is approximately 2000 characters
|
||||||
|
# 90K chunks * 2000 chars ≈ 180MB
|
||||||
|
chunks = []
|
||||||
|
base_text = (
|
||||||
|
"这是一个测试文档。LEANN是一个创新的向量数据库, 通过图基选择性重计算实现97%的存储节省。"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
|
||||||
|
chunks.append(chunk[:chunk_size])
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_performance():
|
||||||
|
"""Test search performance with different configurations"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("Testing LEANN Search Performance (Issue #159)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
meta_path = Path(f"{INDEX_PATH}.meta.json")
|
||||||
|
if meta_path.exists():
|
||||||
|
print(f"\n✓ Index already exists at {INDEX_PATH}")
|
||||||
|
print(" Skipping build phase. Delete the index to rebuild.")
|
||||||
|
else:
|
||||||
|
print("\n📦 Building index...")
|
||||||
|
print(f" Backend: {BACKEND_NAME}")
|
||||||
|
print(f" Embedding Model: {EMBEDDING_MODEL}")
|
||||||
|
print(" Generating test data (~90K chunks, ~180MB)...")
|
||||||
|
|
||||||
|
chunks = generate_test_data(num_chunks=90000)
|
||||||
|
print(f" Generated {len(chunks)} chunks")
|
||||||
|
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=BACKEND_NAME,
|
||||||
|
embedding_model=EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" Adding chunks to builder...")
|
||||||
|
start_time = time.time()
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
builder.add_text(chunk)
|
||||||
|
if (i + 1) % 10000 == 0:
|
||||||
|
print(f" Added {i + 1}/{len(chunks)} chunks...")
|
||||||
|
|
||||||
|
print(" Building index...")
|
||||||
|
build_start = time.time()
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
build_time = time.time() - build_start
|
||||||
|
print(f" ✓ Index built in {build_time:.2f} seconds")
|
||||||
|
|
||||||
|
# Test search with different complexity values
|
||||||
|
print("\n🔍 Testing search performance...")
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
|
||||||
|
test_query = "LEANN向量数据库存储优化"
|
||||||
|
|
||||||
|
# Test with minimal complexity (8)
|
||||||
|
print("\n Test 4: Minimal complexity (8)")
|
||||||
|
print(f" Query: '{test_query}'")
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(test_query, top_k=10, complexity=8)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
print(f" ✓ Search completed in {search_time:.2f} seconds")
|
||||||
|
print(f" Results: {len(results)} items")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_search_performance()
|
||||||
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# Update Benchmarks
|
||||||
|
|
||||||
|
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||||
|
search” pipeline under different assumptions:
|
||||||
|
|
||||||
|
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||||
|
settings influence incremental `add()` latency when embeddings are fetched
|
||||||
|
over the ZMQ embedding server.
|
||||||
|
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||||
|
against an offline approach that keeps the graph static and fuses results.
|
||||||
|
|
||||||
|
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||||
|
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||||
|
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
### 1. HNSW RNG Recompute Benchmark
|
||||||
|
|
||||||
|
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||||
|
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||||
|
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||||
|
is enabled:
|
||||||
|
|
||||||
|
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||||
|
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||||
|
| `baseline` | Enabled | Enabled | Enabled |
|
||||||
|
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||||
|
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||||
|
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||||
|
|
||||||
|
For each scenario the script:
|
||||||
|
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||||
|
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||||
|
3. Appends the requested updates using the scenario’s RNG flags.
|
||||||
|
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||||
|
timings before appending a row to the CSV output.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||||
|
LEANN_LOG_LEVEL=INFO \
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--runs 1 \
|
||||||
|
--index-path .leann/bench/test.leann \
|
||||||
|
--initial-files data/PrideandPrejudice.txt \
|
||||||
|
--update-files data/huawei_pangu.md \
|
||||||
|
--max-initial 300 \
|
||||||
|
--max-updates 1 \
|
||||||
|
--add-timeout 120
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||||
|
(including ms/passage) for each run.
|
||||||
|
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||||
|
`LEANN_HNSW_LOG_PATH`).
|
||||||
|
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||||
|
|
||||||
|
### 2. Sequential vs. Offline Update Benchmark
|
||||||
|
|
||||||
|
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||||
|
same dataset:
|
||||||
|
|
||||||
|
- **Scenario A – Sequential Update**
|
||||||
|
- Start an embedding server.
|
||||||
|
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||||
|
mutates the HNSW graph.
|
||||||
|
- After all inserts, run a search on the updated graph.
|
||||||
|
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||||
|
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||||
|
latency.
|
||||||
|
|
||||||
|
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||||
|
- Stop Scenario A’s server and start a fresh embedding server.
|
||||||
|
- Spawn two threads: one generates embeddings for the new passages offline
|
||||||
|
(graph unchanged); the other computes the query embedding and searches the
|
||||||
|
existing graph.
|
||||||
|
- Merge offline similarities with the graph search results to emulate late
|
||||||
|
fusion, then report the merged top‑k preview.
|
||||||
|
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||||
|
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||||
|
|
||||||
|
**Run (both scenarios):**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 \
|
||||||
|
--num-updates 1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||||
|
print timing summaries to stdout and append the results to CSV.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||||
|
Scenario A and B.
|
||||||
|
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||||
|
checks.
|
||||||
|
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||||
|
|
||||||
|
### 3. Visualisation
|
||||||
|
|
||||||
|
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||||
|
benchmark into a single two-panel plot.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.plot_bench_results \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||||
|
- `--csv` – RNG benchmark results CSV (left panel).
|
||||||
|
- `--csv-right` – Update strategy results CSV (right panel).
|
||||||
|
- `--out` – Output image path (PNG/PDF supported).
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||||
|
suites.
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||||
|
slides/papers.
|
||||||
|
|
||||||
|
## Parameters & Environment
|
||||||
|
|
||||||
|
### Common CLI Flags
|
||||||
|
- `--max-initial` – Number of initial passages used to seed the index.
|
||||||
|
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||||
|
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||||
|
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||||
|
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||||
|
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||||
|
execution of the embedding model.
|
||||||
|
|
||||||
|
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||||
|
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||||
|
fusion better match your latency/accuracy trade-offs.
|
||||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Benchmarks for LEANN update workflows."""
|
||||||
|
|
||||||
|
# Expose helper to locate repository root for other modules that need it.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_repo_root() -> Path:
|
||||||
|
"""Return the project root containing pyproject.toml."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
return current.parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["find_repo_root"]
|
||||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
|||||||
|
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||||
|
embedding recomputation.
|
||||||
|
|
||||||
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||||
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||||
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||||
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||||
|
|
||||||
|
Example usage (run from the repo root; downloads the model on first run)::
|
||||||
|
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--index-path .leann/bench/leann-demo.leann \
|
||||||
|
--runs 1
|
||||||
|
|
||||||
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||||
|
if you want a larger or different workload, and change the embedding model via
|
||||||
|
``--model-name``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||||
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_update_with_mode(
|
||||||
|
index_path: Path,
|
||||||
|
new_chunks: list[dict[str, Any]],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
disable_forward_rng: bool,
|
||||||
|
disable_reverse_rng: bool,
|
||||||
|
server_port: int,
|
||||||
|
add_timeout: int,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in new_chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
index.is_recompute = True
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
try:
|
||||||
|
storage_index.ntotal = index.ntotal
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||||
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = ef_construction
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||||
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||||
|
logger.info(
|
||||||
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||||
|
disable_forward_rng,
|
||||||
|
disable_reverse_rng,
|
||||||
|
applied_forward,
|
||||||
|
applied_reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
|
offset_map_backup = offset_map.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=server_port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError("Failed to start embedding server.")
|
||||||
|
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
|
_warmup_embedding_server(actual_port)
|
||||||
|
|
||||||
|
total_start = time.time()
|
||||||
|
add_elapsed = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("incremental add timed out")
|
||||||
|
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(add_timeout)
|
||||||
|
|
||||||
|
add_start = time.time()
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
add_elapsed = time.time() - add_start
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_size)
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map_backup, f)
|
||||||
|
raise
|
||||||
|
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(False)
|
||||||
|
index.hnsw.set_disable_reverse_prune(False)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
total_elapsed = time.time() - total_start
|
||||||
|
|
||||||
|
return total_elapsed, add_elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _total_zmq_nodes(log_path: Path) -> int:
|
||||||
|
if not log_path.exists():
|
||||||
|
return 0
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
text = log_file.read()
|
||||||
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_embedding_server(port: int) -> None:
|
||||||
|
"""Send a dummy REQ so the embedding server loads its model."""
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
sock = ctx.socket(zmq.REQ)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||||
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||||
|
sock.send(payload)
|
||||||
|
try:
|
||||||
|
sock.recv()
|
||||||
|
except zmq.error.Again:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/leann-demo.leann"),
|
||||||
|
help="Output index base path (without extension).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Files used to build the initial index.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Files appended during the benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model used for build/update.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
default="sentence-transformers",
|
||||||
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
help="Distance metric for HNSW backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-construction",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="efConstruction setting for initial build.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=5557,
|
||||||
|
help="Port for the real embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-initial",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Optional cap on initial passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-updates",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Optional cap on update passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-timeout",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("bench_latency.png"),
|
||||||
|
help="Where to save the latency bar plot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Where to append per-scenario results as CSV.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||||
|
|
||||||
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("baseline", False, False, True),
|
||||||
|
("no_cache_baseline", False, False, False),
|
||||||
|
("disable_forward_rng", True, False, True),
|
||||||
|
("disable_forward_and_reverse_rng", True, True, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||||
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
import csv
|
||||||
|
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"cache_enabled",
|
||||||
|
"ef_construction",
|
||||||
|
"max_initial",
|
||||||
|
"max_updates",
|
||||||
|
"total_time_s",
|
||||||
|
"add_only_s",
|
||||||
|
"latency_ms_per_passage",
|
||||||
|
"zmq_nodes",
|
||||||
|
"stageA_time_s",
|
||||||
|
"stageBC_time_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
# Create CSV with header if missing
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(args.runs):
|
||||||
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||||
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||||
|
print(f"\nScenario: {name}")
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
log_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||||
|
args.index_path,
|
||||||
|
update_chunks,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
disable_forward,
|
||||||
|
disable_reverse,
|
||||||
|
args.server_port,
|
||||||
|
args.add_timeout,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
print(f"Scenario {name} timed out: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
if curr_size < prev_size:
|
||||||
|
prev_size = 0
|
||||||
|
zmq_count = 0
|
||||||
|
if log_path.exists():
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
log_file.seek(prev_size)
|
||||||
|
new_entries = log_file.read()
|
||||||
|
zmq_count = sum(
|
||||||
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||||
|
)
|
||||||
|
stageA = sum(
|
||||||
|
float(x)
|
||||||
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
stageBC = sum(
|
||||||
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stageA = 0.0
|
||||||
|
stageBC = 0.0
|
||||||
|
|
||||||
|
per_chunk = add_elapsed / len(update_chunks)
|
||||||
|
print(
|
||||||
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||||
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||||
|
)
|
||||||
|
print(f"ZMQ node fetch total: {zmq_count}")
|
||||||
|
results_total[name].append(total_elapsed)
|
||||||
|
results_add[name].append(add_elapsed)
|
||||||
|
results_zmq[name].append(zmq_count)
|
||||||
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||||
|
results_stageA[name].append(stageA)
|
||||||
|
results_stageBC[name].append(stageBC)
|
||||||
|
|
||||||
|
# Append row to CSV
|
||||||
|
if args.csv_path:
|
||||||
|
row = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": name,
|
||||||
|
"cache_enabled": 1 if cache_enabled else 0,
|
||||||
|
"ef_construction": args.ef_construction,
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"max_updates": args.max_updates,
|
||||||
|
"total_time_s": round(total_elapsed, 6),
|
||||||
|
"add_only_s": round(add_elapsed, 6),
|
||||||
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||||
|
"zmq_nodes": int(zmq_count),
|
||||||
|
"stageA_time_s": round(stageA, 6),
|
||||||
|
"stageBC_time_s": round(stageBC, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
for name in results_add:
|
||||||
|
add_values = results_add[name]
|
||||||
|
total_values = results_total[name]
|
||||||
|
zmq_values = results_zmq[name]
|
||||||
|
latency_values = results_ms_per_passage[name]
|
||||||
|
if not add_values:
|
||||||
|
print(f"{name}: no successful runs")
|
||||||
|
continue
|
||||||
|
avg_add = sum(add_values) / len(add_values)
|
||||||
|
avg_total = sum(total_values) / len(total_values)
|
||||||
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||||
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||||
|
runs = len(add_values)
|
||||||
|
print(
|
||||||
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||||
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.plot_path:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
labels = [name for name, *_ in scenarios]
|
||||||
|
values = [
|
||||||
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||||
|
if results_ms_per_passage[name]
|
||||||
|
else 0.0
|
||||||
|
for name in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
def _auto_cap(vals: list[float]) -> float | None:
|
||||||
|
s = sorted(vals, reverse=True)
|
||||||
|
if len(s) < 2:
|
||||||
|
return None
|
||||||
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||||
|
return s[1] * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||||
|
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.4, 5.0),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||||
|
)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap * 0.02,
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False)
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.set_xticks(range(len(labels)))
|
||||||
|
ax_bottom.set_xticklabels(labels)
|
||||||
|
ax = ax_bottom
|
||||||
|
else:
|
||||||
|
cap = args.cap_y or _auto_cap(values)
|
||||||
|
plt.figure(figsize=(7.2, 4.2))
|
||||||
|
ax = plt.gca()
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||||
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(b[0])
|
||||||
|
if v > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
ax.plot(
|
||||||
|
[0.02 - 0.02, 0.02 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[0.98 - 0.02, 0.98 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
if any(v > cap for v in values):
|
||||||
|
ax.legend(
|
||||||
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||||
|
)
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||||
|
|
||||||
|
plt.ylabel("Average add latency (ms per passage)")
|
||||||
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.plot_path)
|
||||||
|
print(f"Saved latency bar plot to {args.plot_path}")
|
||||||
|
# ZMQ time split (Stage A vs B/C)
|
||||||
|
try:
|
||||||
|
plt.figure(figsize=(6, 4))
|
||||||
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||||
|
bc_vals = [
|
||||||
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||||
|
]
|
||||||
|
ind = range(len(labels))
|
||||||
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||||
|
plt.bar(
|
||||||
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||||
|
)
|
||||||
|
plt.xticks(list(ind), labels, rotation=10)
|
||||||
|
plt.ylabel("Server ZMQ time (s)")
|
||||||
|
plt.title(
|
||||||
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||||
|
)
|
||||||
|
plt.legend()
|
||||||
|
out2 = args.plot_path.with_name(
|
||||||
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||||
|
)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out2)
|
||||||
|
print(f"Saved ZMQ time split plot to {out2}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to plot ZMQ split:", e)
|
||||||
|
except ImportError:
|
||||||
|
print("matplotlib not available; skipping plot generation")
|
||||||
|
|
||||||
|
# leave the last build on disk for inspection
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
"""
|
||||||
|
Compare two latency models for small incremental updates vs. search:
|
||||||
|
|
||||||
|
Scenario A (sequential update then search):
|
||||||
|
- Build initial HNSW (is_recompute=True)
|
||||||
|
- Start embedding server (ZMQ) for recompute
|
||||||
|
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||||
|
- Then run a search query on the updated index
|
||||||
|
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||||
|
|
||||||
|
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||||
|
- Do NOT insert the N passages into the graph
|
||||||
|
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||||
|
embedding and run a search on the existing index
|
||||||
|
- After both finish, compute similarity between the query embedding and the N
|
||||||
|
new passage embeddings, merge with the index search results by score, and
|
||||||
|
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||||
|
|
||||||
|
This script reuses the model/data loading conventions of
|
||||||
|
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||||
|
comparison for the two execution strategies above.
|
||||||
|
|
||||||
|
Example (from the repository root):
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 --num-updates 5 --k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil # type: ignore
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||||
|
if metric == "cosine":
|
||||||
|
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||||
|
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vecs = vecs / norms
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index_for_search(index_path: Path) -> Any:
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
# Force-disable experimental disk cache when loading the index so that
|
||||||
|
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||||
|
cfg = faiss.HNSWIndexConfig()
|
||||||
|
cfg.is_recompute = True
|
||||||
|
if hasattr(cfg, "disk_cache_ratio"):
|
||||||
|
cfg.disk_cache_ratio = 0.0
|
||||||
|
if hasattr(cfg, "external_storage_path"):
|
||||||
|
cfg.external_storage_path = None
|
||||||
|
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||||
|
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||||
|
# ensure recompute mode persists after reload
|
||||||
|
try:
|
||||||
|
index.is_recompute = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
actual_ntotal = index.hnsw.levels.size()
|
||||||
|
except AttributeError:
|
||||||
|
actual_ntotal = index.ntotal
|
||||||
|
if actual_ntotal != index.ntotal:
|
||||||
|
print(
|
||||||
|
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
index.ntotal = actual_ntotal
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_for_updates(
|
||||||
|
meta_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
index_dir = meta_path.parent
|
||||||
|
meta_name = meta_path.name
|
||||||
|
if not meta_name.endswith(".meta.json"):
|
||||||
|
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||||
|
index_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||||
|
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not offsets_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passage store missing; cannot register update passages for recompute mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
passage_id = str(start_id + i)
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[passage_id] = offset
|
||||||
|
assigned_ids.append(passage_id)
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta = {}
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||||
|
distances = np.zeros((1, k), dtype=np.float32)
|
||||||
|
indices = np.zeros((1, k), dtype=np.int64)
|
||||||
|
index.search(
|
||||||
|
1,
|
||||||
|
faiss.swig_ptr(q),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(indices),
|
||||||
|
)
|
||||||
|
return distances[0], indices[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _score_for_metric(dist: float, metric: str) -> float:
|
||||||
|
# Convert FAISS distance to a "higher is better" score
|
||||||
|
if metric in ("mips", "cosine"):
|
||||||
|
return float(dist)
|
||||||
|
# l2 distance (smaller better) -> negative distance as score
|
||||||
|
return -float(dist)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray],
|
||||||
|
offline_scores: list[tuple[int, float]],
|
||||||
|
k: int,
|
||||||
|
metric: str,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
distances, indices = index_results
|
||||||
|
merged: list[tuple[str, float]] = []
|
||||||
|
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||||
|
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||||
|
for j, s in offline_scores:
|
||||||
|
merged.append((f"offline:{j}", s))
|
||||||
|
merged.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return merged[:k]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScenarioResult:
|
||||||
|
name: str
|
||||||
|
update_total_s: float
|
||||||
|
search_s: float
|
||||||
|
overall_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-initial", type=int, default=300)
|
||||||
|
parser.add_argument("--num-updates", type=int, default=5)
|
||||||
|
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="neural network",
|
||||||
|
help="Query text used for the search benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--server-port", type=int, default=5557)
|
||||||
|
parser.add_argument("--add-timeout", type=int, default=600)
|
||||||
|
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
choices=["A", "B", "both"],
|
||||||
|
default="both",
|
||||||
|
help="Run only Scenario A, Scenario B, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Where to append results (CSV).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages loaded from --update-files")
|
||||||
|
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||||
|
if len(update_paragraphs) < args.num_updates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare index object and meta
|
||||||
|
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||||
|
index = _read_index_for_search(args.index_path)
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"max_initial",
|
||||||
|
"num_updates",
|
||||||
|
"k",
|
||||||
|
"total_time_s",
|
||||||
|
"add_total_s",
|
||||||
|
"search_time_s",
|
||||||
|
"emb_time_s",
|
||||||
|
"makespan_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
# Debug: list existing HNSW server PIDs before starting
|
||||||
|
try:
|
||||||
|
existing = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if existing:
|
||||||
|
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||||
|
for p in existing:
|
||||||
|
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||||
|
except Exception as _e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
add_total = 0.0
|
||||||
|
search_after_add = 0.0
|
||||||
|
total_seq = 0.0
|
||||||
|
port_a = None
|
||||||
|
if args.only in ("A", "both"):
|
||||||
|
# Scenario A: sequential update then search
|
||||||
|
start_id = index.ntotal
|
||||||
|
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||||
|
if assigned_ids:
|
||||||
|
logger.debug(
|
||||||
|
"Registered %d update passages starting at id %s",
|
||||||
|
len(assigned_ids),
|
||||||
|
assigned_ids[0],
|
||||||
|
)
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
ok, port = server_manager.start_server(
|
||||||
|
port=args.server_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to start embedding server")
|
||||||
|
try:
|
||||||
|
# Set ZMQ port for recompute mode
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(port)
|
||||||
|
|
||||||
|
# Start A overall timer BEFORE computing update embeddings
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Compute embeddings for updates (counted into A's overall)
|
||||||
|
t_emb0 = time.time()
|
||||||
|
upd_embs = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time_updates = time.time() - t_emb0
|
||||||
|
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||||
|
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||||
|
|
||||||
|
# Perform sequential adds
|
||||||
|
for i in range(upd_embs.shape[0]):
|
||||||
|
t_add0 = time.time()
|
||||||
|
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||||
|
add_total += time.time() - t_add0
|
||||||
|
# Don't persist index after adds to avoid contaminating Scenario B
|
||||||
|
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||||
|
# faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Search after updates
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||||
|
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||||
|
|
||||||
|
# Warm up search with a dummy query first
|
||||||
|
print("[DEBUG] Warming up search...")
|
||||||
|
_ = _search(index, q_emb, 1)
|
||||||
|
|
||||||
|
t_s0 = time.time()
|
||||||
|
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||||
|
search_after_add = time.time() - t_s0
|
||||||
|
total_seq = time.time() - t0
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
port_a = port
|
||||||
|
|
||||||
|
print("\n=== Scenario A: update->search (sequential) ===")
|
||||||
|
# emb_time_updates is defined only when A runs
|
||||||
|
try:
|
||||||
|
_emb_a = emb_time_updates
|
||||||
|
except NameError:
|
||||||
|
_emb_a = 0.0
|
||||||
|
print(
|
||||||
|
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||||
|
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||||
|
)
|
||||||
|
# CSV row for A
|
||||||
|
if args.csv_path:
|
||||||
|
row_a = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "A",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": round(total_seq, 6),
|
||||||
|
"add_total_s": round(add_total, 6),
|
||||||
|
"search_time_s": round(search_after_add, 6),
|
||||||
|
"emb_time_s": round(_emb_a, 6),
|
||||||
|
"makespan_s": 0.0,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_a)
|
||||||
|
|
||||||
|
# Verify server cleanup
|
||||||
|
try:
|
||||||
|
# short sleep to allow signal handling to finish
|
||||||
|
time.sleep(0.5)
|
||||||
|
leftovers = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if leftovers:
|
||||||
|
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||||
|
for p in leftovers:
|
||||||
|
print(
|
||||||
|
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||||
|
if args.only in ("B", "both"):
|
||||||
|
# ensure a server is available for recompute search
|
||||||
|
server_manager_b = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
requested_port = args.server_port if port_a is None else port_a
|
||||||
|
ok_b, port_b = server_manager_b.start_server(
|
||||||
|
port=requested_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok_b:
|
||||||
|
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||||
|
|
||||||
|
# Wait for server to fully initialize
|
||||||
|
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the index first
|
||||||
|
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||||
|
|
||||||
|
# Then configure ZMQ port on the correct index object
|
||||||
|
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||||
|
index_no_update.hnsw.set_zmq_port(port_b)
|
||||||
|
elif hasattr(index_no_update, "set_zmq_port"):
|
||||||
|
index_no_update.set_zmq_port(port_b)
|
||||||
|
|
||||||
|
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||||
|
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||||
|
logger.info("Warming up embedding model for Scenario B...")
|
||||||
|
_ = compute_embeddings(
|
||||||
|
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare worker A: compute embeddings for the same N passages
|
||||||
|
emb_time = 0.0
|
||||||
|
updates_embs_offline: np.ndarray | None = None
|
||||||
|
|
||||||
|
def _worker_emb():
|
||||||
|
nonlocal emb_time, updates_embs_offline
|
||||||
|
t = time.time()
|
||||||
|
updates_embs_offline = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time = time.time() - t
|
||||||
|
|
||||||
|
# Pre-compute query embedding and warm up search outside of timed section.
|
||||||
|
q_vec = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||||
|
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||||
|
print("[DEBUG B] Warming up search...")
|
||||||
|
_ = _search(index_no_update, q_vec, 1)
|
||||||
|
|
||||||
|
# Worker B: timed search on the warmed index
|
||||||
|
search_time = 0.0
|
||||||
|
offline_elapsed = 0.0
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||||
|
|
||||||
|
def _worker_search():
|
||||||
|
nonlocal search_time, index_results
|
||||||
|
t = time.time()
|
||||||
|
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||||
|
search_time = time.time() - t
|
||||||
|
index_results = (distances, indices)
|
||||||
|
|
||||||
|
# Run two workers concurrently
|
||||||
|
t0 = time.time()
|
||||||
|
th1 = threading.Thread(target=_worker_emb)
|
||||||
|
th2 = threading.Thread(target=_worker_search)
|
||||||
|
th1.start()
|
||||||
|
th2.start()
|
||||||
|
th1.join()
|
||||||
|
th2.join()
|
||||||
|
offline_elapsed = time.time() - t0
|
||||||
|
|
||||||
|
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||||
|
offline_scores: list[tuple[int, float]] = []
|
||||||
|
if updates_embs_offline is not None:
|
||||||
|
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||||
|
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||||
|
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||||
|
for j in range(upd2.shape[0]):
|
||||||
|
if args.distance_metric in ("mips", "cosine"):
|
||||||
|
s = float(np.dot(q_vec[0], upd2[j]))
|
||||||
|
else:
|
||||||
|
diff = q_vec[0] - upd2[j]
|
||||||
|
s = -float(np.dot(diff, diff))
|
||||||
|
offline_scores.append((j, s))
|
||||||
|
|
||||||
|
merged_topk = (
|
||||||
|
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||||
|
if index_results
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||||
|
print(
|
||||||
|
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||||
|
)
|
||||||
|
if merged_topk:
|
||||||
|
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||||
|
print(f"Merged top-5 preview: {preview}")
|
||||||
|
# CSV row for B
|
||||||
|
if args.csv_path:
|
||||||
|
row_b = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "B",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": 0.0,
|
||||||
|
"add_total_s": 0.0,
|
||||||
|
"search_time_s": round(search_time, 6),
|
||||||
|
"emb_time_s": round(emb_time, 6),
|
||||||
|
"makespan_s": round(offline_elapsed, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_b)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server_manager_b.stop_server()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
msg_a = (
|
||||||
|
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||||
|
if args.only in ("A", "both")
|
||||||
|
else "A: skipped"
|
||||||
|
)
|
||||||
|
msg_b = (
|
||||||
|
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||||
|
if args.only in ("B", "both")
|
||||||
|
else "B: skipped"
|
||||||
|
)
|
||||||
|
print(msg_a + "\n" + msg_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Plot latency bars from the benchmark CSV produced by
|
||||||
|
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||||
|
|
||||||
|
If you also provide an offline_vs_update.csv via --csv-right
|
||||||
|
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||||
|
output a side-by-side figure:
|
||||||
|
- Left: ms/passage bars (four RNG scenarios).
|
||||||
|
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python benchmarks/update/plot_bench_results.py \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
|
||||||
|
The script selects the latest run_id in the CSV and plots four bars for
|
||||||
|
the default scenarios:
|
||||||
|
- baseline
|
||||||
|
- no_cache_baseline
|
||||||
|
- disable_forward_rng
|
||||||
|
- disable_forward_and_reverse_rng
|
||||||
|
|
||||||
|
If multiple rows exist per scenario for that run_id, the script averages
|
||||||
|
their latency_ms_per_passage values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_SCENARIOS = [
|
||||||
|
"no_cache_baseline",
|
||||||
|
"baseline",
|
||||||
|
"disable_forward_rng",
|
||||||
|
"disable_forward_and_reverse_rng",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"baseline": "+ Cache",
|
||||||
|
"no_cache_baseline": "Naive \n Recompute",
|
||||||
|
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||||
|
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper-style colors and hatches for scenarios
|
||||||
|
SCENARIO_STYLES = {
|
||||||
|
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||||
|
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||||
|
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||||
|
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_latest_run(csv_path: Path):
|
||||||
|
rows = []
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows.append(row)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit("CSV is empty: no rows to plot")
|
||||||
|
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||||
|
run_ids = [r.get("run_id", "") for r in rows]
|
||||||
|
latest = max(run_ids)
|
||||||
|
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||||
|
if not latest_rows:
|
||||||
|
# Fallback: take last 4 rows
|
||||||
|
latest_rows = rows[-4:]
|
||||||
|
latest = latest_rows[-1].get("run_id", "unknown")
|
||||||
|
return latest, latest_rows
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_latency(rows):
|
||||||
|
acc = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
sc = r.get("scenario", "")
|
||||||
|
try:
|
||||||
|
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
acc[sc].append(val)
|
||||||
|
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_cap(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
sorted_vals = sorted(values, reverse=True)
|
||||||
|
if len(sorted_vals) < 2:
|
||||||
|
return None
|
||||||
|
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||||
|
if second <= 0:
|
||||||
|
return None
|
||||||
|
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||||
|
if max_v >= 2.5 * second:
|
||||||
|
return second * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||||
|
# Draw small diagonal ticks near left/right to signal cap
|
||||||
|
x0, x1 = rel_x0, rel_x1
|
||||||
|
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
if v >= 1000:
|
||||||
|
return f"{v / 1000:.1f}k"
|
||||||
|
return f"{v:.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.rcParams["font.family"] = "Helvetica"
|
||||||
|
plt.rcParams["ytick.direction"] = "in"
|
||||||
|
plt.rcParams["hatch.linewidth"] = 1.5
|
||||||
|
plt.rcParams["font.weight"] = "bold"
|
||||||
|
plt.rcParams["axes.labelweight"] = "bold"
|
||||||
|
plt.rcParams["text.usetex"] = True
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Path to results CSV (defaults to bench_results.csv)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=Path,
|
||||||
|
default=Path("add_ablation.pdf"),
|
||||||
|
help="Output image path",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv-right",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--no-auto-cap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
latest_run, latest_rows = load_latest_run(args.csv)
|
||||||
|
avg = aggregate_latency(latest_rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(f"matplotlib not available: {e}")
|
||||||
|
|
||||||
|
scenarios = DEFAULT_SCENARIOS
|
||||||
|
values = [avg.get(name, 0.0) for name in scenarios]
|
||||||
|
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
# If right CSV is provided, build side-by-side figure
|
||||||
|
if args.csv_right is not None:
|
||||||
|
try:
|
||||||
|
right_rows_all = []
|
||||||
|
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||||
|
rreader = csv.DictReader(f)
|
||||||
|
right_rows_all = list(rreader)
|
||||||
|
if right_rows_all:
|
||||||
|
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||||
|
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||||
|
else:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
except Exception:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
|
||||||
|
a_total = 0.0
|
||||||
|
b_makespan = 0.0
|
||||||
|
for r in right_rows:
|
||||||
|
sc = (r.get("scenario", "") or "").strip().upper()
|
||||||
|
if sc == "A":
|
||||||
|
try:
|
||||||
|
a_total = float(r.get("total_time_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif sc == "B":
|
||||||
|
try:
|
||||||
|
b_makespan = float(r.get("makespan_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import gridspec
|
||||||
|
|
||||||
|
# Left subplot (reuse current style, with optional cap)
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
# Use broken axis for left subplot
|
||||||
|
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||||
|
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||||
|
gs = gridspec.GridSpec(
|
||||||
|
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||||
|
)
|
||||||
|
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||||
|
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||||
|
ax_right = fig.add_subplot(gs[:, 1])
|
||||||
|
|
||||||
|
# Determine break points
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = (
|
||||||
|
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||||
|
) # Increased to show more range
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.5, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = (
|
||||||
|
max(values) * 1.90 if values else 1.0
|
||||||
|
) # Increase headroom to 1.90 for text label and tick range
|
||||||
|
|
||||||
|
# Draw bars on both axes
|
||||||
|
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Set limits
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_left_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values (convert ms to s)
|
||||||
|
values_s = [v / 1000.0 for v in values]
|
||||||
|
lower_cap_s = lower_cap / 1000.0
|
||||||
|
upper_start_s = upper_start / 1000.0
|
||||||
|
ymax_s = ymax / 1000.0
|
||||||
|
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||||
|
ax_left_bottom.clear()
|
||||||
|
ax_left_top.clear()
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||||
|
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||||
|
# Draw in bottom axis for all bars
|
||||||
|
ax_left_bottom.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||||
|
if v > upper_start_s:
|
||||||
|
ax_left_top.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
for i, v in enumerate(values_s):
|
||||||
|
if v <= lower_cap_s:
|
||||||
|
ax_left_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap_s * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left_top.text(
|
||||||
|
i,
|
||||||
|
v + (ymax_s - upper_start_s) * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide spines between axes
|
||||||
|
ax_left_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_left_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_left_top.tick_params(
|
||||||
|
labeltop=False, labelbottom=False, bottom=False
|
||||||
|
) # Hide tick marks
|
||||||
|
ax_left_bottom.xaxis.tick_bottom()
|
||||||
|
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||||
|
|
||||||
|
# Draw break marks (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_left_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||||
|
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
|
||||||
|
ax_left_bottom.set_xticks(x)
|
||||||
|
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||||
|
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match bar width with right subplot
|
||||||
|
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||||
|
ax_left_top.set_xlim(-0.6, 3.6)
|
||||||
|
|
||||||
|
ax_left = ax_left_bottom # for compatibility
|
||||||
|
else:
|
||||||
|
# Regular side-by-side layout
|
||||||
|
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||||
|
if val > cap:
|
||||||
|
bars[i].set_hatch("//")
|
||||||
|
ax_left.text(
|
||||||
|
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(val),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax_left.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax_left, y=0.98)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
else:
|
||||||
|
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax_left.set_ylabel("Latency (ms per passage)")
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
ax_left.set_title(
|
||||||
|
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right subplot (A vs B, seconds) - paper style
|
||||||
|
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||||
|
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||||
|
r_styles = [
|
||||||
|
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||||
|
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||||
|
]
|
||||||
|
# 2 bars, centered with proper spacing
|
||||||
|
xr = [0, 1]
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||||
|
ax_right.bar(
|
||||||
|
xr[i],
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(r_values):
|
||||||
|
max_v = max(r_values) if r_values else 1.0
|
||||||
|
offset = max(0.0002, 0.02 * max_v)
|
||||||
|
ax_right.text(
|
||||||
|
xr[i],
|
||||||
|
v + offset,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_right.set_xticks(xr)
|
||||||
|
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_right.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match left subplot's bar width visually
|
||||||
|
# Accounting for width_ratios=[1.5, 1]:
|
||||||
|
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||||
|
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# Right: 2 bars, need same visual width
|
||||||
|
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# range_right = 4.2 / 1.5 = 2.8
|
||||||
|
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||||
|
ax_right.set_xlim(-0.9, 1.9)
|
||||||
|
|
||||||
|
# Set y-axis limit with headroom for text labels
|
||||||
|
if r_values:
|
||||||
|
max_v = max(r_values)
|
||||||
|
ax_right.set_ylim(0, max_v * 1.15)
|
||||||
|
|
||||||
|
# Format y-axis to avoid scientific notation
|
||||||
|
ax_right.ticklabel_format(style="plain", axis="y")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Add aligned ylabels using fig.text (after tight_layout)
|
||||||
|
# Get the vertical center of the entire figure
|
||||||
|
fig_center_y = 0.5
|
||||||
|
# Left ylabel - closer to left plot
|
||||||
|
left_x = 0.05
|
||||||
|
fig.text(
|
||||||
|
left_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Right ylabel - closer to right plot
|
||||||
|
right_bbox = ax_right.get_position()
|
||||||
|
right_x = right_bbox.x0 - 0.07
|
||||||
|
fig.text(
|
||||||
|
right_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broken-Y mode
|
||||||
|
if args.broken_y:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.5, 6.75),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine default breaks from second-highest
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
|
||||||
|
# Hide spines between axes and draw diagonal break marks
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Diagonal lines at the break (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||||
|
|
||||||
|
ax_bottom.set_xticks(x)
|
||||||
|
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax = ax_bottom # for labeling below
|
||||||
|
else:
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
|
||||||
|
plt.figure(figsize=(5.4, 3.15))
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||||
|
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(bar[0])
|
||||||
|
# Hatch and annotate when capped
|
||||||
|
if val > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax, y=0.98)
|
||||||
|
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||||
|
v > cap for v in values
|
||||||
|
) else None
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(
|
||||||
|
idx,
|
||||||
|
val + 1.0,
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
# Try to extract some context for title
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
fig.text(
|
||||||
|
0.02,
|
||||||
|
0.5,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
fig.suptitle(
|
||||||
|
"Add Operation Latency",
|
||||||
|
fontsize=11,
|
||||||
|
y=0.98,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||||
|
else:
|
||||||
|
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||||
|
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
if hasattr(self._index, "set_zmq_port"):
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
|
|||||||
pass
|
pass
|
||||||
return str(nid)
|
return str(nid)
|
||||||
|
|
||||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
@@ -158,225 +156,238 @@ def create_hnsw_embedding_server(
|
|||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
# Track last request type/length for shape-correct fallbacks
|
last_request_type = "unknown"
|
||||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
|
||||||
last_request_length = 0
|
last_request_length = 0
|
||||||
|
|
||||||
|
def _build_safe_fallback():
|
||||||
|
if last_request_type == "distance":
|
||||||
|
large_distance = 1e9
|
||||||
|
fallback_len = max(0, int(last_request_length))
|
||||||
|
return [[large_distance] * fallback_len]
|
||||||
|
if last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
if dim > 0:
|
||||||
|
return [[bsz, dim], [0.0] * (bsz * dim)]
|
||||||
|
return [[0, 0], []]
|
||||||
|
if last_request_type == "text":
|
||||||
|
return []
|
||||||
|
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
|
||||||
|
def _handle_text_embedding(request: list[str]) -> None:
|
||||||
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
last_request_type = "text"
|
||||||
|
last_request_length = len(request)
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
def _handle_distance_request(request: list[Any]) -> None:
|
||||||
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
node_ids = request[0]
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
|
logger.debug("Distance calculation request received")
|
||||||
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||||
|
|
||||||
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
|
response_distances[pos] = float(dval)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||||
|
|
||||||
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
def _handle_embedding_by_id(request: Any) -> None:
|
||||||
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
|
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
|
||||||
|
node_ids = request[0]
|
||||||
|
elif isinstance(request, list):
|
||||||
|
node_ids = request
|
||||||
|
else:
|
||||||
|
node_ids = []
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error(
|
||||||
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
|
)
|
||||||
|
dims = [0, embedding_dim]
|
||||||
|
flat_data = []
|
||||||
|
else:
|
||||||
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
flat = emb_f32.flatten().tolist()
|
||||||
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
|
]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
||||||
|
|
||||||
|
response_payload = [dims, flat_data]
|
||||||
|
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
e2e_start = time.time()
|
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
request_bytes = rep_socket.recv()
|
request_bytes = rep_socket.recv()
|
||||||
|
except zmq.Again:
|
||||||
|
continue
|
||||||
|
|
||||||
# Rest of the processing logic (same as original)
|
try:
|
||||||
request = msgpack.unpackb(request_bytes)
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
except Exception as exc:
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
logger.error(f"Error unpacking ZMQ message: {exc}")
|
||||||
|
try:
|
||||||
|
safe = _build_safe_fallback()
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
continue
|
||||||
|
|
||||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
try:
|
||||||
response_bytes = msgpack.packb([model_name])
|
# Model query
|
||||||
rep_socket.send(response_bytes)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle direct text embedding request
|
|
||||||
if (
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and request[0] == "__QUERY_MODEL__"
|
||||||
|
):
|
||||||
|
rep_socket.send(msgpack.packb([model_name]))
|
||||||
|
# Direct text embedding
|
||||||
|
elif (
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and request
|
and request
|
||||||
and all(isinstance(item, str) for item in request)
|
and all(isinstance(item, str) for item in request)
|
||||||
):
|
):
|
||||||
last_request_type = "text"
|
_handle_text_embedding(request)
|
||||||
last_request_length = len(request)
|
# Distance calculation: [[ids], [query_vector]]
|
||||||
embeddings = compute_embeddings(
|
elif (
|
||||||
request,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle distance calculation request: [[ids], [query_vector]]
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and len(request) == 2
|
and len(request) == 2
|
||||||
and isinstance(request[0], list)
|
and isinstance(request[0], list)
|
||||||
and isinstance(request[1], list)
|
and isinstance(request[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request[0]
|
_handle_distance_request(request)
|
||||||
# Handle nested [[ids]] shape defensively
|
# Embedding-by-id fallback
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
|
||||||
|
|
||||||
# Gather texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
# Prepare full-length response with large sentinel values
|
|
||||||
large_distance = 1e9
|
|
||||||
response_distances = [large_distance] * len(node_ids)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else: # mips or cosine
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
|
||||||
response_distances[pos] = float(dval)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
|
||||||
|
|
||||||
# Send response in expected shape [[distances]]
|
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fallback: treat as embedding-by-id request
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
|
||||||
elif isinstance(request, list):
|
|
||||||
node_ids = request
|
|
||||||
else:
|
|
||||||
node_ids = []
|
|
||||||
last_request_type = "embedding"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
|
||||||
|
|
||||||
# Preallocate zero-filled flat data for robustness
|
|
||||||
if embedding_dim <= 0:
|
|
||||||
dims = [0, 0]
|
|
||||||
flat_data: list[float] = []
|
|
||||||
else:
|
|
||||||
dims = [len(node_ids), embedding_dim]
|
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
|
||||||
|
|
||||||
# Collect texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
|
||||||
logger.error(
|
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
|
||||||
)
|
|
||||||
dims = [0, embedding_dim]
|
|
||||||
flat_data = []
|
|
||||||
else:
|
|
||||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
flat = emb_f32.flatten().tolist()
|
|
||||||
for j, pos in enumerate(found_indices):
|
|
||||||
start = pos * embedding_dim
|
|
||||||
end = start + embedding_dim
|
|
||||||
if end <= len(flat_data):
|
|
||||||
flat_data[start:end] = flat[
|
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
|
|
||||||
rep_socket.send(response_bytes)
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
# Shape-correct fallback
|
|
||||||
try:
|
|
||||||
if last_request_type == "distance":
|
|
||||||
large_distance = 1e9
|
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
safe = [[large_distance] * fallback_len]
|
|
||||||
elif last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
safe = (
|
|
||||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
|
||||||
)
|
|
||||||
elif last_request_type == "text":
|
|
||||||
safe = [] # direct text embeddings expectation is a flat list
|
|
||||||
else:
|
|
||||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
|
_handle_embedding_by_id(request)
|
||||||
|
except Exception as exc:
|
||||||
|
if shutdown_event.is_set():
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
break
|
break
|
||||||
|
logger.error(f"Error in ZMQ server loop: {exc}")
|
||||||
|
try:
|
||||||
|
safe = _build_safe_fallback()
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
rep_socket.close(0)
|
rep_socket.close(0)
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: c69511a99c...301bf24f14
@@ -820,10 +820,10 @@ class LeannBuilder:
|
|||||||
actual_port,
|
actual_port,
|
||||||
requested_zmq_port,
|
requested_zmq_port,
|
||||||
)
|
)
|
||||||
try:
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
index.hnsw.zmq_port = actual_port
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
except AttributeError:
|
elif hasattr(index, "set_zmq_port"):
|
||||||
pass
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
if needs_recompute:
|
if needs_recompute:
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
@@ -864,7 +864,13 @@ class LeannBuilder:
|
|||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
enable_warmup: bool = True,
|
||||||
|
recompute_embeddings: bool = True,
|
||||||
|
**backend_kwargs,
|
||||||
|
):
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution for Colab and other environments
|
||||||
if not Path(index_path).is_absolute():
|
if not Path(index_path).is_absolute():
|
||||||
index_path = str(Path(index_path).resolve())
|
index_path = str(Path(index_path).resolve())
|
||||||
@@ -895,14 +901,32 @@ class LeannSearcher:
|
|||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
|
|
||||||
|
# Global recompute flag for this searcher (explicit knob, default True)
|
||||||
|
self.recompute_embeddings: bool = bool(recompute_embeddings)
|
||||||
|
|
||||||
|
# Warmup flag: keep using the existing enable_warmup parameter,
|
||||||
|
# but default it to True so cold-start happens earlier.
|
||||||
|
self._warmup: bool = bool(enable_warmup)
|
||||||
|
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = self._warmup
|
||||||
if self.embedding_options:
|
if self.embedding_options:
|
||||||
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Optional one-shot warmup at construction time to hide cold-start latency.
|
||||||
|
if self._warmup:
|
||||||
|
try:
|
||||||
|
_ = self.backend_impl.compute_query_embedding(
|
||||||
|
"__LEANN_WARMUP__",
|
||||||
|
use_server_if_available=self.recompute_embeddings,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Warmup embedding failed (ignored): {exc}")
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -910,7 +934,7 @@ class LeannSearcher:
|
|||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: Optional[bool] = None,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
@@ -927,7 +951,8 @@ class LeannSearcher:
|
|||||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
beam_width: Number of parallel search paths/IO requests per iteration
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
|
||||||
|
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
|
||||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
@@ -966,8 +991,19 @@ class LeannSearcher:
|
|||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
|
# Resolve effective recompute flag for this search.
|
||||||
|
if recompute_embeddings is not None:
|
||||||
|
logger.warning(
|
||||||
|
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
|
||||||
|
"will be removed in a future version. Configure recompute at "
|
||||||
|
"LeannSearcher(..., recompute_embeddings=...) instead."
|
||||||
|
)
|
||||||
|
effective_recompute = bool(recompute_embeddings)
|
||||||
|
else:
|
||||||
|
effective_recompute = self.recompute_embeddings
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if recompute_embeddings:
|
if effective_recompute:
|
||||||
zmq_port = self.backend_impl._ensure_server_running(
|
zmq_port = self.backend_impl._ensure_server_running(
|
||||||
self.meta_path_str,
|
self.meta_path_str,
|
||||||
port=expected_zmq_port,
|
port=expected_zmq_port,
|
||||||
@@ -981,7 +1017,7 @@ class LeannSearcher:
|
|||||||
|
|
||||||
query_embedding = self.backend_impl.compute_query_embedding(
|
query_embedding = self.backend_impl.compute_query_embedding(
|
||||||
query,
|
query,
|
||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=effective_recompute,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
@@ -993,7 +1029,7 @@ class LeannSearcher:
|
|||||||
"complexity": complexity,
|
"complexity": complexity,
|
||||||
"beam_width": beam_width,
|
"beam_width": beam_width,
|
||||||
"prune_ratio": prune_ratio,
|
"prune_ratio": prune_ratio,
|
||||||
"recompute_embeddings": recompute_embeddings,
|
"recompute_embeddings": effective_recompute,
|
||||||
"pruning_strategy": pruning_strategy,
|
"pruning_strategy": pruning_strategy,
|
||||||
"zmq_port": zmq_port,
|
"zmq_port": zmq_port,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,12 +5,15 @@ Packaged within leann-core so installed wheels can import it reliably.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Flag to ensure AST token warning only shown once per session
|
||||||
|
_ast_token_warning_shown = False
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(text: str) -> int:
|
def estimate_token_count(text: str) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -174,37 +177,44 @@ def create_ast_chunks(
|
|||||||
max_chunk_size: int = 512,
|
max_chunk_size: int = 512,
|
||||||
chunk_overlap: int = 64,
|
chunk_overlap: int = 64,
|
||||||
metadata_template: str = "default",
|
metadata_template: str = "default",
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create AST-aware chunks from code documents using astchunk.
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
Falls back to traditional chunking if astchunk is unavailable.
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from astchunk import ASTChunkBuilder # optional dependency
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"astchunk not available: {e}")
|
logger.error(f"astchunk not available: {e}")
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
language = doc.metadata.get("language")
|
language = doc.metadata.get("language")
|
||||||
if not language:
|
if not language:
|
||||||
logger.warning("No language detected; falling back to traditional chunking")
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Warn if AST chunk size + overlap might exceed common token limits
|
# Warn once if AST chunk size + overlap might exceed common token limits
|
||||||
|
# Note: Actual truncation happens at embedding time with dynamic model limits
|
||||||
|
global _ast_token_warning_shown
|
||||||
estimated_max_tokens = int(
|
estimated_max_tokens = int(
|
||||||
(max_chunk_size + chunk_overlap) * 1.2
|
(max_chunk_size + chunk_overlap) * 1.2
|
||||||
) # Conservative estimate
|
) # Conservative estimate
|
||||||
if estimated_max_tokens > 512:
|
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
||||||
|
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
||||||
)
|
)
|
||||||
|
_ast_token_warning_shown = True
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
@@ -229,17 +239,40 @@ def create_ast_chunks(
|
|||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
chunk_text = None
|
||||||
|
astchunk_metadata = {}
|
||||||
|
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = chunk.text
|
||||||
elif isinstance(chunk, dict) and "text" in chunk:
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
|
elif isinstance(chunk, dict):
|
||||||
|
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
||||||
|
if "content" in chunk:
|
||||||
|
chunk_text = chunk["content"]
|
||||||
|
astchunk_metadata = chunk.get("metadata", {})
|
||||||
|
elif "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
else:
|
||||||
|
chunk_text = str(chunk) # Last resort
|
||||||
else:
|
else:
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
if chunk_text and chunk_text.strip():
|
||||||
all_chunks.append(chunk_text.strip())
|
# Extract document-level metadata
|
||||||
|
doc_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
|
# Merge document metadata + astchunk metadata
|
||||||
|
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
||||||
|
|
||||||
|
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
@@ -247,15 +280,19 @@ def create_ast_chunks(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
logger.info("Falling back to traditional chunking")
|
logger.info("Falling back to traditional chunking")
|
||||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
def create_traditional_chunks(
|
def create_traditional_chunks(
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
|
"""
|
||||||
if chunk_size <= 0:
|
if chunk_size <= 0:
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
chunk_size = 256
|
chunk_size = 256
|
||||||
@@ -271,19 +308,40 @@ def create_traditional_chunks(
|
|||||||
paragraph_separator="\n\n",
|
paragraph_separator="\n\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
all_texts = []
|
result = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
|
# Extract document-level metadata
|
||||||
|
doc_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
if nodes:
|
if nodes:
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
for node in nodes:
|
||||||
|
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
content = doc.get_content()
|
content = doc.get_content()
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
all_texts.append(content.strip())
|
result.append({"text": content.strip(), "metadata": doc_metadata})
|
||||||
|
|
||||||
return all_texts
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _traditional_chunks_as_dicts(
|
||||||
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Helper: Traditional chunking that returns dict format for consistency.
|
||||||
|
|
||||||
|
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
||||||
|
"""
|
||||||
|
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(
|
def create_text_chunks(
|
||||||
@@ -295,8 +353,12 @@ def create_text_chunks(
|
|||||||
ast_chunk_overlap: int = 64,
|
ast_chunk_overlap: int = 64,
|
||||||
code_file_extensions: Optional[list[str]] = None,
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
ast_fallback_traditional: bool = True,
|
ast_fallback_traditional: bool = True,
|
||||||
) -> list[str]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Create text chunks from documents with optional AST support for code files."""
|
"""Create text chunks from documents with optional AST support for code files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with {"text": str, "metadata": dict}
|
||||||
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("No documents provided for chunking")
|
logger.warning("No documents provided for chunking")
|
||||||
return []
|
return []
|
||||||
@@ -331,24 +393,17 @@ def create_text_chunks(
|
|||||||
logger.error(f"AST chunking failed: {e}")
|
logger.error(f"AST chunking failed: {e}")
|
||||||
if ast_fallback_traditional:
|
if ast_fallback_traditional:
|
||||||
all_chunks.extend(
|
all_chunks.extend(
|
||||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
if text_docs:
|
if text_docs:
|
||||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
||||||
else:
|
else:
|
||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
|
||||||
# Validate chunk token limits (default to 512 for safety)
|
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
||||||
# This provides a safety net for embedding models with token limits
|
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
||||||
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
return all_chunks
|
||||||
|
|
||||||
if num_truncated > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
return validated_chunks
|
|
||||||
|
|||||||
@@ -1279,13 +1279,8 @@ Examples:
|
|||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: AST chunking currently returns plain text chunks without metadata
|
# create_text_chunks now returns list[dict] with metadata preserved
|
||||||
# We preserve basic file info by associating chunks with their source documents
|
all_texts.extend(chunk_texts)
|
||||||
# For better metadata preservation, documents list order should be maintained
|
|
||||||
for chunk_text in chunk_texts:
|
|
||||||
# TODO: Enhance create_text_chunks to return metadata alongside text
|
|
||||||
# For now, we store chunks with empty metadata
|
|
||||||
all_texts.append({"text": chunk_text, "metadata": {}})
|
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(
|
print(
|
||||||
|
|||||||
@@ -10,72 +10,63 @@ import time
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
# Set up logger with proper level
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
# Token limit registry for embedding models
|
||||||
"""
|
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
||||||
Truncate texts to token limit using tiktoken or conservative character truncation.
|
# Ollama models use dynamic discovery via /api/show
|
||||||
|
EMBEDDING_MODEL_LIMITS = {
|
||||||
Args:
|
# Nomic models (common across servers)
|
||||||
texts: List of texts to truncate
|
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
||||||
max_tokens: Maximum tokens allowed per text
|
"nomic-embed-text-v1.5": 2048,
|
||||||
|
"nomic-embed-text-v2": 512,
|
||||||
Returns:
|
# Other embedding models
|
||||||
List of truncated texts that should fit within token limit
|
"mxbai-embed-large": 512,
|
||||||
"""
|
"all-minilm": 512,
|
||||||
try:
|
"bge-m3": 8192,
|
||||||
import tiktoken
|
"snowflake-arctic-embed": 512,
|
||||||
|
# OpenAI models
|
||||||
encoder = tiktoken.get_encoding("cl100k_base")
|
"text-embedding-3-small": 8192,
|
||||||
truncated = []
|
"text-embedding-3-large": 8192,
|
||||||
|
"text-embedding-ada-002": 8192,
|
||||||
for text in texts:
|
}
|
||||||
tokens = encoder.encode(text)
|
|
||||||
if len(tokens) > max_tokens:
|
|
||||||
# Truncate to max_tokens and decode back to text
|
|
||||||
truncated_tokens = tokens[:max_tokens]
|
|
||||||
truncated_text = encoder.decode(truncated_tokens)
|
|
||||||
truncated.append(truncated_text)
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
|
||||||
f"(from {len(text)} to {len(truncated_text)} characters)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated.append(text)
|
|
||||||
return truncated
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: Conservative character truncation
|
|
||||||
# Assume worst case: 1.5 tokens per character for code content
|
|
||||||
char_limit = int(max_tokens / 1.5)
|
|
||||||
truncated = []
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
if len(text) > char_limit:
|
|
||||||
truncated_text = text[:char_limit]
|
|
||||||
truncated.append(truncated_text)
|
|
||||||
logger.warning(
|
|
||||||
f"Truncated text from {len(text)} to {char_limit} characters "
|
|
||||||
f"(conservative estimate for {max_tokens} tokens)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated.append(text)
|
|
||||||
return truncated
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_token_limit(model_name: str) -> int:
|
def get_model_token_limit(
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
default: int = 2048,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get token limit for a given embedding model.
|
Get token limit for a given embedding model.
|
||||||
|
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of the embedding model
|
model_name: Name of the embedding model
|
||||||
|
base_url: Base URL of the embedding server (for dynamic discovery)
|
||||||
|
default: Default token limit if model not found
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Token limit for the model, defaults to 512 if unknown
|
Token limit for the model in tokens
|
||||||
"""
|
"""
|
||||||
|
# Try Ollama dynamic discovery if base_url provided
|
||||||
|
if base_url:
|
||||||
|
# Detect Ollama servers by port or "ollama" in URL
|
||||||
|
if "11434" in base_url or "ollama" in base_url.lower():
|
||||||
|
limit = _query_ollama_context_limit(model_name, base_url)
|
||||||
|
if limit:
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Fallback to known model registry with version handling (from PR #154)
|
||||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||||
base_model_name = model_name.split(":")[0]
|
base_model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
@@ -92,31 +83,111 @@ def get_model_token_limit(model_name: str) -> int:
|
|||||||
if known_model in base_model_name or base_model_name in known_model:
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
return limit
|
return limit
|
||||||
|
|
||||||
# Default to conservative 512 token limit
|
# Default fallback
|
||||||
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||||
return 512
|
return default
|
||||||
|
|
||||||
|
|
||||||
# Set up logger with proper level
|
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
||||||
logger = logging.getLogger(__name__)
|
"""
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
Truncate texts to fit within token limit using tiktoken.
|
||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
||||||
logger.setLevel(log_level)
|
Args:
|
||||||
|
texts: List of text strings to truncate
|
||||||
|
token_limit: Maximum number of tokens allowed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of truncated texts (same length as input)
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use tiktoken with cl100k_base encoding
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
truncated_texts = []
|
||||||
|
truncation_count = 0
|
||||||
|
total_tokens_removed = 0
|
||||||
|
max_original_length = 0
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
tokens = enc.encode(text)
|
||||||
|
original_length = len(tokens)
|
||||||
|
|
||||||
|
if original_length <= token_limit:
|
||||||
|
# Text is within limit, keep as is
|
||||||
|
truncated_texts.append(text)
|
||||||
|
else:
|
||||||
|
# Truncate to token_limit
|
||||||
|
truncated_tokens = tokens[:token_limit]
|
||||||
|
truncated_text = enc.decode(truncated_tokens)
|
||||||
|
truncated_texts.append(truncated_text)
|
||||||
|
|
||||||
|
# Track truncation statistics
|
||||||
|
truncation_count += 1
|
||||||
|
tokens_removed = original_length - token_limit
|
||||||
|
total_tokens_removed += tokens_removed
|
||||||
|
max_original_length = max(max_original_length, original_length)
|
||||||
|
|
||||||
|
# Log individual truncation at WARNING level (first few only)
|
||||||
|
if truncation_count <= 3:
|
||||||
|
logger.warning(
|
||||||
|
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
||||||
|
f"({tokens_removed} tokens removed)"
|
||||||
|
)
|
||||||
|
elif truncation_count == 4:
|
||||||
|
logger.warning("Further truncation warnings suppressed...")
|
||||||
|
|
||||||
|
# Log summary at INFO level
|
||||||
|
if truncation_count > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
||||||
|
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated_texts
|
||||||
|
|
||||||
|
|
||||||
|
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Query Ollama /api/show for model context limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the Ollama model
|
||||||
|
base_url: Base URL of the Ollama server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context limit in tokens if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/api/show",
|
||||||
|
json={"name": model_name},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "model_info" in data:
|
||||||
|
# Look for *.context_length in model_info
|
||||||
|
for key, value in data["model_info"].items():
|
||||||
|
if "context_length" in key and isinstance(value, int):
|
||||||
|
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
||||||
|
return value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to query Ollama context limit: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
# Known embedding model token limits
|
|
||||||
EMBEDDING_MODEL_LIMITS = {
|
|
||||||
"nomic-embed-text": 512,
|
|
||||||
"nomic-embed-text-v2": 512,
|
|
||||||
"mxbai-embed-large": 512,
|
|
||||||
"all-minilm": 512,
|
|
||||||
"bge-m3": 8192,
|
|
||||||
"snowflake-arctic-embed": 512,
|
|
||||||
# Add more models as needed
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -144,9 +215,14 @@ def compute_embeddings(
|
|||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
provider_options = provider_options or {}
|
provider_options = provider_options or {}
|
||||||
|
wrapper_start_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
inner_start_time = time.time()
|
||||||
|
result = compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
@@ -155,6 +231,14 @@ def compute_embeddings(
|
|||||||
manual_tokenize=manual_tokenize,
|
manual_tokenize=manual_tokenize,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
inner_end_time = time.time()
|
||||||
|
wrapper_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"[compute_embeddings] sentence-transformers timings: "
|
||||||
|
f"inner={inner_end_time - inner_start_time:.6f}s, "
|
||||||
|
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
return result
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(
|
return compute_embeddings_openai(
|
||||||
texts,
|
texts,
|
||||||
@@ -200,6 +284,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
"""
|
"""
|
||||||
|
outer_start_time = time.time()
|
||||||
# Handle empty input
|
# Handle empty input
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
@@ -230,7 +315,14 @@ def compute_embeddings_sentence_transformers(
|
|||||||
# Create cache key
|
# Create cache key
|
||||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
|
||||||
|
|
||||||
|
pre_model_init_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"compute_embeddings_sentence_transformers pre-model-init time "
|
||||||
|
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if model is already cached
|
# Check if model is already cached
|
||||||
|
start_time = time.time()
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
@@ -370,10 +462,13 @@ def compute_embeddings_sentence_transformers(
|
|||||||
_model_cache[cache_key] = model
|
_model_cache[cache_key] = model
|
||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
end_time = time.time()
|
||||||
logger.info(
|
|
||||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
# Compute embeddings with optimized inference mode
|
||||||
)
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
if not manual_tokenize:
|
if not manual_tokenize:
|
||||||
@@ -394,32 +489,46 @@ def compute_embeddings_sentence_transformers(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
|
||||||
|
# This path is reserved for an aggressively optimized FP pipeline
|
||||||
|
# (no quantization), mainly for experimentation.
|
||||||
try:
|
try:
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
# Cache tokenizer and model
|
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
|
||||||
|
|
||||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
hf_tokenizer = _model_cache[tok_cache_key]
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
hf_model = _model_cache[mdl_cache_key]
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
logger.info("Using cached HF tokenizer/model for manual path")
|
logger.info("Using cached HF tokenizer/model for manual FP path")
|
||||||
else:
|
else:
|
||||||
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
logger.info("Loading HF tokenizer/model for manual FP path")
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
hf_model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
hf_model.to(device)
|
hf_model.to(device)
|
||||||
|
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
# Optional compile on supported devices
|
# Optional compile on supported devices
|
||||||
if device in ["cuda", "mps"]:
|
if device in ["cuda", "mps"]:
|
||||||
try:
|
try:
|
||||||
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
hf_model = torch.compile( # type: ignore
|
||||||
except Exception:
|
hf_model, mode="reduce-overhead", dynamic=True
|
||||||
pass
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Applied torch.compile to HF model for {model_name} "
|
||||||
|
f"(device={device}, dtype={torch_dtype})"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"torch.compile optimization failed: {exc}")
|
||||||
|
|
||||||
_model_cache[tok_cache_key] = hf_tokenizer
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
_model_cache[mdl_cache_key] = hf_model
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
@@ -445,7 +554,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
for start_index in batch_iter:
|
for start_index in batch_iter:
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
batch_texts = texts[start_index:end_index]
|
batch_texts = texts[start_index:end_index]
|
||||||
tokenize_start_time = time.time()
|
|
||||||
inputs = hf_tokenizer(
|
inputs = hf_tokenizer(
|
||||||
batch_texts,
|
batch_texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -453,34 +561,17 @@ def compute_embeddings_sentence_transformers(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
tokenize_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
|
||||||
)
|
|
||||||
# Print shapes of all input tensors for debugging
|
|
||||||
for k, v in inputs.items():
|
|
||||||
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
|
||||||
to_device_start_time = time.time()
|
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
to_device_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
|
||||||
)
|
|
||||||
forward_start_time = time.time()
|
|
||||||
outputs = hf_model(**inputs)
|
outputs = hf_model(**inputs)
|
||||||
forward_end_time = time.time()
|
|
||||||
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
# Fallback: assume all tokens are valid
|
|
||||||
pooled = last_hidden_state.mean(dim=1)
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
else:
|
else:
|
||||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
masked = last_hidden_state * mask
|
masked = last_hidden_state * mask
|
||||||
lengths = mask.sum(dim=1).clamp(min=1)
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
pooled = masked.sum(dim=1) / lengths
|
pooled = masked.sum(dim=1) / lengths
|
||||||
# Move to CPU float32
|
|
||||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
all_embeddings.append(batch_embeddings)
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
@@ -500,6 +591,12 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
|
|
||||||
|
outer_end_time = time.time()
|
||||||
|
logger.debug(
|
||||||
|
"compute_embeddings_sentence_transformers total time "
|
||||||
|
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -814,15 +911,13 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
# Get model token limit and apply truncation
|
# Get model token limit and apply truncation before batching
|
||||||
token_limit = get_model_token_limit(model_name)
|
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|
||||||
# Apply token-aware truncation to all texts
|
# Apply truncation to all texts before batch processing
|
||||||
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
# Function logs truncation details internally
|
||||||
if len(truncated_texts) != len(texts):
|
texts = truncate_to_token_limit(texts, token_limit)
|
||||||
logger.error("Truncation failed - text count mismatch")
|
|
||||||
truncated_texts = texts # Fallback to original texts
|
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||||
@@ -880,12 +975,12 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
return None, list(range(len(batch_texts)))
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
# Process truncated texts in batches
|
# Process texts in batches
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
all_failed_indices = []
|
all_failed_indices = []
|
||||||
|
|
||||||
# Setup progress bar if needed
|
# Setup progress bar if needed
|
||||||
show_progress = is_build or len(truncated_texts) > 10
|
show_progress = is_build or len(texts) > 10
|
||||||
try:
|
try:
|
||||||
if show_progress:
|
if show_progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -893,7 +988,7 @@ def compute_embeddings_ollama(
|
|||||||
show_progress = False
|
show_progress = False
|
||||||
|
|
||||||
# Process batches
|
# Process batches
|
||||||
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||||
@@ -902,8 +997,8 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
for batch_idx in batch_iterator:
|
for batch_idx in batch_iterator:
|
||||||
start_idx = batch_idx * batch_size
|
start_idx = batch_idx * batch_size
|
||||||
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
end_idx = min(start_idx + batch_size, len(texts))
|
||||||
batch_texts = truncated_texts[start_idx:end_idx]
|
batch_texts = texts[start_idx:end_idx]
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
@@ -918,11 +1013,11 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
if len(all_failed_indices) == len(truncated_texts):
|
if len(all_failed_indices) == len(texts):
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ dependencies = [
|
|||||||
"tree-sitter-c-sharp>=0.20.0",
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
"tree-sitter-typescript>=0.20.0",
|
"tree-sitter-typescript>=0.20.0",
|
||||||
"torchvision>=0.23.0",
|
"torchvision>=0.23.0",
|
||||||
|
"einops",
|
||||||
|
"seaborn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
|
|||||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# Traditional chunks now return dict format for consistency
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
assert all(isinstance(chunk, dict) for chunk in chunks)
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
|
||||||
|
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
def test_create_traditional_chunks_empty_docs(self):
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
"""Test traditional chunking with empty documents."""
|
"""Test traditional chunking with empty documents."""
|
||||||
@@ -158,11 +160,22 @@ class Calculator:
|
|||||||
|
|
||||||
# Should have multiple chunks due to different functions/classes
|
# Should have multiple chunks due to different functions/classes
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: Expect dict format with "text" and "metadata" keys
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
|
||||||
|
"Each chunk text should be non-empty"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check metadata is present
|
||||||
|
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
|
||||||
|
"Each chunk should have file_path metadata"
|
||||||
|
)
|
||||||
|
|
||||||
# Check that code structure is somewhat preserved
|
# Check that code structure is somewhat preserved
|
||||||
combined_content = " ".join(chunks)
|
combined_content = " ".join([c["text"] for c in chunks])
|
||||||
assert "def hello_world" in combined_content
|
assert "def hello_world" in combined_content
|
||||||
assert "class Calculator" in combined_content
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
@@ -194,7 +207,11 @@ class Calculator:
|
|||||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: Traditional chunking should also return dict format for consistency
|
||||||
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_text_chunks_ast_mode(self):
|
def test_create_text_chunks_ast_mode(self):
|
||||||
"""Test text chunking in AST mode."""
|
"""Test text chunking in AST mode."""
|
||||||
@@ -213,7 +230,11 @@ class Calculator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(chunks) > 0
|
assert len(chunks) > 0
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
# R3: AST mode should also return dict format
|
||||||
|
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||||
|
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||||
|
"Each chunk should have 'text' and 'metadata' keys"
|
||||||
|
)
|
||||||
|
|
||||||
def test_create_text_chunks_custom_extensions(self):
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
"""Test text chunking with custom code file extensions."""
|
"""Test text chunking with custom code file extensions."""
|
||||||
@@ -353,6 +374,552 @@ class MathUtils:
|
|||||||
pytest.skip("Test timed out - likely due to model download in CI")
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTContentExtraction:
|
||||||
|
"""Test AST content extraction bug fix.
|
||||||
|
|
||||||
|
These tests verify that astchunk's dict format with 'content' key is handled correctly,
|
||||||
|
and that the extraction logic doesn't fall through to stringifying entire dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_extract_content_from_astchunk_dict(self):
|
||||||
|
"""Test that astchunk dict format with 'content' key is handled correctly.
|
||||||
|
|
||||||
|
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
|
||||||
|
This causes fallthrough to str(chunk), stringifying the entire dict.
|
||||||
|
|
||||||
|
This test will FAIL until the bug is fixed because:
|
||||||
|
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
|
||||||
|
- Fixed code should extract just the content value
|
||||||
|
"""
|
||||||
|
# Mock the ASTChunkBuilder class
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Astchunk returns this format
|
||||||
|
astchunk_format_chunk = {
|
||||||
|
"content": "def hello():\n print('world')",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "test.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 0,
|
||||||
|
"end_line_no": 1,
|
||||||
|
"node_count": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_builder.chunkify.return_value = [astchunk_format_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module and its ASTChunkBuilder class
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
# Patch sys.modules to inject our mock before the import
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should return dict format with proper metadata
|
||||||
|
assert len(chunks) > 0, "Should return at least one chunk"
|
||||||
|
|
||||||
|
# R3: Each chunk should be a dict
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
assert "metadata" in chunk, "Chunk should have 'metadata' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
# CRITICAL: Should NOT contain stringified dict markers in the text field
|
||||||
|
# These assertions will FAIL with current buggy code
|
||||||
|
assert "'content':" not in chunk_text, (
|
||||||
|
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
|
||||||
|
)
|
||||||
|
assert "'metadata':" not in chunk_text, (
|
||||||
|
"Chunk text contains stringified metadata - extraction failed! "
|
||||||
|
f"Got: {chunk_text[:100]}..."
|
||||||
|
)
|
||||||
|
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
|
||||||
|
"Chunk text appears to be a stringified dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should contain actual content
|
||||||
|
assert "def hello()" in chunk_text, "Should extract actual code content"
|
||||||
|
assert "print('world')" in chunk_text, "Should extract complete code content"
|
||||||
|
|
||||||
|
# R3: Should preserve astchunk metadata
|
||||||
|
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
|
||||||
|
"Should preserve file path metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_extract_text_key_fallback(self):
|
||||||
|
"""Test that 'text' key still works for backward compatibility.
|
||||||
|
|
||||||
|
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
|
||||||
|
This test should PASS even with current code.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Some chunks might use "text" key
|
||||||
|
text_key_chunk = {"text": "def legacy_function():\n return True"}
|
||||||
|
mock_builder.chunkify.return_value = [text_key_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should extract text correctly as dict format
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
# Should NOT be stringified
|
||||||
|
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
|
||||||
|
|
||||||
|
# Should contain actual content
|
||||||
|
assert "def legacy_function()" in chunk_text
|
||||||
|
assert "return True" in chunk_text
|
||||||
|
|
||||||
|
def test_handles_string_chunks(self):
|
||||||
|
"""Test that plain string chunks still work.
|
||||||
|
|
||||||
|
Some chunkers might return plain strings - verify these are preserved.
|
||||||
|
This test should PASS with current code.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Plain string chunk
|
||||||
|
plain_string_chunk = "def simple_function():\n pass"
|
||||||
|
mock_builder.chunkify.return_value = [plain_string_chunk]
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
doc = MockDocument(
|
||||||
|
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should wrap string in dict format
|
||||||
|
assert len(chunks) > 0
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
|
||||||
|
assert chunk_text == plain_string_chunk.strip(), (
|
||||||
|
"Should preserve plain string chunk content"
|
||||||
|
)
|
||||||
|
assert "def simple_function()" in chunk_text
|
||||||
|
assert "pass" in chunk_text
|
||||||
|
|
||||||
|
def test_multiple_chunks_with_mixed_formats(self):
|
||||||
|
"""Test handling of multiple chunks with different formats.
|
||||||
|
|
||||||
|
Real-world scenario: astchunk might return a mix of formats.
|
||||||
|
This test will FAIL if any chunk with 'content' key gets stringified.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
# Mix of formats
|
||||||
|
mixed_chunks = [
|
||||||
|
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
|
||||||
|
"def second():\n return 2", # Plain string
|
||||||
|
{"text": "def third():\n return 3"}, # Old format
|
||||||
|
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = mixed_chunks
|
||||||
|
|
||||||
|
# Create mock document
|
||||||
|
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
|
||||||
|
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
# Call create_ast_chunks
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should extract all chunks correctly as dicts
|
||||||
|
assert len(chunks) == 4, "Should extract all 4 chunks"
|
||||||
|
|
||||||
|
# Check each chunk
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
|
||||||
|
assert "text" in chunk, f"Chunk {i} should have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
|
||||||
|
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
# None should be stringified dicts
|
||||||
|
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
|
||||||
|
assert "'metadata':" not in chunk_text, (
|
||||||
|
f"Chunk {i} text is stringified (has 'metadata':)"
|
||||||
|
)
|
||||||
|
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
|
||||||
|
|
||||||
|
# Verify actual content is present
|
||||||
|
combined = "\n".join([c["text"] for c in chunks])
|
||||||
|
assert "def first()" in combined
|
||||||
|
assert "def second()" in combined
|
||||||
|
assert "def third()" in combined
|
||||||
|
assert "class MyClass:" in combined
|
||||||
|
|
||||||
|
def test_empty_content_value_handling(self):
|
||||||
|
"""Test handling of chunks with empty content values.
|
||||||
|
|
||||||
|
Edge case: chunk has 'content' key but value is empty.
|
||||||
|
Should skip these chunks, not stringify them.
|
||||||
|
"""
|
||||||
|
mock_builder = Mock()
|
||||||
|
|
||||||
|
chunks_with_empty = [
|
||||||
|
{"content": "", "metadata": {"line_count": 0}}, # Empty content
|
||||||
|
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
|
||||||
|
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = chunks_with_empty
|
||||||
|
|
||||||
|
doc = MockDocument(
|
||||||
|
"def valid():\n return True", "/test/empty.py", {"language": "python"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the astchunk module
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# R3: Should only have the valid chunk (empty ones filtered out)
|
||||||
|
assert len(chunks) == 1, "Should filter out empty content chunks"
|
||||||
|
|
||||||
|
chunk = chunks[0]
|
||||||
|
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||||
|
assert "text" in chunk, "Chunk should have 'text' key"
|
||||||
|
assert "def valid()" in chunk["text"]
|
||||||
|
|
||||||
|
# Should not have stringified the empty dict
|
||||||
|
assert "'content': ''" not in chunk["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTMetadataPreservation:
|
||||||
|
"""Test metadata preservation in AST chunk dictionaries.
|
||||||
|
|
||||||
|
R3: These tests define the contract for metadata preservation when returning
|
||||||
|
chunk dictionaries instead of plain strings. Each chunk dict should have:
|
||||||
|
- "text": str - the actual chunk content
|
||||||
|
- "metadata": dict - all metadata from document AND astchunk
|
||||||
|
|
||||||
|
These tests will FAIL until G3 implementation changes return type to list[dict].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_ast_chunks_preserve_file_metadata(self):
|
||||||
|
"""Test that document metadata is preserved in chunk metadata.
|
||||||
|
|
||||||
|
This test verifies that all document-level metadata (file_path, file_name,
|
||||||
|
creation_date, last_modified_date) is included in each chunk's metadata dict.
|
||||||
|
|
||||||
|
This will FAIL because current code returns list[str], not list[dict].
|
||||||
|
"""
|
||||||
|
# Create mock document with rich metadata
|
||||||
|
python_code = '''
|
||||||
|
def calculate_sum(numbers):
|
||||||
|
"""Calculate sum of numbers."""
|
||||||
|
return sum(numbers)
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""Process data records."""
|
||||||
|
|
||||||
|
def process(self, data):
|
||||||
|
return [x * 2 for x in data]
|
||||||
|
'''
|
||||||
|
doc = MockDocument(
|
||||||
|
python_code,
|
||||||
|
file_path="/project/src/utils.py",
|
||||||
|
metadata={
|
||||||
|
"language": "python",
|
||||||
|
"file_path": "/project/src/utils.py",
|
||||||
|
"file_name": "utils.py",
|
||||||
|
"creation_date": "2024-01-15T10:30:00",
|
||||||
|
"last_modified_date": "2024-10-31T15:45:00",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock astchunk to return chunks with metadata
|
||||||
|
mock_builder = Mock()
|
||||||
|
astchunk_chunks = [
|
||||||
|
{
|
||||||
|
"content": "def calculate_sum(numbers):\n return sum(numbers)",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/project/src/utils.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 1,
|
||||||
|
"end_line_no": 2,
|
||||||
|
"node_count": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/project/src/utils.py",
|
||||||
|
"line_count": 3,
|
||||||
|
"start_line_no": 5,
|
||||||
|
"end_line_no": 7,
|
||||||
|
"node_count": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = astchunk_chunks
|
||||||
|
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# CRITICAL: These assertions will FAIL with current list[str] return type
|
||||||
|
assert len(chunks) == 2, "Should return 2 chunks"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure assertions - WILL FAIL: current code returns strings
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||||
|
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||||
|
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
|
||||||
|
|
||||||
|
# Document metadata preservation - WILL FAIL
|
||||||
|
metadata = chunk["metadata"]
|
||||||
|
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
|
||||||
|
assert metadata["file_path"] == "/project/src/utils.py", (
|
||||||
|
f"Chunk {i} file_path incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
|
||||||
|
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
|
||||||
|
|
||||||
|
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
|
||||||
|
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
|
||||||
|
f"Chunk {i} creation_date incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
|
||||||
|
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
|
||||||
|
f"Chunk {i} last_modified_date incorrect"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify metadata is consistent across chunks from same document
|
||||||
|
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
|
||||||
|
"All chunks from same document should have same file_path"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify text content is present and not stringified
|
||||||
|
assert "def calculate_sum" in chunks[0]["text"]
|
||||||
|
assert "class DataProcessor" in chunks[1]["text"]
|
||||||
|
|
||||||
|
def test_ast_chunks_include_astchunk_metadata(self):
|
||||||
|
"""Test that astchunk-specific metadata is merged into chunk metadata.
|
||||||
|
|
||||||
|
This test verifies that astchunk's metadata (line_count, start_line_no,
|
||||||
|
end_line_no, node_count) is merged with document metadata.
|
||||||
|
|
||||||
|
This will FAIL because current code returns list[str], not list[dict].
|
||||||
|
"""
|
||||||
|
python_code = '''
|
||||||
|
def function_one():
|
||||||
|
"""First function."""
|
||||||
|
x = 1
|
||||||
|
y = 2
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
def function_two():
|
||||||
|
"""Second function."""
|
||||||
|
return 42
|
||||||
|
'''
|
||||||
|
doc = MockDocument(
|
||||||
|
python_code,
|
||||||
|
file_path="/test/code.py",
|
||||||
|
metadata={
|
||||||
|
"language": "python",
|
||||||
|
"file_path": "/test/code.py",
|
||||||
|
"file_name": "code.py",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock astchunk with detailed metadata
|
||||||
|
mock_builder = Mock()
|
||||||
|
astchunk_chunks = [
|
||||||
|
{
|
||||||
|
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/test/code.py",
|
||||||
|
"line_count": 4,
|
||||||
|
"start_line_no": 1,
|
||||||
|
"end_line_no": 4,
|
||||||
|
"node_count": 5, # function, assignments, return
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "def function_two():\n return 42",
|
||||||
|
"metadata": {
|
||||||
|
"filepath": "/test/code.py",
|
||||||
|
"line_count": 2,
|
||||||
|
"start_line_no": 7,
|
||||||
|
"end_line_no": 8,
|
||||||
|
"node_count": 2, # function, return
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_builder.chunkify.return_value = astchunk_chunks
|
||||||
|
|
||||||
|
mock_astchunk = Mock()
|
||||||
|
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||||
|
chunks = create_ast_chunks([doc])
|
||||||
|
|
||||||
|
# CRITICAL: These will FAIL with current list[str] return
|
||||||
|
assert len(chunks) == 2
|
||||||
|
|
||||||
|
# First chunk - function_one
|
||||||
|
chunk1 = chunks[0]
|
||||||
|
assert isinstance(chunk1, dict), "Chunk should be dict"
|
||||||
|
assert "metadata" in chunk1
|
||||||
|
|
||||||
|
metadata1 = chunk1["metadata"]
|
||||||
|
|
||||||
|
# Check astchunk metadata is present
|
||||||
|
assert "line_count" in metadata1, "Should include astchunk line_count"
|
||||||
|
assert metadata1["line_count"] == 4, "line_count should be 4"
|
||||||
|
|
||||||
|
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
|
||||||
|
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
|
||||||
|
|
||||||
|
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
|
||||||
|
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
|
||||||
|
|
||||||
|
assert "node_count" in metadata1, "Should include astchunk node_count"
|
||||||
|
assert metadata1["node_count"] == 5, "node_count should be 5"
|
||||||
|
|
||||||
|
# Second chunk - function_two
|
||||||
|
chunk2 = chunks[1]
|
||||||
|
metadata2 = chunk2["metadata"]
|
||||||
|
|
||||||
|
assert metadata2["line_count"] == 2, "line_count should be 2"
|
||||||
|
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
|
||||||
|
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
|
||||||
|
assert metadata2["node_count"] == 2, "node_count should be 2"
|
||||||
|
|
||||||
|
# Verify document metadata is ALSO present (merged, not replaced)
|
||||||
|
assert metadata1["file_path"] == "/test/code.py"
|
||||||
|
assert metadata1["file_name"] == "code.py"
|
||||||
|
assert metadata2["file_path"] == "/test/code.py"
|
||||||
|
assert metadata2["file_name"] == "code.py"
|
||||||
|
|
||||||
|
# Verify text content is correct
|
||||||
|
assert "def function_one" in chunk1["text"]
|
||||||
|
assert "def function_two" in chunk2["text"]
|
||||||
|
|
||||||
|
def test_traditional_chunks_as_dicts_helper(self):
|
||||||
|
"""Test the helper function that wraps traditional chunks as dicts.
|
||||||
|
|
||||||
|
This test verifies that when create_traditional_chunks is called,
|
||||||
|
its plain string chunks are wrapped into dict format with metadata.
|
||||||
|
|
||||||
|
This will FAIL because the helper function _traditional_chunks_as_dicts()
|
||||||
|
doesn't exist yet, and create_traditional_chunks returns list[str].
|
||||||
|
"""
|
||||||
|
# Create documents with various metadata
|
||||||
|
docs = [
|
||||||
|
MockDocument(
|
||||||
|
"This is the first paragraph of text. It contains multiple sentences. "
|
||||||
|
"This should be split into chunks based on size.",
|
||||||
|
file_path="/docs/readme.txt",
|
||||||
|
metadata={
|
||||||
|
"file_path": "/docs/readme.txt",
|
||||||
|
"file_name": "readme.txt",
|
||||||
|
"creation_date": "2024-01-01",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockDocument(
|
||||||
|
"Second document with different metadata. It also has content that needs chunking.",
|
||||||
|
file_path="/docs/guide.md",
|
||||||
|
metadata={
|
||||||
|
"file_path": "/docs/guide.md",
|
||||||
|
"file_name": "guide.md",
|
||||||
|
"last_modified_date": "2024-10-31",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call create_traditional_chunks (which should now return list[dict])
|
||||||
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
# CRITICAL: Will FAIL - current code returns list[str]
|
||||||
|
assert len(chunks) > 0, "Should return chunks"
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
# Structure assertions - WILL FAIL
|
||||||
|
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||||
|
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||||
|
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||||
|
|
||||||
|
# Text should be non-empty
|
||||||
|
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
|
||||||
|
|
||||||
|
# Metadata should include document info
|
||||||
|
metadata = chunk["metadata"]
|
||||||
|
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
|
||||||
|
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
|
||||||
|
|
||||||
|
# Verify metadata tracking works correctly
|
||||||
|
# At least one chunk should be from readme.txt
|
||||||
|
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
|
||||||
|
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
|
||||||
|
|
||||||
|
# At least one chunk should be from guide.md
|
||||||
|
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
|
||||||
|
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
|
||||||
|
|
||||||
|
# Verify creation_date is preserved for readme chunks
|
||||||
|
for chunk in readme_chunks:
|
||||||
|
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
|
||||||
|
"readme.txt chunks should preserve creation_date"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify last_modified_date is preserved for guide chunks
|
||||||
|
for chunk in guide_chunks:
|
||||||
|
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
|
||||||
|
"guide.md chunks should preserve last_modified_date"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify text content is present
|
||||||
|
all_text = " ".join([c["text"] for c in chunks])
|
||||||
|
assert "first paragraph" in all_text
|
||||||
|
assert "Second document" in all_text
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
"""Test error handling and edge cases."""
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
|||||||
268
tests/test_token_truncation.py
Normal file
268
tests/test_token_truncation.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Unit tests for token-aware truncation functionality.
|
||||||
|
|
||||||
|
This test suite defines the contract for token truncation functions that prevent
|
||||||
|
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
||||||
|
|
||||||
|
1. Model token limit retrieval (known and unknown models)
|
||||||
|
2. Text truncation behavior for single and multiple texts
|
||||||
|
3. Token counting and truncation accuracy using tiktoken
|
||||||
|
|
||||||
|
All tests are written in Red Phase - they should FAIL initially because the
|
||||||
|
implementation does not exist yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tiktoken
|
||||||
|
from leann.embedding_compute import (
|
||||||
|
EMBEDDING_MODEL_LIMITS,
|
||||||
|
get_model_token_limit,
|
||||||
|
truncate_to_token_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelTokenLimits:
|
||||||
|
"""Tests for retrieving model-specific token limits."""
|
||||||
|
|
||||||
|
def test_get_model_token_limit_known_model(self):
|
||||||
|
"""Verify correct token limit is returned for known models.
|
||||||
|
|
||||||
|
Known models should return their specific token limits from
|
||||||
|
EMBEDDING_MODEL_LIMITS dictionary.
|
||||||
|
"""
|
||||||
|
# Test nomic-embed-text (2048 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text")
|
||||||
|
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
||||||
|
|
||||||
|
# Test nomic-embed-text-v1.5 (2048 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
||||||
|
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
||||||
|
|
||||||
|
# Test nomic-embed-text-v2 (512 tokens)
|
||||||
|
limit = get_model_token_limit("nomic-embed-text-v2")
|
||||||
|
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
||||||
|
|
||||||
|
# Test OpenAI models (8192 tokens)
|
||||||
|
limit = get_model_token_limit("text-embedding-3-small")
|
||||||
|
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_unknown_model(self):
|
||||||
|
"""Verify default token limit is returned for unknown models.
|
||||||
|
|
||||||
|
Unknown models should return the default limit (2048) to allow
|
||||||
|
operation with reasonable safety margin.
|
||||||
|
"""
|
||||||
|
# Test with completely unknown model
|
||||||
|
limit = get_model_token_limit("unknown-model-xyz")
|
||||||
|
assert limit == 2048, "Unknown models should return default 2048"
|
||||||
|
|
||||||
|
# Test with empty string
|
||||||
|
limit = get_model_token_limit("")
|
||||||
|
assert limit == 2048, "Empty model name should return default 2048"
|
||||||
|
|
||||||
|
def test_get_model_token_limit_custom_default(self):
|
||||||
|
"""Verify custom default can be specified for unknown models.
|
||||||
|
|
||||||
|
Allow callers to specify their own default token limit when
|
||||||
|
model is not in the known models dictionary.
|
||||||
|
"""
|
||||||
|
limit = get_model_token_limit("unknown-model", default=4096)
|
||||||
|
assert limit == 4096, "Should return custom default for unknown models"
|
||||||
|
|
||||||
|
# Known model should ignore custom default
|
||||||
|
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
||||||
|
assert limit == 2048, "Known model should ignore custom default"
|
||||||
|
|
||||||
|
def test_embedding_model_limits_dictionary_exists(self):
|
||||||
|
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
||||||
|
|
||||||
|
The dictionary should be importable and contain at least the
|
||||||
|
known nomic models with correct token limits.
|
||||||
|
"""
|
||||||
|
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
||||||
|
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
||||||
|
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
||||||
|
"Should contain nomic-embed-text-v1.5"
|
||||||
|
)
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
||||||
|
# OpenAI models
|
||||||
|
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenTruncation:
|
||||||
|
"""Tests for truncating texts to token limits."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer(self):
|
||||||
|
"""Provide tiktoken tokenizer for token counting verification."""
|
||||||
|
return tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def test_truncate_single_text_under_limit(self, tokenizer):
|
||||||
|
"""Verify text under token limit remains unchanged.
|
||||||
|
|
||||||
|
When text is already within the token limit, it should be
|
||||||
|
returned unchanged with no truncation.
|
||||||
|
"""
|
||||||
|
text = "This is a short text that is well under the token limit."
|
||||||
|
token_count = len(tokenizer.encode(text))
|
||||||
|
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
||||||
|
|
||||||
|
# Truncate with generous limit
|
||||||
|
result = truncate_to_token_limit([text], token_limit=512)
|
||||||
|
|
||||||
|
assert len(result) == 1, "Should return same number of texts"
|
||||||
|
assert result[0] == text, "Text under limit should be unchanged"
|
||||||
|
|
||||||
|
def test_truncate_single_text_over_limit(self, tokenizer):
|
||||||
|
"""Verify text over token limit is truncated correctly.
|
||||||
|
|
||||||
|
When text exceeds the token limit, it should be truncated to
|
||||||
|
fit within the limit while maintaining valid token boundaries.
|
||||||
|
"""
|
||||||
|
# Create a text that definitely exceeds limit
|
||||||
|
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
||||||
|
original_token_count = len(tokenizer.encode(text))
|
||||||
|
assert original_token_count > 50, (
|
||||||
|
f"Test setup: text should be long (has {original_token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate to 50 tokens
|
||||||
|
result = truncate_to_token_limit([text], token_limit=50)
|
||||||
|
|
||||||
|
assert len(result) == 1, "Should return same number of texts"
|
||||||
|
assert result[0] != text, "Text over limit should be truncated"
|
||||||
|
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
||||||
|
|
||||||
|
# Verify truncated text is within token limit
|
||||||
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||||
|
assert truncated_token_count <= 50, (
|
||||||
|
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
||||||
|
"""Verify multiple texts with mixed lengths are handled correctly.
|
||||||
|
|
||||||
|
When processing multiple texts:
|
||||||
|
- Texts under limit should remain unchanged
|
||||||
|
- Texts over limit should be truncated independently
|
||||||
|
- Output list should maintain same order and length
|
||||||
|
"""
|
||||||
|
texts = [
|
||||||
|
"Short text.", # Under limit
|
||||||
|
"word " * 200, # Over limit
|
||||||
|
"Another short one.", # Under limit
|
||||||
|
"token " * 150, # Over limit
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify test setup
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
token_count = len(tokenizer.encode(text))
|
||||||
|
if i in [1, 3]:
|
||||||
|
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
||||||
|
else:
|
||||||
|
assert token_count < 50, (
|
||||||
|
f"Text {i} should be under limit (has {token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate with 50 token limit
|
||||||
|
result = truncate_to_token_limit(texts, token_limit=50)
|
||||||
|
|
||||||
|
assert len(result) == len(texts), "Should return same number of texts"
|
||||||
|
|
||||||
|
# Verify each text individually
|
||||||
|
for i, (original, truncated) in enumerate(zip(texts, result)):
|
||||||
|
token_count = len(tokenizer.encode(truncated))
|
||||||
|
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
||||||
|
|
||||||
|
# Short texts should be unchanged
|
||||||
|
if i in [0, 2]:
|
||||||
|
assert truncated == original, f"Short text {i} should be unchanged"
|
||||||
|
# Long texts should be truncated
|
||||||
|
else:
|
||||||
|
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
||||||
|
|
||||||
|
def test_truncate_empty_list(self):
|
||||||
|
"""Verify empty input list returns empty output list.
|
||||||
|
|
||||||
|
Edge case: empty list should return empty list without errors.
|
||||||
|
"""
|
||||||
|
result = truncate_to_token_limit([], token_limit=512)
|
||||||
|
assert result == [], "Empty input should return empty output"
|
||||||
|
|
||||||
|
def test_truncate_preserves_order(self, tokenizer):
|
||||||
|
"""Verify truncation preserves original text order.
|
||||||
|
|
||||||
|
Output list should maintain the same order as input list,
|
||||||
|
regardless of which texts were truncated.
|
||||||
|
"""
|
||||||
|
texts = [
|
||||||
|
"First text " * 50, # Will be truncated
|
||||||
|
"Second text.", # Won't be truncated
|
||||||
|
"Third text " * 50, # Will be truncated
|
||||||
|
]
|
||||||
|
|
||||||
|
result = truncate_to_token_limit(texts, token_limit=20)
|
||||||
|
|
||||||
|
assert len(result) == 3, "Should preserve list length"
|
||||||
|
# Check that order is maintained by looking for distinctive words
|
||||||
|
assert "First" in result[0], "First text should remain in first position"
|
||||||
|
assert "Second" in result[1], "Second text should remain in second position"
|
||||||
|
assert "Third" in result[2], "Third text should remain in third position"
|
||||||
|
|
||||||
|
def test_truncate_extremely_long_text(self, tokenizer):
|
||||||
|
"""Verify extremely long texts are truncated efficiently.
|
||||||
|
|
||||||
|
Test with text that far exceeds token limit to ensure
|
||||||
|
truncation handles extreme cases without performance issues.
|
||||||
|
"""
|
||||||
|
# Create very long text (simulate real-world scenario)
|
||||||
|
text = "token " * 5000 # ~5000+ tokens
|
||||||
|
original_token_count = len(tokenizer.encode(text))
|
||||||
|
assert original_token_count > 1000, "Test setup: text should be very long"
|
||||||
|
|
||||||
|
# Truncate to small limit
|
||||||
|
result = truncate_to_token_limit([text], token_limit=100)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||||
|
assert truncated_token_count <= 100, (
|
||||||
|
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
||||||
|
)
|
||||||
|
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
||||||
|
|
||||||
|
def test_truncate_exact_token_limit(self, tokenizer):
|
||||||
|
"""Verify text at exactly token limit is handled correctly.
|
||||||
|
|
||||||
|
Edge case: text with exactly the token limit should either
|
||||||
|
remain unchanged or be safely truncated by 1 token.
|
||||||
|
"""
|
||||||
|
# Create text with approximately 50 tokens
|
||||||
|
# We'll adjust to get exactly 50
|
||||||
|
target_tokens = 50
|
||||||
|
text = "word " * 50
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Adjust to get exactly target_tokens
|
||||||
|
if len(tokens) > target_tokens:
|
||||||
|
tokens = tokens[:target_tokens]
|
||||||
|
text = tokenizer.decode(tokens)
|
||||||
|
elif len(tokens) < target_tokens:
|
||||||
|
# Add more words
|
||||||
|
while len(tokenizer.encode(text)) < target_tokens:
|
||||||
|
text += "word "
|
||||||
|
tokens = tokenizer.encode(text)[:target_tokens]
|
||||||
|
text = tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
# Verify we have exactly target_tokens
|
||||||
|
assert len(tokenizer.encode(text)) == target_tokens, (
|
||||||
|
"Test setup: should have exactly 50 tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
result_tokens = len(tokenizer.encode(result[0]))
|
||||||
|
assert result_tokens <= target_tokens, (
|
||||||
|
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user