Compare commits

..

17 Commits

Author SHA1 Message Date
Andy Lee
253680043a fix: recompute args in searcher 2025-11-24 07:58:20 +00:00
Andy Lee
36c44b8806 fix: faster embed 2025-11-24 05:30:11 +00:00
Andy Lee
66c6aad3e4 refactor: embedding server 2025-11-19 06:54:10 +00:00
Andy Lee
29ef3c95dc refactor: embedding server 2025-11-19 06:50:39 +00:00
yichuan-w
469dce0045 add test 2025-11-12 08:46:05 +00:00
CalebZ9909
0ac676f9cb Add reproduction test script for Issue #159
- Test script to reproduce slow search performance issue
- Generates ~90K chunks (~180MB) similar to user's dataset
- Tests search performance with different complexity values (8, 16, 32, 64)
- Demonstrates that complexity=16-32 achieves ~2s search time
- Validates the performance analysis findings
2025-11-12 08:08:34 +00:00
CalebZ9909
97c9f39704 Add performance analysis and reproduction script for Issue #159
- Reproduced the slow search performance issue (15-30s vs expected ~2s)
- Identified root cause: default complexity=64 is too high for fast search
- Created test script demonstrating performance with different complexity values
- Test results show complexity=16-32 achieves ~2s search time (matching paper)
- Added comprehensive analysis document with solutions and recommendations

Key findings:
- Default complexity=64 results in ~36s search time
- Reducing complexity to 16-32 achieves ~2s search time
- beam_width parameter is mainly for DiskANN, not HNSW
- Paper likely used smaller embedding model (~100M) and lower complexity

