diff --git a/benchmarks/update/README.md b/benchmarks/update/README.md new file mode 100644 index 0000000..585342b --- /dev/null +++ b/benchmarks/update/README.md @@ -0,0 +1,143 @@ +# Update Benchmarks + +This directory hosts two benchmark suites that exercise LEANN’s HNSW “update + +search” pipeline under different assumptions: + +1. **RNG recompute latency** – measure how random-neighbour pruning and cache + settings influence incremental `add()` latency when embeddings are fetched + over the ZMQ embedding server. +2. **Update strategy comparison** – compare a fully sequential update pipeline + against an offline approach that keeps the graph static and fuses results. + +Both suites build a non-compact, `is_recompute=True` index so that new +embeddings are pulled from the embedding server. Benchmark outputs are written +under `.leann/bench/` by default and appended to CSV files for later plotting. + +## Benchmarks + +### 1. HNSW RNG Recompute Benchmark + +`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four +random-neighbour (RNG) configurations. Each scenario uses the same dataset but +changes the forward / reverse RNG pruning flags and whether the embedding cache +is enabled: + +| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache | +| ---------------------------------- | ----------- | ----------- | ------------------- | +| `baseline` | Enabled | Enabled | Enabled | +| `no_cache_baseline` | Enabled | Enabled | **Disabled** | +| `disable_forward_rng` | **Disabled**| Enabled | Enabled | +| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled | + +For each scenario the script: +1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`. +2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings. +3. Appends the requested updates using the scenario’s RNG flags. +4. Records total time, latency per passage, ZMQ fetch counts, and stage-level + timings before appending a row to the CSV output. + +**Run:** +```bash +LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \ +LEANN_LOG_LEVEL=INFO \ +uv run -m benchmarks.update.bench_hnsw_rng_recompute \ + --runs 1 \ + --index-path .leann/bench/test.leann \ + --initial-files data/PrideandPrejudice.txt \ + --update-files data/huawei_pangu.md \ + --max-initial 300 \ + --max-updates 1 \ + --add-timeout 120 +``` + +**Output:** +- `benchmarks/update/bench_results.csv` – per-scenario timing statistics + (including ms/passage) for each run. +- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by + `LEANN_HNSW_LOG_PATH`). + _The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._ + +### 2. Sequential vs. Offline Update Benchmark + +`bench_update_vs_offline_search.py` compares two end-to-end strategies on the +same dataset: + +- **Scenario A – Sequential Update** + - Start an embedding server. + - Sequentially call `index.add()`; each call fetches embeddings via ZMQ and + mutates the HNSW graph. + - After all inserts, run a search on the updated graph. + - Metrics recorded: update time (`add_total_s`), post-update search time + (`search_time_s`), combined total (`total_time_s`), and per-passage + latency. + +- **Scenario B – Offline Embedding + Concurrent Search** + - Stop Scenario A’s server and start a fresh embedding server. + - Spawn two threads: one generates embeddings for the new passages offline + (graph unchanged); the other computes the query embedding and searches the + existing graph. + - Merge offline similarities with the graph search results to emulate late + fusion, then report the merged top‑k preview. + - Metrics recorded: embedding time (`emb_time_s`), search time + (`search_time_s`), concurrent makespan (`makespan_s`), and scenario total. + +**Run (both scenarios):** +```bash +uv run -m benchmarks.update.bench_update_vs_offline_search \ + --index-path .leann/bench/offline_vs_update.leann \ + --max-initial 300 \ + --num-updates 1 +``` + +You can pass `--only A` or `--only B` to run a single scenario. The script will +print timing summaries to stdout and append the results to CSV. + +**Output:** +- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for + Scenario A and B. +- Console output includes Scenario B’s merged top‑k preview for quick sanity + checks. + _The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._ + +### 3. Visualisation + +`plot_bench_results.py` combines the RNG benchmark and the update strategy +benchmark into a single two-panel plot. + +**Run:** +```bash +uv run -m benchmarks.update.plot_bench_results \ + --csv benchmarks/update/bench_results.csv \ + --csv-right benchmarks/update/offline_vs_update.csv \ + --out benchmarks/update/bench_latency_from_csv.png +``` + +**Options:** +- `--broken-y` – Enable a broken Y-axis (default: true when appropriate). +- `--csv` – RNG benchmark results CSV (left panel). +- `--csv-right` – Update strategy results CSV (right panel). +- `--out` – Output image path (PNG/PDF supported). + +**Output:** +- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two + suites. +- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for + slides/papers. + +## Parameters & Environment + +### Common CLI Flags +- `--max-initial` – Number of initial passages used to seed the index. +- `--max-updates` / `--num-updates` – Number of passages to treat as updates. +- `--index-path` – Base path (without extension) where the LEANN index is stored. +- `--runs` – Number of repetitions (RNG benchmark only). + +### Environment Variables +- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional). +- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR). +- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU + execution of the embedding model. + +With these scripts you can easily replicate LEANN’s update benchmarks, compare +multiple RNG strategies, and evaluate whether sequential updates or offline +fusion better match your latency/accuracy trade-offs. diff --git a/benchmarks/update/__init__.py b/benchmarks/update/__init__.py new file mode 100644 index 0000000..4970eba --- /dev/null +++ b/benchmarks/update/__init__.py @@ -0,0 +1,16 @@ +"""Benchmarks for LEANN update workflows.""" + +# Expose helper to locate repository root for other modules that need it. +from pathlib import Path + + +def find_repo_root() -> Path: + """Return the project root containing pyproject.toml.""" + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "pyproject.toml").exists(): + return parent + return current.parents[1] + + +__all__ = ["find_repo_root"] diff --git a/benchmarks/update/bench_hnsw_rng_recompute.py b/benchmarks/update/bench_hnsw_rng_recompute.py new file mode 100644 index 0000000..81272ae --- /dev/null +++ b/benchmarks/update/bench_hnsw_rng_recompute.py @@ -0,0 +1,804 @@ +"""Benchmark incremental HNSW add() under different RNG pruning modes with real +embedding recomputation. + +This script clones the structure of ``examples/dynamic_update_no_recompute.py`` +so that we build a non-compact ``is_recompute=True`` index, spin up the +standard HNSW embedding server, and measure how long incremental ``add`` takes +when RNG pruning is fully enabled vs. partially/fully disabled. + +Example usage (run from the repo root; downloads the model on first run):: + + uv run -m benchmarks.update.bench_hnsw_rng_recompute \ + --index-path .leann/bench/leann-demo.leann \ + --runs 1 + +You can tweak the input documents with ``--initial-files`` / ``--update-files`` +if you want a larger or different workload, and change the embedding model via +``--model-name``. +""" + +import argparse +import json +import logging +import os +import pickle +import re +import sys +import time +from pathlib import Path +from typing import Any + +import msgpack +import numpy as np +import zmq +from leann.api import LeannBuilder + +if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"): + os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") + +from leann.embedding_compute import compute_embeddings +from leann.embedding_server_manager import EmbeddingServerManager +from leann.registry import register_project_directory +from leann_backend_hnsw import faiss # type: ignore +from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace + +logger = logging.getLogger(__name__) +if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + + +def _find_repo_root() -> Path: + """Locate project root by walking up until pyproject.toml is found.""" + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "pyproject.toml").exists(): + return parent + # Fallback: assume repo is two levels up (../..) + return current.parents[2] + + +REPO_ROOT = _find_repo_root() +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from apps.chunking import create_text_chunks # noqa: E402 + +DEFAULT_INITIAL_FILES = [ + REPO_ROOT / "data" / "2501.14312v1 (1).pdf", + REPO_ROOT / "data" / "huawei_pangu.md", +] +DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"] + +DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log") + + +def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]: + from llama_index.core import SimpleDirectoryReader + + documents = [] + for path in paths: + p = path.expanduser().resolve() + if not p.exists(): + raise FileNotFoundError(f"Input path not found: {p}") + if p.is_dir(): + reader = SimpleDirectoryReader(str(p), recursive=False) + documents.extend(reader.load_data(show_progress=True)) + else: + reader = SimpleDirectoryReader(input_files=[str(p)]) + documents.extend(reader.load_data(show_progress=True)) + + if not documents: + return [] + + chunks = create_text_chunks( + documents, + chunk_size=512, + chunk_overlap=128, + use_ast_chunking=False, + ) + cleaned = [c for c in chunks if isinstance(c, str) and c.strip()] + if limit is not None: + cleaned = cleaned[:limit] + return cleaned + + +def ensure_index_dir(index_path: Path) -> None: + index_path.parent.mkdir(parents=True, exist_ok=True) + + +def cleanup_index_files(index_path: Path) -> None: + parent = index_path.parent + if not parent.exists(): + return + stem = index_path.stem + for file in parent.glob(f"{stem}*"): + if file.is_file(): + file.unlink() + + +def build_initial_index( + index_path: Path, + paragraphs: list[str], + model_name: str, + embedding_mode: str, + distance_metric: str, + ef_construction: int, +) -> None: + builder = LeannBuilder( + backend_name="hnsw", + embedding_model=model_name, + embedding_mode=embedding_mode, + is_compact=False, + is_recompute=True, + distance_metric=distance_metric, + backend_kwargs={ + "distance_metric": distance_metric, + "is_compact": False, + "is_recompute": True, + "efConstruction": ef_construction, + }, + ) + for idx, passage in enumerate(paragraphs): + builder.add_text(passage, metadata={"id": str(idx)}) + builder.build_index(str(index_path)) + + +def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]: + return [{"text": text, "metadata": {}} for text in paragraphs] + + +def benchmark_update_with_mode( + index_path: Path, + new_chunks: list[dict[str, Any]], + model_name: str, + embedding_mode: str, + distance_metric: str, + disable_forward_rng: bool, + disable_reverse_rng: bool, + server_port: int, + add_timeout: int, + ef_construction: int, +) -> tuple[float, float]: + meta_path = index_path.parent / f"{index_path.name}.meta.json" + passages_file = index_path.parent / f"{index_path.name}.passages.jsonl" + offset_file = index_path.parent / f"{index_path.name}.passages.idx" + index_file = index_path.parent / f"{index_path.stem}.index" + + with open(meta_path, encoding="utf-8") as f: + meta = json.load(f) + + with open(offset_file, "rb") as f: + offset_map: dict[str, int] = pickle.load(f) + existing_ids = set(offset_map.keys()) + + valid_chunks: list[dict[str, Any]] = [] + for chunk in new_chunks: + text = chunk.get("text", "") + if not isinstance(text, str) or not text.strip(): + continue + metadata = chunk.setdefault("metadata", {}) + passage_id = chunk.get("id") or metadata.get("id") + if passage_id and passage_id in existing_ids: + raise ValueError(f"Passage ID '{passage_id}' already exists in the index.") + valid_chunks.append(chunk) + + if not valid_chunks: + raise ValueError("No valid chunks to append.") + + texts_to_embed = [chunk["text"] for chunk in valid_chunks] + embeddings = compute_embeddings( + texts_to_embed, + model_name, + mode=embedding_mode, + is_build=False, + batch_size=16, + ) + + embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) + if distance_metric == "cosine": + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1 + embeddings = embeddings / norms + + index = faiss.read_index(str(index_file)) + index.is_recompute = True + if getattr(index, "storage", None) is None: + if index.metric_type == faiss.METRIC_INNER_PRODUCT: + storage_index = faiss.IndexFlatIP(index.d) + else: + storage_index = faiss.IndexFlatL2(index.d) + index.storage = storage_index + index.own_fields = True + try: + storage_index.ntotal = index.ntotal + except AttributeError: + pass + try: + index.hnsw.set_disable_rng_during_add(disable_forward_rng) + index.hnsw.set_disable_reverse_prune(disable_reverse_rng) + if ef_construction is not None: + index.hnsw.efConstruction = ef_construction + except AttributeError: + pass + + applied_forward = getattr(index.hnsw, "disable_rng_during_add", None) + applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None) + logger.info( + "HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s", + disable_forward_rng, + disable_reverse_rng, + applied_forward, + applied_reverse, + ) + + base_id = index.ntotal + for offset, chunk in enumerate(valid_chunks): + new_id = str(base_id + offset) + chunk.setdefault("metadata", {})["id"] = new_id + chunk["id"] = new_id + + rollback_size = passages_file.stat().st_size if passages_file.exists() else 0 + offset_map_backup = offset_map.copy() + + try: + with open(passages_file, "a", encoding="utf-8") as f: + for chunk in valid_chunks: + offset = f.tell() + json.dump( + { + "id": chunk["id"], + "text": chunk["text"], + "metadata": chunk.get("metadata", {}), + }, + f, + ensure_ascii=False, + ) + f.write("\n") + offset_map[chunk["id"]] = offset + + with open(offset_file, "wb") as f: + pickle.dump(offset_map, f) + + server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) + server_started, actual_port = server_manager.start_server( + port=server_port, + model_name=model_name, + embedding_mode=embedding_mode, + passages_file=str(meta_path), + distance_metric=distance_metric, + ) + if not server_started: + raise RuntimeError("Failed to start embedding server.") + + if hasattr(index.hnsw, "set_zmq_port"): + index.hnsw.set_zmq_port(actual_port) + elif hasattr(index, "set_zmq_port"): + index.set_zmq_port(actual_port) + + _warmup_embedding_server(actual_port) + + total_start = time.time() + add_elapsed = 0.0 + + try: + import signal + + def _timeout_handler(signum, frame): + raise TimeoutError("incremental add timed out") + + if add_timeout > 0: + signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(add_timeout) + + add_start = time.time() + for i in range(embeddings.shape[0]): + index.add(1, faiss.swig_ptr(embeddings[i : i + 1])) + add_elapsed = time.time() - add_start + if add_timeout > 0: + signal.alarm(0) + faiss.write_index(index, str(index_file)) + finally: + server_manager.stop_server() + + except TimeoutError: + raise + except Exception: + if passages_file.exists(): + with open(passages_file, "rb+") as f: + f.truncate(rollback_size) + with open(offset_file, "wb") as f: + pickle.dump(offset_map_backup, f) + raise + + prune_hnsw_embeddings_inplace(str(index_file)) + + meta["total_passages"] = len(offset_map) + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + # Reset toggles so the index on disk returns to baseline behaviour. + try: + index.hnsw.set_disable_rng_during_add(False) + index.hnsw.set_disable_reverse_prune(False) + except AttributeError: + pass + faiss.write_index(index, str(index_file)) + + total_elapsed = time.time() - total_start + + return total_elapsed, add_elapsed + + +def _total_zmq_nodes(log_path: Path) -> int: + if not log_path.exists(): + return 0 + with log_path.open("r", encoding="utf-8") as log_file: + text = log_file.read() + return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text)) + + +def _warmup_embedding_server(port: int) -> None: + """Send a dummy REQ so the embedding server loads its model.""" + ctx = zmq.Context() + try: + sock = ctx.socket(zmq.REQ) + sock.setsockopt(zmq.LINGER, 0) + sock.setsockopt(zmq.RCVTIMEO, 5000) + sock.setsockopt(zmq.SNDTIMEO, 5000) + sock.connect(f"tcp://127.0.0.1:{port}") + payload = msgpack.packb(["__WARMUP__"], use_bin_type=True) + sock.send(payload) + try: + sock.recv() + except zmq.error.Again: + pass + finally: + sock.close() + ctx.term() + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--index-path", + type=Path, + default=Path(".leann/bench/leann-demo.leann"), + help="Output index base path (without extension).", + ) + parser.add_argument( + "--initial-files", + nargs="*", + type=Path, + default=DEFAULT_INITIAL_FILES, + help="Files used to build the initial index.", + ) + parser.add_argument( + "--update-files", + nargs="*", + type=Path, + default=DEFAULT_UPDATE_FILES, + help="Files appended during the benchmark.", + ) + parser.add_argument( + "--runs", type=int, default=1, help="How many times to repeat each scenario." + ) + parser.add_argument( + "--model-name", + default="sentence-transformers/all-MiniLM-L6-v2", + help="Embedding model used for build/update.", + ) + parser.add_argument( + "--embedding-mode", + default="sentence-transformers", + help="Embedding mode passed to LeannBuilder/embedding server.", + ) + parser.add_argument( + "--distance-metric", + default="mips", + choices=["mips", "l2", "cosine"], + help="Distance metric for HNSW backend.", + ) + parser.add_argument( + "--ef-construction", + type=int, + default=200, + help="efConstruction setting for initial build.", + ) + parser.add_argument( + "--server-port", + type=int, + default=5557, + help="Port for the real embedding server.", + ) + parser.add_argument( + "--max-initial", + type=int, + default=300, + help="Optional cap on initial passages (after chunking).", + ) + parser.add_argument( + "--max-updates", + type=int, + default=1, + help="Optional cap on update passages (after chunking).", + ) + parser.add_argument( + "--add-timeout", + type=int, + default=900, + help="Timeout in seconds for the incremental add loop (0 = no timeout).", + ) + parser.add_argument( + "--plot-path", + type=Path, + default=Path("bench_latency.png"), + help="Where to save the latency bar plot.", + ) + parser.add_argument( + "--cap-y", + type=float, + default=None, + help="Cap Y-axis (ms). Bars above are hatched and annotated.", + ) + parser.add_argument( + "--broken-y", + action="store_true", + help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.", + ) + parser.add_argument( + "--lower-cap-y", + type=float, + default=None, + help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.", + ) + parser.add_argument( + "--upper-start-y", + type=float, + default=None, + help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.", + ) + parser.add_argument( + "--csv-path", + type=Path, + default=Path("benchmarks/update/bench_results.csv"), + help="Where to append per-scenario results as CSV.", + ) + + args = parser.parse_args() + + register_project_directory(REPO_ROOT) + + initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial) + update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates) + if not update_paragraphs: + raise ValueError("No update passages found; please provide --update-files with content.") + + update_chunks = prepare_new_chunks(update_paragraphs) + ensure_index_dir(args.index_path) + + scenarios = [ + ("baseline", False, False, True), + ("no_cache_baseline", False, False, False), + ("disable_forward_rng", True, False, True), + ("disable_forward_and_reverse_rng", True, True, True), + ] + + log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG)) + log_path.parent.mkdir(parents=True, exist_ok=True) + os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve()) + os.environ.setdefault("LEANN_LOG_LEVEL", "INFO") + + results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios} + results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios} + results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios} + results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios} + results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios} + results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios} + + # CSV setup + import csv + + run_id = time.strftime("%Y%m%d-%H%M%S") + csv_fields = [ + "run_id", + "scenario", + "cache_enabled", + "ef_construction", + "max_initial", + "max_updates", + "total_time_s", + "add_only_s", + "latency_ms_per_passage", + "zmq_nodes", + "stageA_time_s", + "stageBC_time_s", + "model_name", + "embedding_mode", + "distance_metric", + ] + # Create CSV with header if missing + if args.csv_path: + args.csv_path.parent.mkdir(parents=True, exist_ok=True) + if not args.csv_path.exists() or args.csv_path.stat().st_size == 0: + with args.csv_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writeheader() + + for run in range(args.runs): + print(f"\n=== Benchmark run {run + 1}/{args.runs} ===") + for name, disable_forward, disable_reverse, cache_enabled in scenarios: + print(f"\nScenario: {name}") + cleanup_index_files(args.index_path) + if log_path.exists(): + try: + log_path.unlink() + except OSError: + pass + os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0" + build_initial_index( + args.index_path, + initial_paragraphs, + args.model_name, + args.embedding_mode, + args.distance_metric, + args.ef_construction, + ) + + prev_size = log_path.stat().st_size if log_path.exists() else 0 + + try: + total_elapsed, add_elapsed = benchmark_update_with_mode( + args.index_path, + update_chunks, + args.model_name, + args.embedding_mode, + args.distance_metric, + disable_forward, + disable_reverse, + args.server_port, + args.add_timeout, + args.ef_construction, + ) + except TimeoutError as exc: + print(f"Scenario {name} timed out: {exc}") + continue + + curr_size = log_path.stat().st_size if log_path.exists() else 0 + if curr_size < prev_size: + prev_size = 0 + zmq_count = 0 + if log_path.exists(): + with log_path.open("r", encoding="utf-8") as log_file: + log_file.seek(prev_size) + new_entries = log_file.read() + zmq_count = sum( + int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries) + ) + stageA = sum( + float(x) + for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries) + ) + stageBC = sum( + float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries) + ) + else: + stageA = 0.0 + stageBC = 0.0 + + per_chunk = add_elapsed / len(update_chunks) + print( + f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s " + f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage" + ) + print(f"ZMQ node fetch total: {zmq_count}") + results_total[name].append(total_elapsed) + results_add[name].append(add_elapsed) + results_zmq[name].append(zmq_count) + results_ms_per_passage[name].append(per_chunk * 1e3) + results_stageA[name].append(stageA) + results_stageBC[name].append(stageBC) + + # Append row to CSV + if args.csv_path: + row = { + "run_id": run_id, + "scenario": name, + "cache_enabled": 1 if cache_enabled else 0, + "ef_construction": args.ef_construction, + "max_initial": args.max_initial, + "max_updates": args.max_updates, + "total_time_s": round(total_elapsed, 6), + "add_only_s": round(add_elapsed, 6), + "latency_ms_per_passage": round(per_chunk * 1e3, 6), + "zmq_nodes": int(zmq_count), + "stageA_time_s": round(stageA, 6), + "stageBC_time_s": round(stageBC, 6), + "model_name": args.model_name, + "embedding_mode": args.embedding_mode, + "distance_metric": args.distance_metric, + } + with args.csv_path.open("a", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writerow(row) + + print("\n=== Summary ===") + for name in results_add: + add_values = results_add[name] + total_values = results_total[name] + zmq_values = results_zmq[name] + latency_values = results_ms_per_passage[name] + if not add_values: + print(f"{name}: no successful runs") + continue + avg_add = sum(add_values) / len(add_values) + avg_total = sum(total_values) / len(total_values) + avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0 + avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0 + runs = len(add_values) + print( + f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s " + f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)" + ) + + if args.plot_path: + try: + import matplotlib.pyplot as plt + + labels = [name for name, *_ in scenarios] + values = [ + sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name]) + if results_ms_per_passage[name] + else 0.0 + for name in labels + ] + + def _auto_cap(vals: list[float]) -> float | None: + s = sorted(vals, reverse=True) + if len(s) < 2: + return None + if s[1] > 0 and s[0] >= 2.5 * s[1]: + return s[1] * 1.1 + return None + + def _fmt_ms(v: float) -> str: + return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}" + + colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"] + + if args.broken_y: + s = sorted(values, reverse=True) + second = s[1] if len(s) >= 2 else (s[0] if s else 0.0) + lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1 + upper_start = ( + args.upper_start_y + if args.upper_start_y is not None + else max(second * 1.2, lower_cap * 1.02) + ) + ymax = max(values) * 1.10 if values else 1.0 + fig, (ax_top, ax_bottom) = plt.subplots( + 2, + 1, + sharex=True, + figsize=(7.4, 5.0), + gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05}, + ) + x = list(range(len(labels))) + ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8) + ax_top.bar(x, values, color=colors[: len(labels)], width=0.8) + ax_bottom.set_ylim(0, lower_cap) + ax_top.set_ylim(upper_start, ymax) + for i, v in enumerate(values): + if v <= lower_cap: + ax_bottom.text( + i, + v + lower_cap * 0.02, + _fmt_ms(v), + ha="center", + va="bottom", + fontsize=9, + ) + else: + ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9) + ax_top.spines["bottom"].set_visible(False) + ax_bottom.spines["top"].set_visible(False) + ax_top.tick_params(labeltop=False) + ax_bottom.xaxis.tick_bottom() + d = 0.015 + kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False} + ax_top.plot((-d, +d), (-d, +d), **kwargs) + ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) + kwargs.update({"transform": ax_bottom.transAxes}) + ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) + ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) + ax_bottom.set_xticks(range(len(labels))) + ax_bottom.set_xticklabels(labels) + ax = ax_bottom + else: + cap = args.cap_y or _auto_cap(values) + plt.figure(figsize=(7.2, 4.2)) + ax = plt.gca() + if cap is not None: + show_vals = [min(v, cap) for v in values] + bars = [] + for i, (v, show) in enumerate(zip(values, show_vals)): + b = ax.bar(i, show, color=colors[i], width=0.8) + bars.append(b[0]) + if v > cap: + bars[-1].set_hatch("//") + ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9) + else: + ax.text( + i, + show + max(1.0, 0.01 * (cap or show)), + _fmt_ms(v), + ha="center", + va="bottom", + fontsize=9, + ) + ax.set_ylim(0, cap * 1.10) + ax.plot( + [0.02 - 0.02, 0.02 + 0.02], + [0.98 + 0.02, 0.98 - 0.02], + transform=ax.transAxes, + color="k", + lw=1, + ) + ax.plot( + [0.98 - 0.02, 0.98 + 0.02], + [0.98 + 0.02, 0.98 - 0.02], + transform=ax.transAxes, + color="k", + lw=1, + ) + if any(v > cap for v in values): + ax.legend( + [bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right" + ) + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels) + else: + ax.bar(labels, values, color=colors[: len(labels)]) + for idx, val in enumerate(values): + ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom") + + plt.ylabel("Average add latency (ms per passage)") + plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}") + plt.tight_layout() + plt.savefig(args.plot_path) + print(f"Saved latency bar plot to {args.plot_path}") + # ZMQ time split (Stage A vs B/C) + try: + plt.figure(figsize=(6, 4)) + a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels] + bc_vals = [ + sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels + ] + ind = range(len(labels)) + plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)") + plt.bar( + ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)" + ) + plt.xticks(list(ind), labels, rotation=10) + plt.ylabel("Server ZMQ time (s)") + plt.title( + f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})" + ) + plt.legend() + out2 = args.plot_path.with_name( + args.plot_path.stem + "_zmq_split" + args.plot_path.suffix + ) + plt.tight_layout() + plt.savefig(out2) + print(f"Saved ZMQ time split plot to {out2}") + except Exception as e: + print("Failed to plot ZMQ split:", e) + except ImportError: + print("matplotlib not available; skipping plot generation") + + # leave the last build on disk for inspection + + +if __name__ == "__main__": + main() diff --git a/benchmarks/update/bench_results.csv b/benchmarks/update/bench_results.csv new file mode 100644 index 0000000..767a8ef --- /dev/null +++ b/benchmarks/update/bench_results.csv @@ -0,0 +1,5 @@ +run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric +20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips diff --git a/benchmarks/update/bench_update_vs_offline_search.py b/benchmarks/update/bench_update_vs_offline_search.py new file mode 100644 index 0000000..250bd19 --- /dev/null +++ b/benchmarks/update/bench_update_vs_offline_search.py @@ -0,0 +1,704 @@ +""" +Compare two latency models for small incremental updates vs. search: + +Scenario A (sequential update then search): + - Build initial HNSW (is_recompute=True) + - Start embedding server (ZMQ) for recompute + - Add N passages one-by-one (each triggers recompute over ZMQ) + - Then run a search query on the updated index + - Report total time = sum(add_i) + search_time, with breakdowns + +Scenario B (offline embeds + concurrent search; no graph updates): + - Do NOT insert the N passages into the graph + - In parallel: (1) compute embeddings for the N passages; (2) compute query + embedding and run a search on the existing index + - After both finish, compute similarity between the query embedding and the N + new passage embeddings, merge with the index search results by score, and + report time = max(embed_time, search_time) (i.e., no blocking on updates) + +This script reuses the model/data loading conventions of +examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency +comparison for the two execution strategies above. + +Example (from the repository root): + uv run -m benchmarks.update.bench_update_vs_offline_search \ + --index-path .leann/bench/offline_vs_update.leann \ + --max-initial 300 --num-updates 5 --k 10 +""" + +import argparse +import csv +import json +import logging +import os +import pickle +import sys +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import psutil # type: ignore +from leann.api import LeannBuilder + +if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"): + os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") + +from leann.embedding_compute import compute_embeddings +from leann.embedding_server_manager import EmbeddingServerManager +from leann.registry import register_project_directory +from leann_backend_hnsw import faiss # type: ignore + +logger = logging.getLogger(__name__) +if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + + +def _find_repo_root() -> Path: + """Locate project root by walking up until pyproject.toml is found.""" + current = Path(__file__).resolve() + for parent in current.parents: + if (parent / "pyproject.toml").exists(): + return parent + # Fallback: assume repo is two levels up (../..) + return current.parents[2] + + +REPO_ROOT = _find_repo_root() +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from apps.chunking import create_text_chunks # noqa: E402 + +DEFAULT_INITIAL_FILES = [ + REPO_ROOT / "data" / "2501.14312v1 (1).pdf", + REPO_ROOT / "data" / "huawei_pangu.md", +] +DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"] + + +def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]: + from llama_index.core import SimpleDirectoryReader + + documents = [] + for path in paths: + p = path.expanduser().resolve() + if not p.exists(): + raise FileNotFoundError(f"Input path not found: {p}") + if p.is_dir(): + reader = SimpleDirectoryReader(str(p), recursive=False) + documents.extend(reader.load_data(show_progress=True)) + else: + reader = SimpleDirectoryReader(input_files=[str(p)]) + documents.extend(reader.load_data(show_progress=True)) + + if not documents: + return [] + + chunks = create_text_chunks( + documents, + chunk_size=512, + chunk_overlap=128, + use_ast_chunking=False, + ) + cleaned = [c for c in chunks if isinstance(c, str) and c.strip()] + if limit is not None: + cleaned = cleaned[:limit] + return cleaned + + +def ensure_index_dir(index_path: Path) -> None: + index_path.parent.mkdir(parents=True, exist_ok=True) + + +def cleanup_index_files(index_path: Path) -> None: + parent = index_path.parent + if not parent.exists(): + return + stem = index_path.stem + for file in parent.glob(f"{stem}*"): + if file.is_file(): + file.unlink() + + +def build_initial_index( + index_path: Path, + paragraphs: list[str], + model_name: str, + embedding_mode: str, + distance_metric: str, + ef_construction: int, +) -> None: + builder = LeannBuilder( + backend_name="hnsw", + embedding_model=model_name, + embedding_mode=embedding_mode, + is_compact=False, + is_recompute=True, + distance_metric=distance_metric, + backend_kwargs={ + "distance_metric": distance_metric, + "is_compact": False, + "is_recompute": True, + "efConstruction": ef_construction, + }, + ) + for idx, passage in enumerate(paragraphs): + builder.add_text(passage, metadata={"id": str(idx)}) + builder.build_index(str(index_path)) + + +def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray: + if metric == "cosine": + vecs = np.ascontiguousarray(vecs, dtype=np.float32) + norms = np.linalg.norm(vecs, axis=1, keepdims=True) + norms[norms == 0] = 1 + vecs = vecs / norms + return vecs + + +def _read_index_for_search(index_path: Path) -> Any: + index_file = index_path.parent / f"{index_path.stem}.index" + # Force-disable experimental disk cache when loading the index so that + # incremental benchmarks don't pick up stale top-degree bitmaps. + cfg = faiss.HNSWIndexConfig() + cfg.is_recompute = True + if hasattr(cfg, "disk_cache_ratio"): + cfg.disk_cache_ratio = 0.0 + if hasattr(cfg, "external_storage_path"): + cfg.external_storage_path = None + io_flags = getattr(faiss, "IO_FLAG_MMAP", 0) + index = faiss.read_index(str(index_file), io_flags, cfg) + # ensure recompute mode persists after reload + try: + index.is_recompute = True + except AttributeError: + pass + try: + actual_ntotal = index.hnsw.levels.size() + except AttributeError: + actual_ntotal = index.ntotal + if actual_ntotal != index.ntotal: + print( + f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}", + flush=True, + ) + index.ntotal = actual_ntotal + if getattr(index, "storage", None) is None: + if index.metric_type == faiss.METRIC_INNER_PRODUCT: + storage_index = faiss.IndexFlatIP(index.d) + else: + storage_index = faiss.IndexFlatL2(index.d) + index.storage = storage_index + index.own_fields = True + return index + + +def _append_passages_for_updates( + meta_path: Path, + start_id: int, + texts: list[str], +) -> list[str]: + """Append update passages so the embedding server can serve recompute fetches.""" + + if not texts: + return [] + + index_dir = meta_path.parent + meta_name = meta_path.name + if not meta_name.endswith(".meta.json"): + raise ValueError(f"Unexpected meta filename: {meta_path}") + index_base = meta_name[: -len(".meta.json")] + + passages_file = index_dir / f"{index_base}.passages.jsonl" + offsets_file = index_dir / f"{index_base}.passages.idx" + + if not passages_file.exists() or not offsets_file.exists(): + raise FileNotFoundError( + "Passage store missing; cannot register update passages for recompute mode." + ) + + with open(offsets_file, "rb") as f: + offset_map: dict[str, int] = pickle.load(f) + + assigned_ids: list[str] = [] + with open(passages_file, "a", encoding="utf-8") as f: + for i, text in enumerate(texts): + passage_id = str(start_id + i) + offset = f.tell() + json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False) + f.write("\n") + offset_map[passage_id] = offset + assigned_ids.append(passage_id) + + with open(offsets_file, "wb") as f: + pickle.dump(offset_map, f) + + try: + with open(meta_path, encoding="utf-8") as f: + meta = json.load(f) + except json.JSONDecodeError: + meta = {} + meta["total_passages"] = len(offset_map) + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(meta, f, indent=2) + + return assigned_ids + + +def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: + q = np.ascontiguousarray(q, dtype=np.float32) + distances = np.zeros((1, k), dtype=np.float32) + indices = np.zeros((1, k), dtype=np.int64) + index.search( + 1, + faiss.swig_ptr(q), + k, + faiss.swig_ptr(distances), + faiss.swig_ptr(indices), + ) + return distances[0], indices[0] + + +def _score_for_metric(dist: float, metric: str) -> float: + # Convert FAISS distance to a "higher is better" score + if metric in ("mips", "cosine"): + return float(dist) + # l2 distance (smaller better) -> negative distance as score + return -float(dist) + + +def _merge_results( + index_results: tuple[np.ndarray, np.ndarray], + offline_scores: list[tuple[int, float]], + k: int, + metric: str, +) -> list[tuple[str, float]]: + distances, indices = index_results + merged: list[tuple[str, float]] = [] + for distance, idx in zip(distances.tolist(), indices.tolist()): + merged.append((f"idx:{idx}", _score_for_metric(distance, metric))) + for j, s in offline_scores: + merged.append((f"offline:{j}", s)) + merged.sort(key=lambda x: x[1], reverse=True) + return merged[:k] + + +@dataclass +class ScenarioResult: + name: str + update_total_s: float + search_s: float + overall_s: float + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--index-path", + type=Path, + default=Path(".leann/bench/offline-vs-update.leann"), + ) + parser.add_argument( + "--initial-files", + nargs="*", + type=Path, + default=DEFAULT_INITIAL_FILES, + ) + parser.add_argument( + "--update-files", + nargs="*", + type=Path, + default=DEFAULT_UPDATE_FILES, + ) + parser.add_argument("--max-initial", type=int, default=300) + parser.add_argument("--num-updates", type=int, default=5) + parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge") + parser.add_argument( + "--query", + type=str, + default="neural network", + help="Query text used for the search benchmark.", + ) + parser.add_argument("--server-port", type=int, default=5557) + parser.add_argument("--add-timeout", type=int, default=600) + parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2") + parser.add_argument("--embedding-mode", default="sentence-transformers") + parser.add_argument( + "--distance-metric", + default="mips", + choices=["mips", "l2", "cosine"], + ) + parser.add_argument("--ef-construction", type=int, default=200) + parser.add_argument( + "--only", + choices=["A", "B", "both"], + default="both", + help="Run only Scenario A, Scenario B, or both", + ) + parser.add_argument( + "--csv-path", + type=Path, + default=Path("benchmarks/update/offline_vs_update.csv"), + help="Where to append results (CSV).", + ) + + args = parser.parse_args() + + register_project_directory(REPO_ROOT) + + # Load data + initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial) + update_paragraphs = load_chunks_from_files(args.update_files, None) + if not update_paragraphs: + raise ValueError("No update passages loaded from --update-files") + update_paragraphs = update_paragraphs[: args.num_updates] + if len(update_paragraphs) < args.num_updates: + raise ValueError( + f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}" + ) + + ensure_index_dir(args.index_path) + cleanup_index_files(args.index_path) + + # Build initial index + build_initial_index( + args.index_path, + initial_paragraphs, + args.model_name, + args.embedding_mode, + args.distance_metric, + args.ef_construction, + ) + + # Prepare index object and meta + meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json" + index = _read_index_for_search(args.index_path) + + # CSV setup + run_id = time.strftime("%Y%m%d-%H%M%S") + if args.csv_path: + args.csv_path.parent.mkdir(parents=True, exist_ok=True) + csv_fields = [ + "run_id", + "scenario", + "max_initial", + "num_updates", + "k", + "total_time_s", + "add_total_s", + "search_time_s", + "emb_time_s", + "makespan_s", + "model_name", + "embedding_mode", + "distance_metric", + ] + if not args.csv_path.exists() or args.csv_path.stat().st_size == 0: + with args.csv_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writeheader() + + # Debug: list existing HNSW server PIDs before starting + try: + existing = [ + p + for p in psutil.process_iter(attrs=["pid", "cmdline"]) + if any( + isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg + for arg in (p.info.get("cmdline") or []) + ) + ] + if existing: + print("[debug] Found existing hnsw_embedding_server processes before run:") + for p in existing: + print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}") + except Exception as _e: + pass + + add_total = 0.0 + search_after_add = 0.0 + total_seq = 0.0 + port_a = None + if args.only in ("A", "both"): + # Scenario A: sequential update then search + start_id = index.ntotal + assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs) + if assigned_ids: + logger.debug( + "Registered %d update passages starting at id %s", + len(assigned_ids), + assigned_ids[0], + ) + server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) + ok, port = server_manager.start_server( + port=args.server_port, + model_name=args.model_name, + embedding_mode=args.embedding_mode, + passages_file=str(meta_path), + distance_metric=args.distance_metric, + ) + if not ok: + raise RuntimeError("Failed to start embedding server") + try: + # Set ZMQ port for recompute mode + if hasattr(index.hnsw, "set_zmq_port"): + index.hnsw.set_zmq_port(port) + elif hasattr(index, "set_zmq_port"): + index.set_zmq_port(port) + + # Start A overall timer BEFORE computing update embeddings + t0 = time.time() + + # Compute embeddings for updates (counted into A's overall) + t_emb0 = time.time() + upd_embs = compute_embeddings( + update_paragraphs, + args.model_name, + mode=args.embedding_mode, + is_build=False, + batch_size=16, + ) + emb_time_updates = time.time() - t_emb0 + upd_embs = np.asarray(upd_embs, dtype=np.float32) + upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric) + + # Perform sequential adds + for i in range(upd_embs.shape[0]): + t_add0 = time.time() + index.add(1, faiss.swig_ptr(upd_embs[i : i + 1])) + add_total += time.time() - t_add0 + # Don't persist index after adds to avoid contaminating Scenario B + # index_file = args.index_path.parent / f"{args.index_path.stem}.index" + # faiss.write_index(index, str(index_file)) + + # Search after updates + q_emb = compute_embeddings( + [args.query], args.model_name, mode=args.embedding_mode, is_build=False + ) + q_emb = np.asarray(q_emb, dtype=np.float32) + q_emb = _maybe_norm_cosine(q_emb, args.distance_metric) + + # Warm up search with a dummy query first + print("[DEBUG] Warming up search...") + _ = _search(index, q_emb, 1) + + t_s0 = time.time() + D_upd, I_upd = _search(index, q_emb, args.k) + search_after_add = time.time() - t_s0 + total_seq = time.time() - t0 + finally: + server_manager.stop_server() + port_a = port + + print("\n=== Scenario A: update->search (sequential) ===") + # emb_time_updates is defined only when A runs + try: + _emb_a = emb_time_updates + except NameError: + _emb_a = 0.0 + print( + f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; " + f"search={search_after_add:.3f}s; overall={total_seq:.3f}s" + ) + # CSV row for A + if args.csv_path: + row_a = { + "run_id": run_id, + "scenario": "A", + "max_initial": args.max_initial, + "num_updates": args.num_updates, + "k": args.k, + "total_time_s": round(total_seq, 6), + "add_total_s": round(add_total, 6), + "search_time_s": round(search_after_add, 6), + "emb_time_s": round(_emb_a, 6), + "makespan_s": 0.0, + "model_name": args.model_name, + "embedding_mode": args.embedding_mode, + "distance_metric": args.distance_metric, + } + with args.csv_path.open("a", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writerow(row_a) + + # Verify server cleanup + try: + # short sleep to allow signal handling to finish + time.sleep(0.5) + leftovers = [ + p + for p in psutil.process_iter(attrs=["pid", "cmdline"]) + if any( + isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg + for arg in (p.info.get("cmdline") or []) + ) + ] + if leftovers: + print("[warn] hnsw_embedding_server process(es) still alive after A-stop:") + for p in leftovers: + print( + f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}" + ) + else: + print("[debug] server cleanup confirmed: no hnsw_embedding_server found") + except Exception: + pass + + # Scenario B: offline embeds + concurrent search (no graph updates) + if args.only in ("B", "both"): + # ensure a server is available for recompute search + server_manager_b = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) + requested_port = args.server_port if port_a is None else port_a + ok_b, port_b = server_manager_b.start_server( + port=requested_port, + model_name=args.model_name, + embedding_mode=args.embedding_mode, + passages_file=str(meta_path), + distance_metric=args.distance_metric, + ) + if not ok_b: + raise RuntimeError("Failed to start embedding server for Scenario B") + + # Wait for server to fully initialize + print("[DEBUG] Waiting 2s for embedding server to fully initialize...") + time.sleep(2) + + try: + # Read the index first + index_no_update = _read_index_for_search(args.index_path) # unchanged index + + # Then configure ZMQ port on the correct index object + if hasattr(index_no_update.hnsw, "set_zmq_port"): + index_no_update.hnsw.set_zmq_port(port_b) + elif hasattr(index_no_update, "set_zmq_port"): + index_no_update.set_zmq_port(port_b) + + # Warmup the embedding model before benchmarking (do this for both --only B and --only both) + # This ensures fair comparison as Scenario A has warmed up the model during update embeddings + logger.info("Warming up embedding model for Scenario B...") + _ = compute_embeddings( + ["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False + ) + + # Prepare worker A: compute embeddings for the same N passages + emb_time = 0.0 + updates_embs_offline: np.ndarray | None = None + + def _worker_emb(): + nonlocal emb_time, updates_embs_offline + t = time.time() + updates_embs_offline = compute_embeddings( + update_paragraphs, + args.model_name, + mode=args.embedding_mode, + is_build=False, + batch_size=16, + ) + emb_time = time.time() - t + + # Pre-compute query embedding and warm up search outside of timed section. + q_vec = compute_embeddings( + [args.query], args.model_name, mode=args.embedding_mode, is_build=False + ) + q_vec = np.asarray(q_vec, dtype=np.float32) + q_vec = _maybe_norm_cosine(q_vec, args.distance_metric) + print("[DEBUG B] Warming up search...") + _ = _search(index_no_update, q_vec, 1) + + # Worker B: timed search on the warmed index + search_time = 0.0 + offline_elapsed = 0.0 + index_results: tuple[np.ndarray, np.ndarray] | None = None + + def _worker_search(): + nonlocal search_time, index_results + t = time.time() + distances, indices = _search(index_no_update, q_vec, args.k) + search_time = time.time() - t + index_results = (distances, indices) + + # Run two workers concurrently + t0 = time.time() + th1 = threading.Thread(target=_worker_emb) + th2 = threading.Thread(target=_worker_search) + th1.start() + th2.start() + th1.join() + th2.join() + offline_elapsed = time.time() - t0 + + # For mixing: compute query vs. offline update similarities (pure client-side) + offline_scores: list[tuple[int, float]] = [] + if updates_embs_offline is not None: + upd2 = np.asarray(updates_embs_offline, dtype=np.float32) + upd2 = _maybe_norm_cosine(upd2, args.distance_metric) + # For mips/cosine, score = dot; for l2, score = -||x-y||^2 + for j in range(upd2.shape[0]): + if args.distance_metric in ("mips", "cosine"): + s = float(np.dot(q_vec[0], upd2[j])) + else: + diff = q_vec[0] - upd2[j] + s = -float(np.dot(diff, diff)) + offline_scores.append((j, s)) + + merged_topk = ( + _merge_results(index_results, offline_scores, args.k, args.distance_metric) + if index_results + else [] + ) + + print("\n=== Scenario B: offline embeds + concurrent search (no add) ===") + print( + f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)" + ) + if merged_topk: + preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]]) + print(f"Merged top-5 preview: {preview}") + # CSV row for B + if args.csv_path: + row_b = { + "run_id": run_id, + "scenario": "B", + "max_initial": args.max_initial, + "num_updates": args.num_updates, + "k": args.k, + "total_time_s": 0.0, + "add_total_s": 0.0, + "search_time_s": round(search_time, 6), + "emb_time_s": round(emb_time, 6), + "makespan_s": round(offline_elapsed, 6), + "model_name": args.model_name, + "embedding_mode": args.embedding_mode, + "distance_metric": args.distance_metric, + } + with args.csv_path.open("a", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writerow(row_b) + + finally: + server_manager_b.stop_server() + + # Summary + print("\n=== Summary ===") + msg_a = ( + f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)" + if args.only in ("A", "both") + else "A: skipped" + ) + msg_b = ( + f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)" + if args.only in ("B", "both") + else "B: skipped" + ) + print(msg_a + "\n" + msg_b) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/update/offline_vs_update.csv b/benchmarks/update/offline_vs_update.csv new file mode 100644 index 0000000..aa979bb --- /dev/null +++ b/benchmarks/update/offline_vs_update.csv @@ -0,0 +1,5 @@ +run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric +20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips +20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips diff --git a/benchmarks/update/plot_bench_results.py b/benchmarks/update/plot_bench_results.py new file mode 100644 index 0000000..3c9e56f --- /dev/null +++ b/benchmarks/update/plot_bench_results.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +""" +Plot latency bars from the benchmark CSV produced by +benchmarks/update/bench_hnsw_rng_recompute.py. + +If you also provide an offline_vs_update.csv via --csv-right +(from benchmarks/update/bench_update_vs_offline_search.py), this script will +output a side-by-side figure: +- Left: ms/passage bars (four RNG scenarios). +- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search). + +Usage: + uv run python benchmarks/update/plot_bench_results.py \ + --csv benchmarks/update/bench_results.csv \ + --out benchmarks/update/bench_latency_from_csv.png + +The script selects the latest run_id in the CSV and plots four bars for +the default scenarios: + - baseline + - no_cache_baseline + - disable_forward_rng + - disable_forward_and_reverse_rng + +If multiple rows exist per scenario for that run_id, the script averages +their latency_ms_per_passage values. +""" + +import argparse +import csv +from collections import defaultdict +from pathlib import Path + +DEFAULT_SCENARIOS = [ + "no_cache_baseline", + "baseline", + "disable_forward_rng", + "disable_forward_and_reverse_rng", +] + +SCENARIO_LABELS = { + "baseline": "+ Cache", + "no_cache_baseline": "Naive \n Recompute", + "disable_forward_rng": "+ w/o \n Fwd RNG", + "disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG", +} + +# Paper-style colors and hatches for scenarios +SCENARIO_STYLES = { + "no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"}, + "baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"}, + "disable_forward_rng": {"edgecolor": "green", "hatch": "....."}, + "disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"}, +} + + +def load_latest_run(csv_path: Path): + rows = [] + with csv_path.open("r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + rows.append(row) + if not rows: + raise SystemExit("CSV is empty: no rows to plot") + # Choose latest run_id lexicographically (YYYYMMDD-HHMMSS) + run_ids = [r.get("run_id", "") for r in rows] + latest = max(run_ids) + latest_rows = [r for r in rows if r.get("run_id", "") == latest] + if not latest_rows: + # Fallback: take last 4 rows + latest_rows = rows[-4:] + latest = latest_rows[-1].get("run_id", "unknown") + return latest, latest_rows + + +def aggregate_latency(rows): + acc = defaultdict(list) + for r in rows: + sc = r.get("scenario", "") + try: + val = float(r.get("latency_ms_per_passage", "nan")) + except ValueError: + continue + acc[sc].append(val) + avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()} + return avg + + +def _auto_cap(values: list[float]) -> float | None: + if not values: + return None + sorted_vals = sorted(values, reverse=True) + if len(sorted_vals) < 2: + return None + max_v, second = sorted_vals[0], sorted_vals[1] + if second <= 0: + return None + # If the tallest bar dwarfs the second by 2.5x+, cap near the second + if max_v >= 2.5 * second: + return second * 1.1 + return None + + +def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02): + # Draw small diagonal ticks near left/right to signal cap + x0, x1 = rel_x0, rel_x1 + ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1) + ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1) + + +def _fmt_ms(v: float) -> str: + if v >= 1000: + return f"{v / 1000:.1f}k" + return f"{v:.1f}" + + +def main(): + # Set LaTeX style for paper figures (matching paper_fig.py) + import matplotlib.pyplot as plt + + plt.rcParams["font.family"] = "Helvetica" + plt.rcParams["ytick.direction"] = "in" + plt.rcParams["hatch.linewidth"] = 1.5 + plt.rcParams["font.weight"] = "bold" + plt.rcParams["axes.labelweight"] = "bold" + plt.rcParams["text.usetex"] = True + + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument( + "--csv", + type=Path, + default=Path("benchmarks/update/bench_results.csv"), + help="Path to results CSV (defaults to bench_results.csv)", + ) + ap.add_argument( + "--out", + type=Path, + default=Path("add_ablation.pdf"), + help="Output image path", + ) + ap.add_argument( + "--csv-right", + type=Path, + default=Path("benchmarks/update/offline_vs_update.csv"), + help="Optional: offline_vs_update.csv to render right subplot (A vs B)", + ) + ap.add_argument( + "--cap-y", + type=float, + default=None, + help="Cap Y-axis at this ms value; bars above are hatched and annotated.", + ) + ap.add_argument( + "--no-auto-cap", + action="store_true", + help="Disable auto-cap heuristic when --cap-y is not provided.", + ) + ap.add_argument( + "--broken-y", + action="store_true", + default=True, + help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.", + ) + ap.add_argument( + "--lower-cap-y", + type=float, + default=None, + help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.", + ) + ap.add_argument( + "--upper-start-y", + type=float, + default=None, + help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.", + ) + args = ap.parse_args() + + latest_run, latest_rows = load_latest_run(args.csv) + avg = aggregate_latency(latest_rows) + + try: + import matplotlib.pyplot as plt + except Exception as e: + raise SystemExit(f"matplotlib not available: {e}") + + scenarios = DEFAULT_SCENARIOS + values = [avg.get(name, 0.0) for name in scenarios] + labels = [SCENARIO_LABELS.get(name, name) for name in scenarios] + colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"] + + # If right CSV is provided, build side-by-side figure + if args.csv_right is not None: + try: + right_rows_all = [] + with args.csv_right.open("r", encoding="utf-8") as f: + rreader = csv.DictReader(f) + right_rows_all = list(rreader) + if right_rows_all: + r_latest = max(r.get("run_id", "") for r in right_rows_all) + right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest] + else: + r_latest = None + right_rows = [] + except Exception: + r_latest = None + right_rows = [] + + a_total = 0.0 + b_makespan = 0.0 + for r in right_rows: + sc = (r.get("scenario", "") or "").strip().upper() + if sc == "A": + try: + a_total = float(r.get("total_time_s", 0.0)) + except Exception: + pass + elif sc == "B": + try: + b_makespan = float(r.get("makespan_s", 0.0)) + except Exception: + pass + + import matplotlib.pyplot as plt + from matplotlib import gridspec + + # Left subplot (reuse current style, with optional cap) + cap = args.cap_y + if cap is None and not args.no_auto_cap: + cap = _auto_cap(values) + x = list(range(len(labels))) + + if args.broken_y: + # Use broken axis for left subplot + # Auto-adjust width ratios: left has 4 bars, right has 2 bars + fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80% + gs = gridspec.GridSpec( + 2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35 + ) + ax_left_top = fig.add_subplot(gs[0, 0]) + ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top) + ax_right = fig.add_subplot(gs[:, 1]) + + # Determine break points + s = sorted(values, reverse=True) + second = s[1] if len(s) >= 2 else (s[0] if s else 0.0) + lower_cap = ( + args.lower_cap_y if args.lower_cap_y is not None else second * 1.4 + ) # Increased to show more range + upper_start = ( + args.upper_start_y + if args.upper_start_y is not None + else max(second * 1.5, lower_cap * 1.02) + ) + ymax = ( + max(values) * 1.90 if values else 1.0 + ) # Increase headroom to 1.90 for text label and tick range + + # Draw bars on both axes + ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8) + ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8) + + # Set limits + ax_left_bottom.set_ylim(0, lower_cap) + ax_left_top.set_ylim(upper_start, ymax) + + # Annotate values (convert ms to s) + values_s = [v / 1000.0 for v in values] + lower_cap_s = lower_cap / 1000.0 + upper_start_s = upper_start / 1000.0 + ymax_s = ymax / 1000.0 + + ax_left_bottom.set_ylim(0, lower_cap_s) + ax_left_top.set_ylim(upper_start_s, ymax_s) + + # Redraw bars with s values (paper style: white fill + colored edge + hatch) + ax_left_bottom.clear() + ax_left_top.clear() + bar_width = 0.50 # Reduced for wider spacing between bars + for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)): + style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""}) + # Draw in bottom axis for all bars + ax_left_bottom.bar( + i, + v, + width=bar_width, + color="white", + edgecolor=style["edgecolor"], + hatch=style["hatch"], + linewidth=1.2, + ) + # Only draw in top axis if the bar is tall enough to reach the upper range + if v > upper_start_s: + ax_left_top.bar( + i, + v, + width=bar_width, + color="white", + edgecolor=style["edgecolor"], + hatch=style["hatch"], + linewidth=1.2, + ) + ax_left_bottom.set_ylim(0, lower_cap_s) + ax_left_top.set_ylim(upper_start_s, ymax_s) + + for i, v in enumerate(values_s): + if v <= lower_cap_s: + ax_left_bottom.text( + i, + v + lower_cap_s * 0.02, + f"{v:.2f}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + else: + ax_left_top.text( + i, + v + (ymax_s - upper_start_s) * 0.02, + f"{v:.2f}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + + # Hide spines between axes + ax_left_top.spines["bottom"].set_visible(False) + ax_left_bottom.spines["top"].set_visible(False) + ax_left_top.tick_params( + labeltop=False, labelbottom=False, bottom=False + ) # Hide tick marks + ax_left_bottom.xaxis.tick_bottom() + ax_left_bottom.tick_params(top=False) # Hide top tick marks + + # Draw break marks (matching paper_fig.py style) + d = 0.015 + kwargs = { + "transform": ax_left_top.transAxes, + "color": "k", + "clip_on": False, + "linewidth": 0.8, + "zorder": 10, + } + ax_left_top.plot((-d, +d), (-d, +d), **kwargs) + ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) + kwargs.update({"transform": ax_left_bottom.transAxes}) + ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) + ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) + + ax_left_bottom.set_xticks(x) + ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7) + # Don't set ylabel here - will use fig.text for alignment + ax_left_bottom.tick_params(axis="y", labelsize=10) + ax_left_top.tick_params(axis="y", labelsize=10) + # Add subtle grid for better readability + ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold") + + # Set x-axis limits to match bar width with right subplot + ax_left_bottom.set_xlim(-0.6, 3.6) + ax_left_top.set_xlim(-0.6, 3.6) + + ax_left = ax_left_bottom # for compatibility + else: + # Regular side-by-side layout + fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15)) + + if cap is not None: + show_vals = [min(v, cap) for v in values] + bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8) + for i, (val, show) in enumerate(zip(values, show_vals)): + if val > cap: + bars[i].set_hatch("//") + ax_left.text( + i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9 + ) + else: + ax_left.text( + i, + show + max(1.0, 0.01 * (cap or show)), + _fmt_ms(val), + ha="center", + va="bottom", + fontsize=9, + ) + ax_left.set_ylim(0, cap * 1.10) + _add_break_marker(ax_left, y=0.98) + ax_left.set_xticks(x) + ax_left.set_xticklabels(labels, rotation=0, fontsize=10) + else: + ax_left.bar(x, values, color=colors[: len(labels)], width=0.8) + for i, v in enumerate(values): + ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9) + ax_left.set_xticks(x) + ax_left.set_xticklabels(labels, rotation=0, fontsize=10) + ax_left.set_ylabel("Latency (ms per passage)") + max_initial = latest_rows[0].get("max_initial", "?") + max_updates = latest_rows[0].get("max_updates", "?") + ax_left.set_title( + f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}" + ) + + # Right subplot (A vs B, seconds) - paper style + r_labels = ["Sequential", "Delayed \n Add+Search"] + r_values = [a_total or 0.0, b_makespan or 0.0] + r_styles = [ + {"edgecolor": "#59a14f", "hatch": "xxxxx"}, + {"edgecolor": "#edc948", "hatch": "/////"}, + ] + # 2 bars, centered with proper spacing + xr = [0, 1] + bar_width = 0.50 # Reduced for wider spacing between bars + for i, (v, style) in enumerate(zip(r_values, r_styles)): + ax_right.bar( + xr[i], + v, + width=bar_width, + color="white", + edgecolor=style["edgecolor"], + hatch=style["hatch"], + linewidth=1.2, + ) + for i, v in enumerate(r_values): + max_v = max(r_values) if r_values else 1.0 + offset = max(0.0002, 0.02 * max_v) + ax_right.text( + xr[i], + v + offset, + f"{v:.2f}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + ax_right.set_xticks(xr) + ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7) + # Don't set ylabel here - will use fig.text for alignment + ax_right.tick_params(axis="y", labelsize=10) + # Add subtle grid for better readability + ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold") + + # Set x-axis limits to match left subplot's bar width visually + # Accounting for width_ratios=[1.5, 1]: + # Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit + # bar_width_visual = 0.72 * (1.5*unit / 4.2) + # Right: 2 bars, need same visual width + # 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2) + # range_right = 4.2 / 1.5 = 2.8 + # For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9 + ax_right.set_xlim(-0.9, 1.9) + + # Set y-axis limit with headroom for text labels + if r_values: + max_v = max(r_values) + ax_right.set_ylim(0, max_v * 1.15) + + # Format y-axis to avoid scientific notation + ax_right.ticklabel_format(style="plain", axis="y") + + plt.tight_layout() + + # Add aligned ylabels using fig.text (after tight_layout) + # Get the vertical center of the entire figure + fig_center_y = 0.5 + # Left ylabel - closer to left plot + left_x = 0.05 + fig.text( + left_x, + fig_center_y, + "Latency (s)", + va="center", + rotation="vertical", + fontsize=11, + fontweight="bold", + ) + # Right ylabel - closer to right plot + right_bbox = ax_right.get_position() + right_x = right_bbox.x0 - 0.07 + fig.text( + right_x, + fig_center_y, + "Latency (s)", + va="center", + rotation="vertical", + fontsize=11, + fontweight="bold", + ) + + plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05) + # Also save PDF for paper + pdf_out = args.out.with_suffix(".pdf") + plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05) + print(f"Saved: {args.out}") + print(f"Saved: {pdf_out}") + return + + # Broken-Y mode + if args.broken_y: + import matplotlib.pyplot as plt + + fig, (ax_top, ax_bottom) = plt.subplots( + 2, + 1, + sharex=True, + figsize=(7.5, 6.75), + gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08}, + ) + + # Determine default breaks from second-highest + s = sorted(values, reverse=True) + second = s[1] if len(s) >= 2 else (s[0] if s else 0.0) + lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1 + upper_start = ( + args.upper_start_y + if args.upper_start_y is not None + else max(second * 1.2, lower_cap * 1.02) + ) + ymax = max(values) * 1.10 if values else 1.0 + + x = list(range(len(labels))) + ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8) + ax_top.bar(x, values, color=colors[: len(labels)], width=0.8) + + # Limits + ax_bottom.set_ylim(0, lower_cap) + ax_top.set_ylim(upper_start, ymax) + + # Annotate values + for i, v in enumerate(values): + if v <= lower_cap: + ax_bottom.text( + i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9 + ) + else: + ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9) + + # Hide spines between axes and draw diagonal break marks + ax_top.spines["bottom"].set_visible(False) + ax_bottom.spines["top"].set_visible(False) + ax_top.tick_params(labeltop=False) # don't put tick labels at the top + ax_bottom.xaxis.tick_bottom() + + # Diagonal lines at the break (matching paper_fig.py style) + d = 0.015 + kwargs = { + "transform": ax_top.transAxes, + "color": "k", + "clip_on": False, + "linewidth": 0.8, + "zorder": 10, + } + ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal + ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal + kwargs.update({"transform": ax_bottom.transAxes}) + ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal + ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal + + ax_bottom.set_xticks(x) + ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10) + ax = ax_bottom # for labeling below + else: + cap = args.cap_y + if cap is None and not args.no_auto_cap: + cap = _auto_cap(values) + + plt.figure(figsize=(5.4, 3.15)) + ax = plt.gca() + + if cap is not None: + show_vals = [min(v, cap) for v in values] + bars = [] + for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)): + bar = ax.bar(i, show, color=colors[i], width=0.8) + bars.append(bar[0]) + # Hatch and annotate when capped + if val > cap: + bars[-1].set_hatch("//") + ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9) + else: + ax.text( + i, + show + max(1.0, 0.01 * (cap or show)), + f"{_fmt_ms(val)}", + ha="center", + va="bottom", + fontsize=9, + ) + ax.set_ylim(0, cap * 1.10) + _add_break_marker(ax, y=0.98) + ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any( + v > cap for v in values + ) else None + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels, fontsize=11, fontweight="bold") + else: + ax.bar(labels, values, color=colors[: len(labels)]) + for idx, val in enumerate(values): + ax.text( + idx, + val + 1.0, + f"{_fmt_ms(val)}", + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) + ax.set_xticklabels(labels, fontsize=11, fontweight="bold") + # Try to extract some context for title + max_initial = latest_rows[0].get("max_initial", "?") + max_updates = latest_rows[0].get("max_updates", "?") + + if args.broken_y: + fig.text( + 0.02, + 0.5, + "Latency (s)", + va="center", + rotation="vertical", + fontsize=11, + fontweight="bold", + ) + fig.suptitle( + "Add Operation Latency", + fontsize=11, + y=0.98, + fontweight="bold", + ) + plt.tight_layout(rect=(0.03, 0.04, 1, 0.96)) + else: + plt.ylabel("Latency (s)", fontsize=11, fontweight="bold") + plt.title("Add Operation Latency", fontsize=11, fontweight="bold") + plt.tight_layout() + + plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05) + # Also save PDF for paper + pdf_out = args.out.with_suffix(".pdf") + plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05) + print(f"Saved: {args.out}") + print(f"Saved: {pdf_out}") + + +if __name__ == "__main__": + main() diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 6a831a5..7022009 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher): if recompute_embeddings: if zmq_port is None: raise ValueError("zmq_port must be provided if recompute_embeddings is True") + if hasattr(self._index, "set_zmq_port"): + self._index.set_zmq_port(zmq_port) if query.dtype != np.float32: query = query.astype(np.float32) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index c2f1d93..bbcc8a3 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -820,10 +820,10 @@ class LeannBuilder: actual_port, requested_zmq_port, ) - try: - index.hnsw.zmq_port = actual_port - except AttributeError: - pass + if hasattr(index.hnsw, "set_zmq_port"): + index.hnsw.set_zmq_port(actual_port) + elif hasattr(index, "set_zmq_port"): + index.set_zmq_port(actual_port) if needs_recompute: for i in range(embeddings.shape[0]):