Solutions provided:
1. Reduce complexity parameter to 16-32 for faster search
2. Consider DiskANN backend for better performance on large datasets
3. Use smaller embedding model if speed is critical
2025-11-12 08:03:48 +00:00
yichuan-w
3766ad1fd2 robust multi-vector 2025-11-09 02:34:53 +00:00
ww26
c3aceed1e0 metadata reveal for ast-chunking; smart detection of seq length in ollama; auto adjust chunk length for ast to prevent silent truncation (#157)
* feat: enhance token limits with dynamic discovery + AST metadata

Improves upon upstream PR #154 with two major enhancements:

1. **Hybrid Token Limit Discovery**
   - Dynamic: Query Ollama /api/show for context limits
   - Fallback: Registry for LM Studio/OpenAI
   - Zero maintenance for Ollama users
   - Respects custom num_ctx settings

2. **AST Metadata Preservation**
   - create_ast_chunks() returns dict format with metadata
   - Preserves file_path, file_name, timestamps
   - Includes astchunk metadata (line numbers, node counts)
   - Fixes content extraction bug (checks "content" key)
   - Enables --show-metadata flag

3. **Better Token Limits**
   - nomic-embed-text: 2048 tokens (vs 512)
   - nomic-embed-text-v1.5: 2048 tokens
   - Added OpenAI models: 8192 tokens

4. **Comprehensive Tests**
   - 11 tests for token truncation
   - 545 new lines in test_astchunk_integration.py
   - All metadata preservation tests passing

* fix: merge EMBEDDING_MODEL_LIMITS and remove redundant validation

- Merged upstream's model list with our corrected token limits
- Kept our corrected nomic-embed-text: 2048 (not 512)
- Removed post-chunking validation (redundant with embedding-time truncation)
- All tests passing except 2 pre-existing integration test failures

* style: apply ruff formatting and restore PR #154 version handling

- Remove duplicate truncate_to_token_limit and get_model_token_limit functions
- Restore version handling logic (model:latest -> model) from PR #154
- Restore partial matching fallback for model name variations
- Apply ruff formatting to all modified files
- All 11 token truncation tests passing

* style: sort imports alphabetically (pre-commit auto-fix)

* fix: show AST token limit warning only once per session

- Add module-level flag to track if warning shown
- Prevents spam when processing multiple files
- Add clarifying note that auto-truncation happens at embedding time
- Addresses issue where warning appeared for every code file

* enhance: add detailed logging for token truncation

- Track and report truncation statistics (count, tokens removed, max length)
- Show first 3 individual truncations with exact token counts
- Provide comprehensive summary when truncation occurs
- Use WARNING level for data loss visibility
- Silent (DEBUG level only) when no truncation needed

Replaces misleading "truncated where necessary" message that appeared
even when nothing was truncated.
2025-11-08 17:37:31 -08:00
yichuan-w
dc6c9f696e update some search in copali 2025-11-08 08:53:03 +00:00
CalebZ9909
2406c41eef Update faiss submodule to latest commit
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-08 00:47:21 +00:00
Andy Lee
d4f5f2896f Faster Update (#148)
* stash

* stash

* add std err in add and trace progress

* fix.

* docs

* style: format

* docs

* better figs

* better figs

* update results

* fotmat

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-11-05 13:37:47 -08:00
Aakash Suresh
366984e92e Merge pull request #154 from yichuan-w/fix/chunking-token-limit-behavior
Fix/chunking token limit behavior
2025-11-02 21:37:47 -08:00
aakash
64b92a04a7 fixing chunking token issues within limit for embedding models 2025-10-31 17:15:00 -07:00
ww26
a85d0ad4a7 Feature/optimize ollama batching (#152)
* feat: add metadata output to search results

- Add --show-metadata flag to display file paths in search results
- Preserve document metadata (file_path, file_name, timestamps) during chunking
- Update MCP tool schema to support show_metadata parameter
- Enhance CLI search output to display metadata when requested
- Fix pre-existing bug: args.backend -> args.backend_name

Resolves yichuan-w/LEANN#144

* fix: resolve ZMQ linking issues in Python extension

- Use pkg_check_modules IMPORTED_TARGET to create PkgConfig::ZMQ
- Set PKG_CONFIG_PATH to prioritize ARM64 Homebrew on Apple Silicon
- Override macOS -undefined dynamic_lookup to force proper symbol resolution
- Use PUBLIC linkage for ZMQ in faiss library for transitive linking
- Mark cppzmq includes as SYSTEM to suppress warnings

Fixes editable install ZMQ symbol errors while maintaining compatibility
across Linux, macOS Intel, and macOS ARM64 platforms.

* style: apply ruff formatting

* chore: update faiss submodule to use ww2283 fork

Use ww2283/faiss fork with fix/zmq-linking branch to resolve CI checkout
failures. The ZMQ linking fixes are not yet merged upstream.

* feat: implement true batch processing for Ollama embeddings

Migrate from deprecated /api/embeddings to modern /api/embed endpoint
which supports batch inputs. This reduces HTTP overhead by sending
32 texts per request instead of making individual API calls.

Changes:
- Update endpoint from /api/embeddings to /api/embed
- Change parameter from 'prompt' (single) to 'input' (array)
- Update response parsing for batch embeddings array
- Increase timeout to 60s for batch processing
- Improve error handling for batch requests

Performance:
- Reduces API calls by 32x (batch size)
- Eliminates HTTP connection overhead per text
- Note: Ollama still processes batch items sequentially internally

Related: #151

* fall back to original faiss as i merge the PR

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
2025-10-30 16:39:14 -07:00
yichuan-w
dbb5f4d352 Fix CI failure by removing paru-bin submodule
Remove paru-bin directory that was incorrectly added as a git submodule.
This directory is an AUR build artifact and should not be tracked.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-25 14:51:06 -07:00
yichuan-w
f180b83589 add deep wiki 2025-10-25 14:46:17 -07:00
26 changed files with 8640 additions and 4401 deletions

3
.gitignore vendored
View File

@@ -105,3 +105,6 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

110
ISSUE_159_CONCLUSION.md Normal file
View File

@@ -0,0 +1,110 @@
# Issue #159 Performance Analysis - Conclusion
## Problem Summary
User reported search times of 15-30 seconds instead of the ~2 seconds mentioned in the paper.
**Configuration:**
- GPU: 4090×1
- Embedding Model: BAAI/bge-large-zh-v1.5 (~300M parameters)
- Data Size: 180MB text (~90K chunks)
- Backend: HNSW
- beam_width: 10
- Other parameters: Default values
## Root Cause Analysis
### 1. **Search Complexity Parameter**
The **default `complexity` parameter is 64**, which is too high for achieving ~2 second search times with this configuration.
**Test Results (Reproduced):**
- **Complexity 64 (default)**: **36.17 seconds**
- **Complexity 32**: **2.49 seconds**
- **Complexity 16**: **2.24 seconds** ✅ (Close to paper's ~2 seconds)
- **Complexity 8**: **1.67 seconds**
### 2. **beam_width Parameter**
The `beam_width` parameter is **mainly for DiskANN backend**, not HNSW. Setting it to 10 has minimal/no effect on HNSW search performance.
### 3. **Embedding Model Size**
The paper uses a smaller embedding model (~100M parameters), while the user is using `BAAI/bge-large-zh-v1.5` (~300M parameters). This contributes to slower embedding computation during search, but the main bottleneck is the search complexity parameter.
## Solution
### **Recommended Fix: Reduce Search Complexity**
To achieve search times close to ~2 seconds, use:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
query="your query",
top_k=10,
complexity=16, # or complexity=32 for slightly better accuracy
# beam_width parameter doesn't affect HNSW, can be ignored
)
```
Or via CLI:
```bash
leann search your-index "your query" --complexity 16
```
### **Alternative Solutions**
1. **Use DiskANN Backend** (Recommended by maintainer)
- DiskANN is faster for large datasets
- Better performance scaling
- `beam_width` parameter is relevant here
```python
builder = LeannBuilder(backend_name="diskann")
```
2. **Use Smaller Embedding Model**
- Switch to a smaller model (~100M parameters) like the paper
- Faster embedding computation
- Example: `BAAI/bge-base-zh-v1.5` instead of `bge-large-zh-v1.5`
3. **Disable Recomputation** (Trade storage for speed)
- Use `--no-recompute` flag
- Stores all embeddings (much larger storage)
- Faster search (no embedding recomputation)
```bash
leann build your-index --no-recompute --no-compact
leann search your-index "query" --no-recompute
```
## Performance Comparison
| Complexity | Search Time | Accuracy | Recommendation |
|------------|-------------|----------|---------------|
| 64 (default) | ~36s | Highest | ❌ Too slow |
| 32 | ~2.5s | High | ✅ Good balance |
| 16 | ~2.2s | Good | ✅ **Recommended** (matches paper) |
| 8 | ~1.7s | Lower | ⚠️ May sacrifice accuracy |
## Key Takeaways
1. **The default `complexity=64` is optimized for accuracy, not speed**
2. **For ~2 second search times, use `complexity=16` or `complexity=32`**
3. **`beam_width` parameter is for DiskANN, not HNSW**
4. **The paper's ~2 second results likely used:**
- Smaller embedding model (~100M params)
- Lower complexity (16-32)
- Possibly DiskANN backend
## Verification
The issue has been reproduced and verified. The test script `test_issue_159.py` demonstrates:
- Default complexity (64) results in ~36 second search times
- Reducing complexity to 16-32 achieves ~2 second search times
- This matches the user's reported issue and provides a clear solution
## Next Steps
1. ✅ Issue reproduced and root cause identified
2. ✅ Solution provided (reduce complexity parameter)
3. ⏳ User should test with `complexity=16` or `complexity=32`
4. ⏳ Consider updating documentation to clarify complexity parameter trade-offs

View File

@@ -1213,3 +1213,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>
## 🤖 Explore LEANN with AI
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

View File

@@ -180,14 +180,14 @@ class BaseRAGExample(ABC):
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
)
ast_group.add_argument(
"--code-file-extensions",

View File

@@ -12,6 +12,7 @@ from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
__all__ = [
"CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",

View File

@@ -1,12 +1,18 @@
from __future__ import annotations
import concurrent.futures
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
@@ -16,6 +22,380 @@ def _ensure_repo_paths_importable(current_file: str) -> None:
sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
@@ -45,6 +425,7 @@ class LeannMultiVector:
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict:
return {
@@ -69,6 +450,7 @@ class LeannMultiVector:
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
}
)
@@ -80,6 +462,15 @@ class LeannMultiVector:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None:
if not self._pending_items:
return
@@ -87,10 +478,23 @@ class LeannMultiVector:
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
@@ -100,6 +504,7 @@ class LeannMultiVector:
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
"image_path": image_path, # Store the path to the saved image
}
)
@@ -107,7 +512,6 @@ class LeannMultiVector:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
@@ -121,6 +525,9 @@ class LeannMultiVector:
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
@@ -133,6 +540,19 @@ class LeannMultiVector:
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
@@ -180,3 +600,181 @@ class LeannMultiVector:
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None

View File

@@ -2,34 +2,31 @@
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable,
_load_images_from_dir,
_maybe_convert_pdf_to_images,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_generate_similarity_map,
QwenVL,
)
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
@@ -44,7 +41,7 @@ PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
@@ -54,332 +51,57 @@ SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
MAX_NEW_TOKENS: int = 1024
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
# Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None
need_to_build_index = REBUILD_INDEX
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index
images = [] # Not needed when using existing index
# %%
# Step 2: Load model and processor
# Step 3: Load model and processor (only if we need to build index or perform search)
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
@@ -387,34 +109,39 @@ print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
# Step 4: Build index if needed
if need_to_build_index and retriever is None:
print("Building index...")
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk
# %%
# Step 4: Embed query and search
# Step 5: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# Retrieve image from index instead of memory
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
top_images.append(image)
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
@@ -427,12 +154,17 @@ else:
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
# Print the retrieval score (document-level MaxSim) alongside the saved path
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
# Step 6: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
@@ -469,7 +201,7 @@ if results and SIMILARITY_MAP:
# %%
# Step 6: Optional answer generation
# Step 7: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)

143
benchmarks/update/README.md Normal file
View File

@@ -0,0 +1,143 @@
# Update Benchmarks
This directory hosts two benchmark suites that exercise LEANNs HNSW “update +
search” pipeline under different assumptions:
1. **RNG recompute latency** measure how random-neighbour pruning and cache
settings influence incremental `add()` latency when embeddings are fetched
over the ZMQ embedding server.
2. **Update strategy comparison** compare a fully sequential update pipeline
against an offline approach that keeps the graph static and fuses results.
Both suites build a non-compact, `is_recompute=True` index so that new
embeddings are pulled from the embedding server. Benchmark outputs are written
under `.leann/bench/` by default and appended to CSV files for later plotting.
## Benchmarks
### 1. HNSW RNG Recompute Benchmark
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
changes the forward / reverse RNG pruning flags and whether the embedding cache
is enabled:
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
| ---------------------------------- | ----------- | ----------- | ------------------- |
| `baseline` | Enabled | Enabled | Enabled |
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
For each scenario the script:
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
3. Appends the requested updates using the scenarios RNG flags.
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
timings before appending a row to the CSV output.
**Run:**
```bash
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
LEANN_LOG_LEVEL=INFO \
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--runs 1 \
--index-path .leann/bench/test.leann \
--initial-files data/PrideandPrejudice.txt \
--update-files data/huawei_pangu.md \
--max-initial 300 \
--max-updates 1 \
--add-timeout 120
```
**Output:**
- `benchmarks/update/bench_results.csv` per-scenario timing statistics
(including ms/passage) for each run.
- `.leann/bench/hnsw_server.log` detailed ZMQ/server logs (path controlled by
`LEANN_HNSW_LOG_PATH`).
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
### 2. Sequential vs. Offline Update Benchmark
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
same dataset:
- **Scenario A Sequential Update**
- Start an embedding server.
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
mutates the HNSW graph.
- After all inserts, run a search on the updated graph.
- Metrics recorded: update time (`add_total_s`), post-update search time
(`search_time_s`), combined total (`total_time_s`), and per-passage
latency.
- **Scenario B Offline Embedding + Concurrent Search**
- Stop Scenario As server and start a fresh embedding server.
- Spawn two threads: one generates embeddings for the new passages offline
(graph unchanged); the other computes the query embedding and searches the
existing graph.
- Merge offline similarities with the graph search results to emulate late
fusion, then report the merged topk preview.
- Metrics recorded: embedding time (`emb_time_s`), search time
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
**Run (both scenarios):**
```bash
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 \
--num-updates 1
```
You can pass `--only A` or `--only B` to run a single scenario. The script will
print timing summaries to stdout and append the results to CSV.
**Output:**
- `benchmarks/update/offline_vs_update.csv` per-scenario timing statistics for
Scenario A and B.
- Console output includes Scenario Bs merged topk preview for quick sanity
checks.
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
### 3. Visualisation
`plot_bench_results.py` combines the RNG benchmark and the update strategy
benchmark into a single two-panel plot.
**Run:**
```bash
uv run -m benchmarks.update.plot_bench_results \
--csv benchmarks/update/bench_results.csv \
--csv-right benchmarks/update/offline_vs_update.csv \
--out benchmarks/update/bench_latency_from_csv.png
```
**Options:**
- `--broken-y` Enable a broken Y-axis (default: true when appropriate).
- `--csv` RNG benchmark results CSV (left panel).
- `--csv-right` Update strategy results CSV (right panel).
- `--out` Output image path (PNG/PDF supported).
**Output:**
- `benchmarks/update/bench_latency_from_csv.png` visual comparison of the two
suites.
- `benchmarks/update/bench_latency_from_csv.pdf` PDF version, suitable for
slides/papers.
## Parameters & Environment
### Common CLI Flags
- `--max-initial` Number of initial passages used to seed the index.
- `--max-updates` / `--num-updates` Number of passages to treat as updates.
- `--index-path` Base path (without extension) where the LEANN index is stored.
- `--runs` Number of repetitions (RNG benchmark only).
### Environment Variables
- `LEANN_HNSW_LOG_PATH` File to receive embedding-server logs (optional).
- `LEANN_LOG_LEVEL` Logging verbosity (DEBUG/INFO/WARNING/ERROR).
- `CUDA_VISIBLE_DEVICES` Set to empty string if you want to force CPU
execution of the embedding model.
With these scripts you can easily replicate LEANNs update benchmarks, compare
multiple RNG strategies, and evaluate whether sequential updates or offline
fusion better match your latency/accuracy trade-offs.

View File

@@ -0,0 +1,16 @@
"""Benchmarks for LEANN update workflows."""
# Expose helper to locate repository root for other modules that need it.
from pathlib import Path
def find_repo_root() -> Path:
"""Return the project root containing pyproject.toml."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
return current.parents[1]
__all__ = ["find_repo_root"]

View File

@@ -0,0 +1,804 @@
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario cache_enabled ef_construction max_initial max_updates total_time_s add_only_s latency_ms_per_passage zmq_nodes stageA_time_s stageBC_time_s model_name embedding_mode distance_metric
2 20251024-133101 baseline 1 200 300 1 3.391856 1.120359 1120.359421 126 0.507821 0.601608 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-133101 no_cache_baseline 0 200 300 1 34.941514 32.91376 32913.760185 4033 0.506933 32.159928 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251024-133101 disable_forward_rng 1 200 300 1 2.746756 0.8202 820.200443 66 0.474354 0.338454 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251024-133101 disable_forward_and_reverse_rng 1 200 300 1 2.396566 0.521478 521.478415 1 0.508973 0.006938 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,704 @@
"""
Compare two latency models for small incremental updates vs. search:
Scenario A (sequential update then search):
- Build initial HNSW (is_recompute=True)
- Start embedding server (ZMQ) for recompute
- Add N passages one-by-one (each triggers recompute over ZMQ)
- Then run a search query on the updated index
- Report total time = sum(add_i) + search_time, with breakdowns
Scenario B (offline embeds + concurrent search; no graph updates):
- Do NOT insert the N passages into the graph
- In parallel: (1) compute embeddings for the N passages; (2) compute query
embedding and run a search on the existing index
- After both finish, compute similarity between the query embedding and the N
new passage embeddings, merge with the index search results by score, and
report time = max(embed_time, search_time) (i.e., no blocking on updates)
This script reuses the model/data loading conventions of
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
comparison for the two execution strategies above.
Example (from the repository root):
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 --num-updates 5 --k 10
"""
import argparse
import csv
import json
import logging
import os
import pickle
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import psutil # type: ignore
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
if metric == "cosine":
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
norms[norms == 0] = 1
vecs = vecs / norms
return vecs
def _read_index_for_search(index_path: Path) -> Any:
index_file = index_path.parent / f"{index_path.stem}.index"
# Force-disable experimental disk cache when loading the index so that
# incremental benchmarks don't pick up stale top-degree bitmaps.
cfg = faiss.HNSWIndexConfig()
cfg.is_recompute = True
if hasattr(cfg, "disk_cache_ratio"):
cfg.disk_cache_ratio = 0.0
if hasattr(cfg, "external_storage_path"):
cfg.external_storage_path = None
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
index = faiss.read_index(str(index_file), io_flags, cfg)
# ensure recompute mode persists after reload
try:
index.is_recompute = True
except AttributeError:
pass
try:
actual_ntotal = index.hnsw.levels.size()
except AttributeError:
actual_ntotal = index.ntotal
if actual_ntotal != index.ntotal:
print(
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
flush=True,
)
index.ntotal = actual_ntotal
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
return index
def _append_passages_for_updates(
meta_path: Path,
start_id: int,
texts: list[str],
) -> list[str]:
"""Append update passages so the embedding server can serve recompute fetches."""
if not texts:
return []
index_dir = meta_path.parent
meta_name = meta_path.name
if not meta_name.endswith(".meta.json"):
raise ValueError(f"Unexpected meta filename: {meta_path}")
index_base = meta_name[: -len(".meta.json")]
passages_file = index_dir / f"{index_base}.passages.jsonl"
offsets_file = index_dir / f"{index_base}.passages.idx"
if not passages_file.exists() or not offsets_file.exists():
raise FileNotFoundError(
"Passage store missing; cannot register update passages for recompute mode."
)
with open(offsets_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for i, text in enumerate(texts):
passage_id = str(start_id + i)
offset = f.tell()
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[passage_id] = offset
assigned_ids.append(passage_id)
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
try:
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
except json.JSONDecodeError:
meta = {}
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
return assigned_ids
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
q = np.ascontiguousarray(q, dtype=np.float32)
distances = np.zeros((1, k), dtype=np.float32)
indices = np.zeros((1, k), dtype=np.int64)
index.search(
1,
faiss.swig_ptr(q),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
return distances[0], indices[0]
def _score_for_metric(dist: float, metric: str) -> float:
# Convert FAISS distance to a "higher is better" score
if metric in ("mips", "cosine"):
return float(dist)
# l2 distance (smaller better) -> negative distance as score
return -float(dist)
def _merge_results(
index_results: tuple[np.ndarray, np.ndarray],
offline_scores: list[tuple[int, float]],
k: int,
metric: str,
) -> list[tuple[str, float]]:
distances, indices = index_results
merged: list[tuple[str, float]] = []
for distance, idx in zip(distances.tolist(), indices.tolist()):
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
for j, s in offline_scores:
merged.append((f"offline:{j}", s))
merged.sort(key=lambda x: x[1], reverse=True)
return merged[:k]
@dataclass
class ScenarioResult:
name: str
update_total_s: float
search_s: float
overall_s: float
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/offline-vs-update.leann"),
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
)
parser.add_argument("--max-initial", type=int, default=300)
parser.add_argument("--num-updates", type=int, default=5)
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
parser.add_argument(
"--query",
type=str,
default="neural network",
help="Query text used for the search benchmark.",
)
parser.add_argument("--server-port", type=int, default=5557)
parser.add_argument("--add-timeout", type=int, default=600)
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--embedding-mode", default="sentence-transformers")
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
)
parser.add_argument("--ef-construction", type=int, default=200)
parser.add_argument(
"--only",
choices=["A", "B", "both"],
default="both",
help="Run only Scenario A, Scenario B, or both",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Where to append results (CSV).",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
# Load data
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, None)
if not update_paragraphs:
raise ValueError("No update passages loaded from --update-files")
update_paragraphs = update_paragraphs[: args.num_updates]
if len(update_paragraphs) < args.num_updates:
raise ValueError(
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
)
ensure_index_dir(args.index_path)
cleanup_index_files(args.index_path)
# Build initial index
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
# Prepare index object and meta
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
index = _read_index_for_search(args.index_path)
# CSV setup
run_id = time.strftime("%Y%m%d-%H%M%S")
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_fields = [
"run_id",
"scenario",
"max_initial",
"num_updates",
"k",
"total_time_s",
"add_total_s",
"search_time_s",
"emb_time_s",
"makespan_s",
"model_name",
"embedding_mode",
"distance_metric",
]
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
# Debug: list existing HNSW server PIDs before starting
try:
existing = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if existing:
print("[debug] Found existing hnsw_embedding_server processes before run:")
for p in existing:
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
except Exception as _e:
pass
add_total = 0.0
search_after_add = 0.0
total_seq = 0.0
port_a = None
if args.only in ("A", "both"):
# Scenario A: sequential update then search
start_id = index.ntotal
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
if assigned_ids:
logger.debug(
"Registered %d update passages starting at id %s",
len(assigned_ids),
assigned_ids[0],
)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
ok, port = server_manager.start_server(
port=args.server_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok:
raise RuntimeError("Failed to start embedding server")
try:
# Set ZMQ port for recompute mode
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(port)
# Start A overall timer BEFORE computing update embeddings
t0 = time.time()
# Compute embeddings for updates (counted into A's overall)
t_emb0 = time.time()
upd_embs = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time_updates = time.time() - t_emb0
upd_embs = np.asarray(upd_embs, dtype=np.float32)
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
# Perform sequential adds
for i in range(upd_embs.shape[0]):
t_add0 = time.time()
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
add_total += time.time() - t_add0
# Don't persist index after adds to avoid contaminating Scenario B
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
# faiss.write_index(index, str(index_file))
# Search after updates
q_emb = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_emb = np.asarray(q_emb, dtype=np.float32)
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
# Warm up search with a dummy query first
print("[DEBUG] Warming up search...")
_ = _search(index, q_emb, 1)
t_s0 = time.time()
D_upd, I_upd = _search(index, q_emb, args.k)
search_after_add = time.time() - t_s0
total_seq = time.time() - t0
finally:
server_manager.stop_server()
port_a = port
print("\n=== Scenario A: update->search (sequential) ===")
# emb_time_updates is defined only when A runs
try:
_emb_a = emb_time_updates
except NameError:
_emb_a = 0.0
print(
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
)
# CSV row for A
if args.csv_path:
row_a = {
"run_id": run_id,
"scenario": "A",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": round(total_seq, 6),
"add_total_s": round(add_total, 6),
"search_time_s": round(search_after_add, 6),
"emb_time_s": round(_emb_a, 6),
"makespan_s": 0.0,
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_a)
# Verify server cleanup
try:
# short sleep to allow signal handling to finish
time.sleep(0.5)
leftovers = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if leftovers:
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
for p in leftovers:
print(
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
)
else:
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
except Exception:
pass
# Scenario B: offline embeds + concurrent search (no graph updates)
if args.only in ("B", "both"):
# ensure a server is available for recompute search
server_manager_b = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
requested_port = args.server_port if port_a is None else port_a
ok_b, port_b = server_manager_b.start_server(
port=requested_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok_b:
raise RuntimeError("Failed to start embedding server for Scenario B")
# Wait for server to fully initialize
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
time.sleep(2)
try:
# Read the index first
index_no_update = _read_index_for_search(args.index_path) # unchanged index
# Then configure ZMQ port on the correct index object
if hasattr(index_no_update.hnsw, "set_zmq_port"):
index_no_update.hnsw.set_zmq_port(port_b)
elif hasattr(index_no_update, "set_zmq_port"):
index_no_update.set_zmq_port(port_b)
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
logger.info("Warming up embedding model for Scenario B...")
_ = compute_embeddings(
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
)
# Prepare worker A: compute embeddings for the same N passages
emb_time = 0.0
updates_embs_offline: np.ndarray | None = None
def _worker_emb():
nonlocal emb_time, updates_embs_offline
t = time.time()
updates_embs_offline = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time = time.time() - t
# Pre-compute query embedding and warm up search outside of timed section.
q_vec = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_vec = np.asarray(q_vec, dtype=np.float32)
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
print("[DEBUG B] Warming up search...")
_ = _search(index_no_update, q_vec, 1)
# Worker B: timed search on the warmed index
search_time = 0.0
offline_elapsed = 0.0
index_results: tuple[np.ndarray, np.ndarray] | None = None
def _worker_search():
nonlocal search_time, index_results
t = time.time()
distances, indices = _search(index_no_update, q_vec, args.k)
search_time = time.time() - t
index_results = (distances, indices)
# Run two workers concurrently
t0 = time.time()
th1 = threading.Thread(target=_worker_emb)
th2 = threading.Thread(target=_worker_search)
th1.start()
th2.start()
th1.join()
th2.join()
offline_elapsed = time.time() - t0
# For mixing: compute query vs. offline update similarities (pure client-side)
offline_scores: list[tuple[int, float]] = []
if updates_embs_offline is not None:
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
for j in range(upd2.shape[0]):
if args.distance_metric in ("mips", "cosine"):
s = float(np.dot(q_vec[0], upd2[j]))
else:
diff = q_vec[0] - upd2[j]
s = -float(np.dot(diff, diff))
offline_scores.append((j, s))
merged_topk = (
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
if index_results
else []
)
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
print(
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
)
if merged_topk:
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
print(f"Merged top-5 preview: {preview}")
# CSV row for B
if args.csv_path:
row_b = {
"run_id": run_id,
"scenario": "B",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": 0.0,
"add_total_s": 0.0,
"search_time_s": round(search_time, 6),
"emb_time_s": round(emb_time, 6),
"makespan_s": round(offline_elapsed, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_b)
finally:
server_manager_b.stop_server()
# Summary
print("\n=== Summary ===")
msg_a = (
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
if args.only in ("A", "both")
else "A: skipped"
)
msg_b = (
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
if args.only in ("B", "both")
else "B: skipped"
)
print(msg_a + "\n" + msg_b)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,5 @@
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario max_initial num_updates k total_time_s add_total_s search_time_s emb_time_s makespan_s model_name embedding_mode distance_metric
2 20251024-141607 A 300 1 10 3.273957 3.050168 0.097825 0.017339 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-141607 B 300 1 10 0.0 0.0 0.111892 0.007869 0.112635 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251025-160652 A 300 5 10 5.061945 4.805962 0.123271 0.015008 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251025-160652 B 300 5 10 0.0 0.0 0.101809 0.008817 0.102447 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -0,0 +1,645 @@
#!/usr/bin/env python3
"""
Plot latency bars from the benchmark CSV produced by
benchmarks/update/bench_hnsw_rng_recompute.py.
If you also provide an offline_vs_update.csv via --csv-right
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
output a side-by-side figure:
- Left: ms/passage bars (four RNG scenarios).
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
Usage:
uv run python benchmarks/update/plot_bench_results.py \
--csv benchmarks/update/bench_results.csv \
--out benchmarks/update/bench_latency_from_csv.png
The script selects the latest run_id in the CSV and plots four bars for
the default scenarios:
- baseline
- no_cache_baseline
- disable_forward_rng
- disable_forward_and_reverse_rng
If multiple rows exist per scenario for that run_id, the script averages
their latency_ms_per_passage values.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
DEFAULT_SCENARIOS = [
"no_cache_baseline",
"baseline",
"disable_forward_rng",
"disable_forward_and_reverse_rng",
]
SCENARIO_LABELS = {
"baseline": "+ Cache",
"no_cache_baseline": "Naive \n Recompute",
"disable_forward_rng": "+ w/o \n Fwd RNG",
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
}
# Paper-style colors and hatches for scenarios
SCENARIO_STYLES = {
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
}
def load_latest_run(csv_path: Path):
rows = []
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if not rows:
raise SystemExit("CSV is empty: no rows to plot")
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
run_ids = [r.get("run_id", "") for r in rows]
latest = max(run_ids)
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
if not latest_rows:
# Fallback: take last 4 rows
latest_rows = rows[-4:]
latest = latest_rows[-1].get("run_id", "unknown")
return latest, latest_rows
def aggregate_latency(rows):
acc = defaultdict(list)
for r in rows:
sc = r.get("scenario", "")
try:
val = float(r.get("latency_ms_per_passage", "nan"))
except ValueError:
continue
acc[sc].append(val)
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
return avg
def _auto_cap(values: list[float]) -> float | None:
if not values:
return None
sorted_vals = sorted(values, reverse=True)
if len(sorted_vals) < 2:
return None
max_v, second = sorted_vals[0], sorted_vals[1]
if second <= 0:
return None
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
if max_v >= 2.5 * second:
return second * 1.1
return None
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
# Draw small diagonal ticks near left/right to signal cap
x0, x1 = rel_x0, rel_x1
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
def _fmt_ms(v: float) -> str:
if v >= 1000:
return f"{v / 1000:.1f}k"
return f"{v:.1f}"
def main():
# Set LaTeX style for paper figures (matching paper_fig.py)
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--csv",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Path to results CSV (defaults to bench_results.csv)",
)
ap.add_argument(
"--out",
type=Path,
default=Path("add_ablation.pdf"),
help="Output image path",
)
ap.add_argument(
"--csv-right",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
)
ap.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
)
ap.add_argument(
"--no-auto-cap",
action="store_true",
help="Disable auto-cap heuristic when --cap-y is not provided.",
)
ap.add_argument(
"--broken-y",
action="store_true",
default=True,
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
)
ap.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
)
ap.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
)
args = ap.parse_args()
latest_run, latest_rows = load_latest_run(args.csv)
avg = aggregate_latency(latest_rows)
try:
import matplotlib.pyplot as plt
except Exception as e:
raise SystemExit(f"matplotlib not available: {e}")
scenarios = DEFAULT_SCENARIOS
values = [avg.get(name, 0.0) for name in scenarios]
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
# If right CSV is provided, build side-by-side figure
if args.csv_right is not None:
try:
right_rows_all = []
with args.csv_right.open("r", encoding="utf-8") as f:
rreader = csv.DictReader(f)
right_rows_all = list(rreader)
if right_rows_all:
r_latest = max(r.get("run_id", "") for r in right_rows_all)
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
else:
r_latest = None
right_rows = []
except Exception:
r_latest = None
right_rows = []
a_total = 0.0
b_makespan = 0.0
for r in right_rows:
sc = (r.get("scenario", "") or "").strip().upper()
if sc == "A":
try:
a_total = float(r.get("total_time_s", 0.0))
except Exception:
pass
elif sc == "B":
try:
b_makespan = float(r.get("makespan_s", 0.0))
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Left subplot (reuse current style, with optional cap)
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
x = list(range(len(labels)))
if args.broken_y:
# Use broken axis for left subplot
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
gs = gridspec.GridSpec(
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
)
ax_left_top = fig.add_subplot(gs[0, 0])
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
ax_right = fig.add_subplot(gs[:, 1])
# Determine break points
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = (
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
) # Increased to show more range
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.5, lower_cap * 1.02)
)
ymax = (
max(values) * 1.90 if values else 1.0
) # Increase headroom to 1.90 for text label and tick range
# Draw bars on both axes
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Set limits
ax_left_bottom.set_ylim(0, lower_cap)
ax_left_top.set_ylim(upper_start, ymax)
# Annotate values (convert ms to s)
values_s = [v / 1000.0 for v in values]
lower_cap_s = lower_cap / 1000.0
upper_start_s = upper_start / 1000.0
ymax_s = ymax / 1000.0
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
ax_left_bottom.clear()
ax_left_top.clear()
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
# Draw in bottom axis for all bars
ax_left_bottom.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
# Only draw in top axis if the bar is tall enough to reach the upper range
if v > upper_start_s:
ax_left_top.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
for i, v in enumerate(values_s):
if v <= lower_cap_s:
ax_left_bottom.text(
i,
v + lower_cap_s * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
else:
ax_left_top.text(
i,
v + (ymax_s - upper_start_s) * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
# Hide spines between axes
ax_left_top.spines["bottom"].set_visible(False)
ax_left_bottom.spines["top"].set_visible(False)
ax_left_top.tick_params(
labeltop=False, labelbottom=False, bottom=False
) # Hide tick marks
ax_left_bottom.xaxis.tick_bottom()
ax_left_bottom.tick_params(top=False) # Hide top tick marks
# Draw break marks (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_left_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_left_bottom.transAxes})
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.set_xticks(x)
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_left_bottom.tick_params(axis="y", labelsize=10)
ax_left_top.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match bar width with right subplot
ax_left_bottom.set_xlim(-0.6, 3.6)
ax_left_top.set_xlim(-0.6, 3.6)
ax_left = ax_left_bottom # for compatibility
else:
# Regular side-by-side layout
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
for i, (val, show) in enumerate(zip(values, show_vals)):
if val > cap:
bars[i].set_hatch("//")
ax_left.text(
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
)
else:
ax_left.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(val),
ha="center",
va="bottom",
fontsize=9,
)
ax_left.set_ylim(0, cap * 1.10)
_add_break_marker(ax_left, y=0.98)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
else:
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
for i, v in enumerate(values):
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
ax_left.set_ylabel("Latency (ms per passage)")
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
ax_left.set_title(
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
)
# Right subplot (A vs B, seconds) - paper style
r_labels = ["Sequential", "Delayed \n Add+Search"]
r_values = [a_total or 0.0, b_makespan or 0.0]
r_styles = [
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
{"edgecolor": "#edc948", "hatch": "/////"},
]
# 2 bars, centered with proper spacing
xr = [0, 1]
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (v, style) in enumerate(zip(r_values, r_styles)):
ax_right.bar(
xr[i],
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
for i, v in enumerate(r_values):
max_v = max(r_values) if r_values else 1.0
offset = max(0.0002, 0.02 * max_v)
ax_right.text(
xr[i],
v + offset,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
ax_right.set_xticks(xr)
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_right.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match left subplot's bar width visually
# Accounting for width_ratios=[1.5, 1]:
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
# Right: 2 bars, need same visual width
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
# range_right = 4.2 / 1.5 = 2.8
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
ax_right.set_xlim(-0.9, 1.9)
# Set y-axis limit with headroom for text labels
if r_values:
max_v = max(r_values)
ax_right.set_ylim(0, max_v * 1.15)
# Format y-axis to avoid scientific notation
ax_right.ticklabel_format(style="plain", axis="y")
plt.tight_layout()
# Add aligned ylabels using fig.text (after tight_layout)
# Get the vertical center of the entire figure
fig_center_y = 0.5
# Left ylabel - closer to left plot
left_x = 0.05
fig.text(
left_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
# Right ylabel - closer to right plot
right_bbox = ax_right.get_position()
right_x = right_bbox.x0 - 0.07
fig.text(
right_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
return
# Broken-Y mode
if args.broken_y:
import matplotlib.pyplot as plt
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.5, 6.75),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
)
# Determine default breaks from second-highest
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Limits
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
# Annotate values
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
# Hide spines between axes and draw diagonal break marks
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
ax_bottom.xaxis.tick_bottom()
# Diagonal lines at the break (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
ax_bottom.set_xticks(x)
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
ax = ax_bottom # for labeling below
else:
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
plt.figure(figsize=(5.4, 3.15))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
bar = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(bar[0])
# Hatch and annotate when capped
if val > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
_add_break_marker(ax, y=0.98)
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
v > cap for v in values
) else None
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(
idx,
val + 1.0,
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
# Try to extract some context for title
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
if args.broken_y:
fig.text(
0.02,
0.5,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
fig.suptitle(
"Add Operation Latency",
fontsize=11,
y=0.98,
fontweight="bold",
)
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
else:
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
plt.tight_layout()
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
if __name__ == "__main__":
main()

149
issue_159.py Normal file
View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python3
"""
Test script to reproduce issue #159: Slow search performance
Configuration:
- GPU: 4090×1
- embedding_model: BAAI/bge-large-zh-v1.5
- data size: 180M text (~90K chunks)
- beam_width: 10 (though this is mainly for DiskANN, not HNSW)
- backend: hnsw
"""
import os
import time
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher, SearchResult
os.environ["LEANN_LOG_LEVEL"] = "DEBUG"
# Configuration matching the issue
INDEX_PATH = "./test_issue_159.leann"
EMBEDDING_MODEL = "BAAI/bge-large-zh-v1.5"
BACKEND_NAME = "hnsw"
def generate_test_data(num_chunks=90000, chunk_size=2000):
"""Generate test data similar to 180MB text (~90K chunks)"""
# Each chunk is approximately 2000 characters
# 90K chunks * 2000 chars ≈ 180MB
chunks = []
base_text = (
"这是一个测试文档。LEANN是一个创新的向量数据库通过图基选择性重计算实现97%的存储节省。"
)
for i in range(num_chunks):
chunk = f"{base_text} 文档编号: {i}. " * (chunk_size // len(base_text) + 1)
chunks.append(chunk[:chunk_size])
return chunks
def test_search_performance():
"""Test search performance with different configurations"""
print("=" * 80)
print("Testing LEANN Search Performance (Issue #159)")
print("=" * 80)
meta_path = Path(f"{INDEX_PATH}.meta.json")
if meta_path.exists():
print(f"\n✓ Index already exists at {INDEX_PATH}")
print(" Skipping build phase. Delete the index to rebuild.")
else:
print("\n📦 Building index...")
print(f" Backend: {BACKEND_NAME}")
print(f" Embedding Model: {EMBEDDING_MODEL}")
print(" Generating test data (~90K chunks, ~180MB)...")
chunks = generate_test_data(num_chunks=90000)
print(f" Generated {len(chunks)} chunks")
print(f" Total text size: {sum(len(c) for c in chunks) / (1024 * 1024):.2f} MB")
builder = LeannBuilder(
backend_name=BACKEND_NAME,
embedding_model=EMBEDDING_MODEL,
)
print(" Adding chunks to builder...")
start_time = time.time()
for i, chunk in enumerate(chunks):
builder.add_text(chunk)
if (i + 1) % 10000 == 0:
print(f" Added {i + 1}/{len(chunks)} chunks...")
print(" Building index...")
build_start = time.time()
builder.build_index(INDEX_PATH)
build_time = time.time() - build_start
print(f" ✓ Index built in {build_time:.2f} seconds")
# Test search with different complexity values
print("\n🔍 Testing search performance...")
searcher = LeannSearcher(INDEX_PATH)
test_query = "LEANN向量数据库存储优化"
# Test with default complexity (64)
print("\n Test 1: Default complexity (64) `1 ")
print(f" Query: '{test_query}'")
start_time = time.time()
results: list[SearchResult] = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with default complexity (64)
print("\n Test 1: Default complexity (64)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=64)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with lower complexity (32)
print("\n Test 2: Lower complexity (32)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=32)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with even lower complexity (16)
print("\n Test 3: Lower complexity (16)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=16)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
# Test with minimal complexity (8)
print("\n Test 4: Minimal complexity (8)")
print(f" Query: '{test_query}'")
start_time = time.time()
results = searcher.search(test_query, top_k=10, complexity=8)
search_time = time.time() - start_time
print(f" ✓ Search completed in {search_time:.2f} seconds")
print(f" Results: {len(results)} items")
print("\n" + "=" * 80)
print("Performance Analysis:")
print("=" * 80)
print("\nKey Findings:")
print("1. beam_width parameter is mainly for DiskANN backend, not HNSW")
print("2. For HNSW, the main parameter affecting search speed is 'complexity'")
print("3. Lower complexity values (16-32) should provide faster search")
print("4. The paper mentions ~2 seconds, which likely uses:")
print(" - Smaller embedding model (~100M params vs 300M for bge-large)")
print(" - Lower complexity (16-32)")
print(" - Possibly DiskANN backend for better performance")
print("\nRecommendations:")
print("- Try complexity=16 or complexity=32 for faster search")
print("- Consider using DiskANN backend for better performance on large datasets")
print("- Or use a smaller embedding model if speed is critical")
if __name__ == "__main__":
test_search_performance()

View File

@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
if hasattr(self._index, "set_zmq_port"):
self._index.set_zmq_port(zmq_port)
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
@@ -158,225 +156,238 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
# Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_type = "unknown"
last_request_length = 0
def _build_safe_fallback():
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
return [[large_distance] * fallback_len]
if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else:
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
e2e_start = time.time()
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
except zmq.Again:
continue
# Rest of the processing logic (same as original)
try:
request = msgpack.unpackb(request_bytes)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
try:
# Model query
if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
_handle_text_embedding(request)
# Distance calculation: [[ids], [query_vector]]
elif (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
_handle_distance_request(request)
# Embedding-by-id fallback
else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally:
try:
rep_socket.close(0)

View File

@@ -820,10 +820,10 @@ class LeannBuilder:
actual_port,
requested_zmq_port,
)
try:
index.hnsw.zmq_port = actual_port
except AttributeError:
pass
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
if needs_recompute:
for i in range(embeddings.shape[0]):
@@ -864,7 +864,13 @@ class LeannBuilder:
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
def __init__(
self,
index_path: str,
enable_warmup: bool = True,
recompute_embeddings: bool = True,
**backend_kwargs,
):
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
@@ -895,14 +901,32 @@ class LeannSearcher:
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
# Global recompute flag for this searcher (explicit knob, default True)
self.recompute_embeddings: bool = bool(recompute_embeddings)
# Warmup flag: keep using the existing enable_warmup parameter,
# but default it to True so cold-start happens earlier.
self._warmup: bool = bool(enable_warmup)
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
final_kwargs["enable_warmup"] = self._warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
# Optional one-shot warmup at construction time to hide cold-start latency.
if self._warmup:
try:
_ = self.backend_impl.compute_query_embedding(
"__LEANN_WARMUP__",
use_server_if_available=self.recompute_embeddings,
)
except Exception as exc:
logger.warning(f"Warmup embedding failed (ignored): {exc}")
def search(
self,
query: str,
@@ -910,7 +934,7 @@ class LeannSearcher:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
recompute_embeddings: Optional[bool] = None,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
@@ -927,7 +951,8 @@ class LeannSearcher:
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
recompute_embeddings: (Deprecated) Per-call override for recompute mode.
Configure this at LeannSearcher(..., recompute_embeddings=...) instead.
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata.
@@ -966,8 +991,19 @@ class LeannSearcher:
zmq_port = None
# Resolve effective recompute flag for this search.
if recompute_embeddings is not None:
logger.warning(
"LeannSearcher.search(..., recompute_embeddings=...) is deprecated and "
"will be removed in a future version. Configure recompute at "
"LeannSearcher(..., recompute_embeddings=...) instead."
)
effective_recompute = bool(recompute_embeddings)
else:
effective_recompute = self.recompute_embeddings
start_time = time.time()
if recompute_embeddings:
if effective_recompute:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
@@ -981,7 +1017,7 @@ class LeannSearcher:
query_embedding = self.backend_impl.compute_query_embedding(
query,
use_server_if_available=recompute_embeddings,
use_server_if_available=effective_recompute,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
@@ -993,7 +1029,7 @@ class LeannSearcher:
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"recompute_embeddings": effective_recompute,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}

View File

@@ -5,12 +5,128 @@ Packaged within leann-core so installed wheels can import it reliably.
import logging
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Flag to ensure AST token warning only shown once per session
_ast_token_warning_shown = False
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text string.
Uses conservative estimation: ~4 characters per token for natural text,
~1.2 tokens per character for code (worse tokenization).
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
return len(encoder.encode(text))
except ImportError:
# Fallback: Conservative character-based estimation
# Assume worst case for code: 1.2 tokens per character
return int(len(text) * 1.2)
def calculate_safe_chunk_size(
model_token_limit: int,
overlap_tokens: int,
chunking_mode: str = "traditional",
safety_factor: float = 0.9,
) -> int:
"""
Calculate safe chunk size accounting for overlap and safety margin.
Args:
model_token_limit: Maximum tokens supported by embedding model
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
chunking_mode: "traditional" (tokens) or "ast" (characters)
safety_factor: Safety margin (0.9 = 10% safety margin)
Returns:
Safe chunk size: tokens for traditional, characters for AST
"""
safe_limit = int(model_token_limit * safety_factor)
if chunking_mode == "traditional":
# Traditional chunking uses tokens
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
return max(1, safe_limit - overlap_tokens)
else: # AST chunking
# AST uses characters, need to convert
# Conservative estimate: 1.2 tokens per char for code
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
safe_chars = int(safe_limit / 1.2)
return max(1, safe_chars - overlap_chars)
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
"""
Validate that chunks don't exceed token limits and truncate if necessary.
Args:
chunks: List of text chunks to validate
max_tokens: Maximum tokens allowed per chunk
Returns:
Tuple of (validated_chunks, num_truncated)
"""
validated_chunks = []
num_truncated = 0
for i, chunk in enumerate(chunks):
estimated_tokens = estimate_token_count(chunk)
if estimated_tokens > max_tokens:
# Truncate chunk to fit token limit
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
tokens = encoder.encode(chunk)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
truncated_chunk = encoder.decode(truncated_tokens)
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
)
else:
validated_chunks.append(chunk)
except ImportError:
# Fallback: Conservative character truncation
char_limit = int(max_tokens / 1.2) # Conservative for code
if len(chunk) > char_limit:
truncated_chunk = chunk[:char_limit]
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
validated_chunks.append(chunk)
else:
validated_chunks.append(chunk)
if num_truncated > 0:
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
return validated_chunks, num_truncated
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
@@ -61,27 +177,45 @@ def create_ast_chunks(
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
) -> list[dict[str, Any]]:
"""Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
try:
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
continue
try:
# Warn once if AST chunk size + overlap might exceed common token limits
# Note: Actual truncation happens at embedding time with dynamic model limits
global _ast_token_warning_shown
estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
)
_ast_token_warning_shown = True
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -105,17 +239,40 @@ def create_ast_chunks(
chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
chunk_text = None
astchunk_metadata = {}
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
elif isinstance(chunk, dict):
# Handle astchunk format: {"content": "...", "metadata": {...}}
if "content" in chunk:
chunk_text = chunk["content"]
astchunk_metadata = chunk.get("metadata", {})
elif "text" in chunk:
chunk_text = chunk["text"]
else:
chunk_text = str(chunk) # Last resort
else:
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
# Merge document metadata + astchunk metadata
combined_metadata = {**doc_metadata, **astchunk_metadata}
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
@@ -123,15 +280,19 @@ def create_ast_chunks(
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
) -> list[dict[str, Any]]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
@@ -147,19 +308,40 @@ def create_traditional_chunks(
paragraph_separator="\n\n",
)
all_texts = []
result = []
for doc in documents:
# Extract document-level metadata
doc_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
}
if "creation_date" in doc.metadata:
doc_metadata["creation_date"] = doc.metadata["creation_date"]
if "last_modified_date" in doc.metadata:
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
for node in nodes:
result.append({"text": node.get_content(), "metadata": doc_metadata})
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
result.append({"text": content.strip(), "metadata": doc_metadata})
return all_texts
return result
def _traditional_chunks_as_dicts(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[dict[str, Any]]:
"""Helper: Traditional chunking that returns dict format for consistency.
This is now just an alias for create_traditional_chunks for backwards compatibility.
"""
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
def create_text_chunks(
@@ -171,8 +353,12 @@ def create_text_chunks(
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""Create text chunks from documents with optional AST support for code files."""
) -> list[dict[str, Any]]:
"""Create text chunks from documents with optional AST support for code files.
Returns:
List of dicts with {"text": str, "metadata": dict}
"""
if not documents:
logger.warning("No documents provided for chunking")
return []
@@ -207,14 +393,17 @@ def create_text_chunks(
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
)
else:
raise
if text_docs:
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
else:
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
# Note: Token truncation is now handled at embedding time with dynamic model limits
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
return all_chunks

View File

@@ -181,25 +181,25 @@ Examples:
"--doc-chunk-size",
type=int,
default=256,
help="Document chunk size in tokens/characters (default: 256)",
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
)
build_parser.add_argument(
"--doc-chunk-overlap",
type=int,
default=128,
help="Document chunk overlap (default: 128)",
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--code-chunk-size",
type=int,
default=512,
help="Code chunk size in tokens/lines (default: 512)",
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
)
build_parser.add_argument(
"--code-chunk-overlap",
type=int,
default=50,
help="Code chunk overlap (default: 50)",
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
)
build_parser.add_argument(
"--use-ast-chunking",
@@ -209,14 +209,14 @@ Examples:
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=768,
help="AST chunk size in characters (default: 768)",
default=300,
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=96,
help="AST chunk overlap in characters (default: 96)",
default=64,
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
)
build_parser.add_argument(
"--ast-fallback-traditional",
@@ -1279,13 +1279,8 @@ Examples:
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# Note: AST chunking currently returns plain text chunks without metadata
# We preserve basic file info by associating chunks with their source documents
# For better metadata preservation, documents list order should be maintained
for chunk_text in chunk_texts:
# TODO: Enhance create_text_chunks to return metadata alongside text
# For now, we store chunks with empty metadata
all_texts.append({"text": chunk_text, "metadata": {}})
# create_text_chunks now returns list[dict] with metadata preserved
all_texts.extend(chunk_texts)
except ImportError as e:
print(

View File

@@ -10,6 +10,7 @@ import time
from typing import Any, Optional
import numpy as np
import tiktoken
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
@@ -20,6 +21,170 @@ LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Token limit registry for embedding models
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
# Ollama models use dynamic discovery via /api/show
EMBEDDING_MODEL_LIMITS = {
# Nomic models (common across servers)
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
"nomic-embed-text-v1.5": 2048,
"nomic-embed-text-v2": 512,
# Other embedding models
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# OpenAI models
"text-embedding-3-small": 8192,
"text-embedding-3-large": 8192,
"text-embedding-ada-002": 8192,
}
def get_model_token_limit(
model_name: str,
base_url: Optional[str] = None,
default: int = 2048,
) -> int:
"""
Get token limit for a given embedding model.
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
Args:
model_name: Name of the embedding model
base_url: Base URL of the embedding server (for dynamic discovery)
default: Default token limit if model not found
Returns:
Token limit for the model in tokens
"""
# Try Ollama dynamic discovery if base_url provided
if base_url:
# Detect Ollama servers by port or "ollama" in URL
if "11434" in base_url or "ollama" in base_url.lower():
limit = _query_ollama_context_limit(model_name, base_url)
if limit:
return limit
# Fallback to known model registry with version handling (from PR #154)
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
# Default fallback
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
return default
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
"""
Truncate texts to fit within token limit using tiktoken.
Args:
texts: List of text strings to truncate
token_limit: Maximum number of tokens allowed
Returns:
List of truncated texts (same length as input)
"""
if not texts:
return []
# Use tiktoken with cl100k_base encoding
enc = tiktoken.get_encoding("cl100k_base")
truncated_texts = []
truncation_count = 0
total_tokens_removed = 0
max_original_length = 0
for i, text in enumerate(texts):
tokens = enc.encode(text)
original_length = len(tokens)
if original_length <= token_limit:
# Text is within limit, keep as is
truncated_texts.append(text)
else:
# Truncate to token_limit
truncated_tokens = tokens[:token_limit]
truncated_text = enc.decode(truncated_tokens)
truncated_texts.append(truncated_text)
# Track truncation statistics
truncation_count += 1
tokens_removed = original_length - token_limit
total_tokens_removed += tokens_removed
max_original_length = max(max_original_length, original_length)
# Log individual truncation at WARNING level (first few only)
if truncation_count <= 3:
logger.warning(
f"Text {i + 1} truncated: {original_length}{token_limit} tokens "
f"({tokens_removed} tokens removed)"
)
elif truncation_count == 4:
logger.warning("Further truncation warnings suppressed...")
# Log summary at INFO level
if truncation_count > 0:
logger.warning(
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
)
else:
logger.debug(
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
)
return truncated_texts
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
"""
Query Ollama /api/show for model context limit.
Args:
model_name: Name of the Ollama model
base_url: Base URL of the Ollama server
Returns:
Context limit in tokens if found, None otherwise
"""
try:
import requests
response = requests.post(
f"{base_url}/api/show",
json={"name": model_name},
timeout=5,
)
if response.status_code == 200:
data = response.json()
if "model_info" in data:
# Look for *.context_length in model_info
for key, value in data["model_info"].items():
if "context_length" in key and isinstance(value, int):
logger.info(f"Detected {model_name} context limit: {value} tokens")
return value
except Exception as e:
logger.debug(f"Failed to query Ollama context limit: {e}")
return None
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
@@ -50,9 +215,14 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
wrapper_start_time = time.time()
logger.debug(
f"[compute_embeddings] entry: mode={mode}, model='{model_name}', text_count={len(texts)}"
)
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
inner_start_time = time.time()
result = compute_embeddings_sentence_transformers(
texts,
model_name,
is_build=is_build,
@@ -61,6 +231,14 @@ def compute_embeddings(
manual_tokenize=manual_tokenize,
max_length=max_length,
)
inner_end_time = time.time()
wrapper_end_time = time.time()
logger.debug(
"[compute_embeddings] sentence-transformers timings: "
f"inner={inner_end_time - inner_start_time:.6f}s, "
f"wrapper_total={wrapper_end_time - wrapper_start_time:.6f}s"
)
return result
elif mode == "openai":
return compute_embeddings_openai(
texts,
@@ -106,6 +284,7 @@ def compute_embeddings_sentence_transformers(
is_build: Whether this is a build operation (shows progress bar)
adaptive_optimization: Whether to use adaptive optimization based on batch size
"""
outer_start_time = time.time()
# Handle empty input
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
@@ -136,7 +315,14 @@ def compute_embeddings_sentence_transformers(
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}_optimized"
pre_model_init_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers pre-model-init time "
f"(device/batch selection etc.): {pre_model_init_end_time - outer_start_time:.6f}s"
)
# Check if model is already cached
start_time = time.time()
if cache_key in _model_cache:
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
@@ -276,10 +462,13 @@ def compute_embeddings_sentence_transformers(
_model_cache[cache_key] = model
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
end_time = time.time()
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"start sentence transformers {model} takes {end_time - start_time}")
start_time = time.time()
if not manual_tokenize:
@@ -300,32 +489,46 @@ def compute_embeddings_sentence_transformers(
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel.
# This path is reserved for an aggressively optimized FP pipeline
# (no quantization), mainly for experimentation.
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}_fp"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
logger.info("Using cached HF tokenizer/model for manual FP path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
logger.info("Loading HF tokenizer/model for manual FP path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch_dtype,
)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
hf_model = torch.compile( # type: ignore
hf_model, mode="reduce-overhead", dynamic=True
)
logger.info(
f"Applied torch.compile to HF model for {model_name} "
f"(device={device}, dtype={torch_dtype})"
)
except Exception as exc:
logger.warning(f"torch.compile optimization failed: {exc}")
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
@@ -351,7 +554,6 @@ def compute_embeddings_sentence_transformers(
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
@@ -359,34 +561,17 @@ def compute_embeddings_sentence_transformers(
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
@@ -406,6 +591,12 @@ def compute_embeddings_sentence_transformers(
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
outer_end_time = time.time()
logger.debug(
"compute_embeddings_sentence_transformers total time "
f"(function entry -> return): {outer_end_time - outer_start_time:.6f}s"
)
return embeddings
@@ -720,20 +911,26 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation before batching
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply truncation to all texts before batch processing
# Function logs truncation details internally
texts = truncate_to_token_limit(texts, token_limit)
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts using /api/embed endpoint."""
max_retries = 3
retry_count = 0
# Truncate very long texts to avoid API issues
truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
# Texts are already truncated to token limit by the outer function
while retry_count < max_retries:
try:
# Use /api/embed endpoint with "input" parameter for batch processing
response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": truncated_texts},
json={"model": model_name, "input": batch_texts},
timeout=60, # Increased timeout for batch processing
)
response.raise_for_status()
@@ -763,7 +960,17 @@ def compute_embeddings_ollama(
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embeddings for batch: {e}")
# Enhanced error detection for token limit violations
error_msg = str(e).lower()
if "token" in error_msg and (
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
):
logger.error(
f"Token limit exceeded for batch. Error: {e}. "
f"Consider reducing chunk sizes or check token truncation."
)
else:
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
return None, list(range(len(batch_texts)))

View File

@@ -57,6 +57,8 @@ dependencies = [
"tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
"einops",
"seaborn",
]
[project.optional-dependencies]

View File

@@ -8,7 +8,7 @@ import subprocess
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# Traditional chunks now return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks)
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
def test_create_traditional_chunks_empty_docs(self):
"""Test traditional chunking with empty documents."""
@@ -158,11 +160,22 @@ class Calculator:
# Should have multiple chunks due to different functions/classes
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(len(chunk.strip()) > 0 for chunk in chunks)
# R3: Expect dict format with "text" and "metadata" keys
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
"Each chunk text should be non-empty"
)
# Check metadata is present
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
"Each chunk should have file_path metadata"
)
# Check that code structure is somewhat preserved
combined_content = " ".join(chunks)
combined_content = " ".join([c["text"] for c in chunks])
assert "def hello_world" in combined_content
assert "class Calculator" in combined_content
@@ -194,7 +207,11 @@ class Calculator:
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: Traditional chunking should also return dict format for consistency
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_ast_mode(self):
"""Test text chunking in AST mode."""
@@ -213,7 +230,11 @@ class Calculator:
)
assert len(chunks) > 0
assert all(isinstance(chunk, str) for chunk in chunks)
# R3: AST mode should also return dict format
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
"Each chunk should have 'text' and 'metadata' keys"
)
def test_create_text_chunks_custom_extensions(self):
"""Test text chunking with custom code file extensions."""
@@ -353,6 +374,552 @@ class MathUtils:
pytest.skip("Test timed out - likely due to model download in CI")
class TestASTContentExtraction:
"""Test AST content extraction bug fix.
These tests verify that astchunk's dict format with 'content' key is handled correctly,
and that the extraction logic doesn't fall through to stringifying entire dicts.
"""
def test_extract_content_from_astchunk_dict(self):
"""Test that astchunk dict format with 'content' key is handled correctly.
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
This causes fallthrough to str(chunk), stringifying the entire dict.
This test will FAIL until the bug is fixed because:
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
- Fixed code should extract just the content value
"""
# Mock the ASTChunkBuilder class
mock_builder = Mock()
# Astchunk returns this format
astchunk_format_chunk = {
"content": "def hello():\n print('world')",
"metadata": {
"filepath": "test.py",
"line_count": 2,
"start_line_no": 0,
"end_line_no": 1,
"node_count": 1,
},
}
mock_builder.chunkify.return_value = [astchunk_format_chunk]
# Create mock document
doc = MockDocument(
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
)
# Mock the astchunk module and its ASTChunkBuilder class
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
# Patch sys.modules to inject our mock before the import
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should return dict format with proper metadata
assert len(chunks) > 0, "Should return at least one chunk"
# R3: Each chunk should be a dict
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "metadata" in chunk, "Chunk should have 'metadata' key"
chunk_text = chunk["text"]
# CRITICAL: Should NOT contain stringified dict markers in the text field
# These assertions will FAIL with current buggy code
assert "'content':" not in chunk_text, (
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
)
assert "'metadata':" not in chunk_text, (
"Chunk text contains stringified metadata - extraction failed! "
f"Got: {chunk_text[:100]}..."
)
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
"Chunk text appears to be a stringified dict"
)
# Should contain actual content
assert "def hello()" in chunk_text, "Should extract actual code content"
assert "print('world')" in chunk_text, "Should extract complete code content"
# R3: Should preserve astchunk metadata
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
"Should preserve file path metadata"
)
def test_extract_text_key_fallback(self):
"""Test that 'text' key still works for backward compatibility.
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
This test should PASS even with current code.
"""
mock_builder = Mock()
# Some chunks might use "text" key
text_key_chunk = {"text": "def legacy_function():\n return True"}
mock_builder.chunkify.return_value = [text_key_chunk]
# Create mock document
doc = MockDocument(
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract text correctly as dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
# Should NOT be stringified
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
# Should contain actual content
assert "def legacy_function()" in chunk_text
assert "return True" in chunk_text
def test_handles_string_chunks(self):
"""Test that plain string chunks still work.
Some chunkers might return plain strings - verify these are preserved.
This test should PASS with current code.
"""
mock_builder = Mock()
# Plain string chunk
plain_string_chunk = "def simple_function():\n pass"
mock_builder.chunkify.return_value = [plain_string_chunk]
# Create mock document
doc = MockDocument(
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should wrap string in dict format
assert len(chunks) > 0
chunk = chunks[0]
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
assert "text" in chunk, "Chunk should have 'text' key"
chunk_text = chunk["text"]
assert chunk_text == plain_string_chunk.strip(), (
"Should preserve plain string chunk content"
)
assert "def simple_function()" in chunk_text
assert "pass" in chunk_text
def test_multiple_chunks_with_mixed_formats(self):
"""Test handling of multiple chunks with different formats.
Real-world scenario: astchunk might return a mix of formats.
This test will FAIL if any chunk with 'content' key gets stringified.
"""
mock_builder = Mock()
# Mix of formats
mixed_chunks = [
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
"def second():\n return 2", # Plain string
{"text": "def third():\n return 3"}, # Old format
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
]
mock_builder.chunkify.return_value = mixed_chunks
# Create mock document
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
# Call create_ast_chunks
chunks = create_ast_chunks([doc])
# R3: Should extract all chunks correctly as dicts
assert len(chunks) == 4, "Should extract all 4 chunks"
# Check each chunk
for i, chunk in enumerate(chunks):
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
assert "text" in chunk, f"Chunk {i} should have 'text' key"
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
chunk_text = chunk["text"]
# None should be stringified dicts
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
assert "'metadata':" not in chunk_text, (
f"Chunk {i} text is stringified (has 'metadata':)"
)
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
# Verify actual content is present
combined = "\n".join([c["text"] for c in chunks])
assert "def first()" in combined
assert "def second()" in combined
assert "def third()" in combined
assert "class MyClass:" in combined
def test_empty_content_value_handling(self):
"""Test handling of chunks with empty content values.
Edge case: chunk has 'content' key but value is empty.
Should skip these chunks, not stringify them.
"""
mock_builder = Mock()
chunks_with_empty = [
{"content": "", "metadata": {"line_count": 0}}, # Empty content
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
]
mock_builder.chunkify.return_value = chunks_with_empty
doc = MockDocument(
"def valid():\n return True", "/test/empty.py", {"language": "python"}
)
# Mock the astchunk module
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# R3: Should only have the valid chunk (empty ones filtered out)
assert len(chunks) == 1, "Should filter out empty content chunks"
chunk = chunks[0]
assert isinstance(chunk, dict), "Chunk should be a dict"
assert "text" in chunk, "Chunk should have 'text' key"
assert "def valid()" in chunk["text"]
# Should not have stringified the empty dict
assert "'content': ''" not in chunk["text"]
class TestASTMetadataPreservation:
"""Test metadata preservation in AST chunk dictionaries.
R3: These tests define the contract for metadata preservation when returning
chunk dictionaries instead of plain strings. Each chunk dict should have:
- "text": str - the actual chunk content
- "metadata": dict - all metadata from document AND astchunk
These tests will FAIL until G3 implementation changes return type to list[dict].
"""
def test_ast_chunks_preserve_file_metadata(self):
"""Test that document metadata is preserved in chunk metadata.
This test verifies that all document-level metadata (file_path, file_name,
creation_date, last_modified_date) is included in each chunk's metadata dict.
This will FAIL because current code returns list[str], not list[dict].
"""
# Create mock document with rich metadata
python_code = '''
def calculate_sum(numbers):
"""Calculate sum of numbers."""
return sum(numbers)
class DataProcessor:
"""Process data records."""
def process(self, data):
return [x * 2 for x in data]
'''
doc = MockDocument(
python_code,
file_path="/project/src/utils.py",
metadata={
"language": "python",
"file_path": "/project/src/utils.py",
"file_name": "utils.py",
"creation_date": "2024-01-15T10:30:00",
"last_modified_date": "2024-10-31T15:45:00",
},
)
# Mock astchunk to return chunks with metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def calculate_sum(numbers):\n return sum(numbers)",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 2,
"start_line_no": 1,
"end_line_no": 2,
"node_count": 1,
},
},
{
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
"metadata": {
"filepath": "/project/src/utils.py",
"line_count": 3,
"start_line_no": 5,
"end_line_no": 7,
"node_count": 2,
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These assertions will FAIL with current list[str] return type
assert len(chunks) == 2, "Should return 2 chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL: current code returns strings
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
# Document metadata preservation - WILL FAIL
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
assert metadata["file_path"] == "/project/src/utils.py", (
f"Chunk {i} file_path incorrect"
)
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
f"Chunk {i} creation_date incorrect"
)
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
f"Chunk {i} last_modified_date incorrect"
)
# Verify metadata is consistent across chunks from same document
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
"All chunks from same document should have same file_path"
)
# Verify text content is present and not stringified
assert "def calculate_sum" in chunks[0]["text"]
assert "class DataProcessor" in chunks[1]["text"]
def test_ast_chunks_include_astchunk_metadata(self):
"""Test that astchunk-specific metadata is merged into chunk metadata.
This test verifies that astchunk's metadata (line_count, start_line_no,
end_line_no, node_count) is merged with document metadata.
This will FAIL because current code returns list[str], not list[dict].
"""
python_code = '''
def function_one():
"""First function."""
x = 1
y = 2
return x + y
def function_two():
"""Second function."""
return 42
'''
doc = MockDocument(
python_code,
file_path="/test/code.py",
metadata={
"language": "python",
"file_path": "/test/code.py",
"file_name": "code.py",
},
)
# Mock astchunk with detailed metadata
mock_builder = Mock()
astchunk_chunks = [
{
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
"metadata": {
"filepath": "/test/code.py",
"line_count": 4,
"start_line_no": 1,
"end_line_no": 4,
"node_count": 5, # function, assignments, return
},
},
{
"content": "def function_two():\n return 42",
"metadata": {
"filepath": "/test/code.py",
"line_count": 2,
"start_line_no": 7,
"end_line_no": 8,
"node_count": 2, # function, return
},
},
]
mock_builder.chunkify.return_value = astchunk_chunks
mock_astchunk = Mock()
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
chunks = create_ast_chunks([doc])
# CRITICAL: These will FAIL with current list[str] return
assert len(chunks) == 2
# First chunk - function_one
chunk1 = chunks[0]
assert isinstance(chunk1, dict), "Chunk should be dict"
assert "metadata" in chunk1
metadata1 = chunk1["metadata"]
# Check astchunk metadata is present
assert "line_count" in metadata1, "Should include astchunk line_count"
assert metadata1["line_count"] == 4, "line_count should be 4"
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
assert "node_count" in metadata1, "Should include astchunk node_count"
assert metadata1["node_count"] == 5, "node_count should be 5"
# Second chunk - function_two
chunk2 = chunks[1]
metadata2 = chunk2["metadata"]
assert metadata2["line_count"] == 2, "line_count should be 2"
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
assert metadata2["node_count"] == 2, "node_count should be 2"
# Verify document metadata is ALSO present (merged, not replaced)
assert metadata1["file_path"] == "/test/code.py"
assert metadata1["file_name"] == "code.py"
assert metadata2["file_path"] == "/test/code.py"
assert metadata2["file_name"] == "code.py"
# Verify text content is correct
assert "def function_one" in chunk1["text"]
assert "def function_two" in chunk2["text"]
def test_traditional_chunks_as_dicts_helper(self):
"""Test the helper function that wraps traditional chunks as dicts.
This test verifies that when create_traditional_chunks is called,
its plain string chunks are wrapped into dict format with metadata.
This will FAIL because the helper function _traditional_chunks_as_dicts()
doesn't exist yet, and create_traditional_chunks returns list[str].
"""
# Create documents with various metadata
docs = [
MockDocument(
"This is the first paragraph of text. It contains multiple sentences. "
"This should be split into chunks based on size.",
file_path="/docs/readme.txt",
metadata={
"file_path": "/docs/readme.txt",
"file_name": "readme.txt",
"creation_date": "2024-01-01",
},
),
MockDocument(
"Second document with different metadata. It also has content that needs chunking.",
file_path="/docs/guide.md",
metadata={
"file_path": "/docs/guide.md",
"file_name": "guide.md",
"last_modified_date": "2024-10-31",
},
),
]
# Call create_traditional_chunks (which should now return list[dict])
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
# CRITICAL: Will FAIL - current code returns list[str]
assert len(chunks) > 0, "Should return chunks"
for i, chunk in enumerate(chunks):
# Structure assertions - WILL FAIL
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
assert "text" in chunk, f"Chunk {i} must have 'text' key"
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
# Text should be non-empty
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
# Metadata should include document info
metadata = chunk["metadata"]
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
# Verify metadata tracking works correctly
# At least one chunk should be from readme.txt
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
# At least one chunk should be from guide.md
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
# Verify creation_date is preserved for readme chunks
for chunk in readme_chunks:
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
"readme.txt chunks should preserve creation_date"
)
# Verify last_modified_date is preserved for guide chunks
for chunk in guide_chunks:
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
"guide.md chunks should preserve last_modified_date"
)
# Verify text content is present
all_text = " ".join([c["text"] for c in chunks])
assert "first paragraph" in all_text
assert "Second document" in all_text
class TestErrorHandling:
"""Test error handling and edge cases."""

View File

@@ -0,0 +1,268 @@
"""Unit tests for token-aware truncation functionality.
This test suite defines the contract for token truncation functions that prevent
500 errors from Ollama when text exceeds model token limits. These tests verify:
1. Model token limit retrieval (known and unknown models)
2. Text truncation behavior for single and multiple texts
3. Token counting and truncation accuracy using tiktoken
All tests are written in Red Phase - they should FAIL initially because the
implementation does not exist yet.
"""
import pytest
import tiktoken
from leann.embedding_compute import (
EMBEDDING_MODEL_LIMITS,
get_model_token_limit,
truncate_to_token_limit,
)
class TestModelTokenLimits:
"""Tests for retrieving model-specific token limits."""
def test_get_model_token_limit_known_model(self):
"""Verify correct token limit is returned for known models.
Known models should return their specific token limits from
EMBEDDING_MODEL_LIMITS dictionary.
"""
# Test nomic-embed-text (2048 tokens)
limit = get_model_token_limit("nomic-embed-text")
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
# Test nomic-embed-text-v1.5 (2048 tokens)
limit = get_model_token_limit("nomic-embed-text-v1.5")
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
# Test nomic-embed-text-v2 (512 tokens)
limit = get_model_token_limit("nomic-embed-text-v2")
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
# Test OpenAI models (8192 tokens)
limit = get_model_token_limit("text-embedding-3-small")
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
def test_get_model_token_limit_unknown_model(self):
"""Verify default token limit is returned for unknown models.
Unknown models should return the default limit (2048) to allow
operation with reasonable safety margin.
"""
# Test with completely unknown model
limit = get_model_token_limit("unknown-model-xyz")
assert limit == 2048, "Unknown models should return default 2048"
# Test with empty string
limit = get_model_token_limit("")
assert limit == 2048, "Empty model name should return default 2048"
def test_get_model_token_limit_custom_default(self):
"""Verify custom default can be specified for unknown models.
Allow callers to specify their own default token limit when
model is not in the known models dictionary.
"""
limit = get_model_token_limit("unknown-model", default=4096)
assert limit == 4096, "Should return custom default for unknown models"
# Known model should ignore custom default
limit = get_model_token_limit("nomic-embed-text", default=4096)
assert limit == 2048, "Known model should ignore custom default"
def test_embedding_model_limits_dictionary_exists(self):
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
The dictionary should be importable and contain at least the
known nomic models with correct token limits.
"""
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
"Should contain nomic-embed-text-v1.5"
)
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
# OpenAI models
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
class TestTokenTruncation:
"""Tests for truncating texts to token limits."""
@pytest.fixture
def tokenizer(self):
"""Provide tiktoken tokenizer for token counting verification."""
return tiktoken.get_encoding("cl100k_base")
def test_truncate_single_text_under_limit(self, tokenizer):
"""Verify text under token limit remains unchanged.
When text is already within the token limit, it should be
returned unchanged with no truncation.
"""
text = "This is a short text that is well under the token limit."
token_count = len(tokenizer.encode(text))
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
# Truncate with generous limit
result = truncate_to_token_limit([text], token_limit=512)
assert len(result) == 1, "Should return same number of texts"
assert result[0] == text, "Text under limit should be unchanged"
def test_truncate_single_text_over_limit(self, tokenizer):
"""Verify text over token limit is truncated correctly.
When text exceeds the token limit, it should be truncated to
fit within the limit while maintaining valid token boundaries.
"""
# Create a text that definitely exceeds limit
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 50, (
f"Test setup: text should be long (has {original_token_count} tokens)"
)
# Truncate to 50 tokens
result = truncate_to_token_limit([text], token_limit=50)
assert len(result) == 1, "Should return same number of texts"
assert result[0] != text, "Text over limit should be truncated"
assert len(result[0]) < len(text), "Truncated text should be shorter"
# Verify truncated text is within token limit
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 50, (
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
)
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
"""Verify multiple texts with mixed lengths are handled correctly.
When processing multiple texts:
- Texts under limit should remain unchanged
- Texts over limit should be truncated independently
- Output list should maintain same order and length
"""
texts = [
"Short text.", # Under limit
"word " * 200, # Over limit
"Another short one.", # Under limit
"token " * 150, # Over limit
]
# Verify test setup
for i, text in enumerate(texts):
token_count = len(tokenizer.encode(text))
if i in [1, 3]:
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
else:
assert token_count < 50, (
f"Text {i} should be under limit (has {token_count} tokens)"
)
# Truncate with 50 token limit
result = truncate_to_token_limit(texts, token_limit=50)
assert len(result) == len(texts), "Should return same number of texts"
# Verify each text individually
for i, (original, truncated) in enumerate(zip(texts, result)):
token_count = len(tokenizer.encode(truncated))
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
# Short texts should be unchanged
if i in [0, 2]:
assert truncated == original, f"Short text {i} should be unchanged"
# Long texts should be truncated
else:
assert len(truncated) < len(original), f"Long text {i} should be truncated"
def test_truncate_empty_list(self):
"""Verify empty input list returns empty output list.
Edge case: empty list should return empty list without errors.
"""
result = truncate_to_token_limit([], token_limit=512)
assert result == [], "Empty input should return empty output"
def test_truncate_preserves_order(self, tokenizer):
"""Verify truncation preserves original text order.
Output list should maintain the same order as input list,
regardless of which texts were truncated.
"""
texts = [
"First text " * 50, # Will be truncated
"Second text.", # Won't be truncated
"Third text " * 50, # Will be truncated
]
result = truncate_to_token_limit(texts, token_limit=20)
assert len(result) == 3, "Should preserve list length"
# Check that order is maintained by looking for distinctive words
assert "First" in result[0], "First text should remain in first position"
assert "Second" in result[1], "Second text should remain in second position"
assert "Third" in result[2], "Third text should remain in third position"
def test_truncate_extremely_long_text(self, tokenizer):
"""Verify extremely long texts are truncated efficiently.
Test with text that far exceeds token limit to ensure
truncation handles extreme cases without performance issues.
"""
# Create very long text (simulate real-world scenario)
text = "token " * 5000 # ~5000+ tokens
original_token_count = len(tokenizer.encode(text))
assert original_token_count > 1000, "Test setup: text should be very long"
# Truncate to small limit
result = truncate_to_token_limit([text], token_limit=100)
assert len(result) == 1
truncated_token_count = len(tokenizer.encode(result[0]))
assert truncated_token_count <= 100, (
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
)
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
def test_truncate_exact_token_limit(self, tokenizer):
"""Verify text at exactly token limit is handled correctly.
Edge case: text with exactly the token limit should either
remain unchanged or be safely truncated by 1 token.
"""
# Create text with approximately 50 tokens
# We'll adjust to get exactly 50
target_tokens = 50
text = "word " * 50
tokens = tokenizer.encode(text)
# Adjust to get exactly target_tokens
if len(tokens) > target_tokens:
tokens = tokens[:target_tokens]
text = tokenizer.decode(tokens)
elif len(tokens) < target_tokens:
# Add more words
while len(tokenizer.encode(text)) < target_tokens:
text += "word "
tokens = tokenizer.encode(text)[:target_tokens]
text = tokenizer.decode(tokens)
# Verify we have exactly target_tokens
assert len(tokenizer.encode(text)) == target_tokens, (
"Test setup: should have exactly 50 tokens"
)
result = truncate_to_token_limit([text], token_limit=target_tokens)
assert len(result) == 1
result_tokens = len(tokenizer.encode(result[0]))
assert result_tokens <= target_tokens, (
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
)

7553
uv.lock generated
View File

File diff suppressed because it is too large Load Diff