Faster Update (#148)
* stash * stash * add std err in add and trace progress * fix. * docs * style: format * docs * better figs * better figs * update results * fotmat --------- Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
This commit is contained in:
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# Update Benchmarks
|
||||||
|
|
||||||
|
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||||
|
search” pipeline under different assumptions:
|
||||||
|
|
||||||
|
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||||
|
settings influence incremental `add()` latency when embeddings are fetched
|
||||||
|
over the ZMQ embedding server.
|
||||||
|
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||||
|
against an offline approach that keeps the graph static and fuses results.
|
||||||
|
|
||||||
|
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||||
|
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||||
|
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
### 1. HNSW RNG Recompute Benchmark
|
||||||
|
|
||||||
|
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||||
|
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||||
|
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||||
|
is enabled:
|
||||||
|
|
||||||
|
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||||
|
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||||
|
| `baseline` | Enabled | Enabled | Enabled |
|
||||||
|
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||||
|
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||||
|
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||||
|
|
||||||
|
For each scenario the script:
|
||||||
|
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||||
|
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||||
|
3. Appends the requested updates using the scenario’s RNG flags.
|
||||||
|
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||||
|
timings before appending a row to the CSV output.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||||
|
LEANN_LOG_LEVEL=INFO \
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--runs 1 \
|
||||||
|
--index-path .leann/bench/test.leann \
|
||||||
|
--initial-files data/PrideandPrejudice.txt \
|
||||||
|
--update-files data/huawei_pangu.md \
|
||||||
|
--max-initial 300 \
|
||||||
|
--max-updates 1 \
|
||||||
|
--add-timeout 120
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||||
|
(including ms/passage) for each run.
|
||||||
|
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||||
|
`LEANN_HNSW_LOG_PATH`).
|
||||||
|
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||||
|
|
||||||
|
### 2. Sequential vs. Offline Update Benchmark
|
||||||
|
|
||||||
|
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||||
|
same dataset:
|
||||||
|
|
||||||
|
- **Scenario A – Sequential Update**
|
||||||
|
- Start an embedding server.
|
||||||
|
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||||
|
mutates the HNSW graph.
|
||||||
|
- After all inserts, run a search on the updated graph.
|
||||||
|
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||||
|
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||||
|
latency.
|
||||||
|
|
||||||
|
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||||
|
- Stop Scenario A’s server and start a fresh embedding server.
|
||||||
|
- Spawn two threads: one generates embeddings for the new passages offline
|
||||||
|
(graph unchanged); the other computes the query embedding and searches the
|
||||||
|
existing graph.
|
||||||
|
- Merge offline similarities with the graph search results to emulate late
|
||||||
|
fusion, then report the merged top‑k preview.
|
||||||
|
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||||
|
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||||
|
|
||||||
|
**Run (both scenarios):**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 \
|
||||||
|
--num-updates 1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||||
|
print timing summaries to stdout and append the results to CSV.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||||
|
Scenario A and B.
|
||||||
|
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||||
|
checks.
|
||||||
|
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||||
|
|
||||||
|
### 3. Visualisation
|
||||||
|
|
||||||
|
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||||
|
benchmark into a single two-panel plot.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.plot_bench_results \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||||
|
- `--csv` – RNG benchmark results CSV (left panel).
|
||||||
|
- `--csv-right` – Update strategy results CSV (right panel).
|
||||||
|
- `--out` – Output image path (PNG/PDF supported).
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||||
|
suites.
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||||
|
slides/papers.
|
||||||
|
|
||||||
|
## Parameters & Environment
|
||||||
|
|
||||||
|
### Common CLI Flags
|
||||||
|
- `--max-initial` – Number of initial passages used to seed the index.
|
||||||
|
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||||
|
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||||
|
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||||
|
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||||
|
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||||
|
execution of the embedding model.
|
||||||
|
|
||||||
|
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||||
|
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||||
|
fusion better match your latency/accuracy trade-offs.
|
||||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Benchmarks for LEANN update workflows."""
|
||||||
|
|
||||||
|
# Expose helper to locate repository root for other modules that need it.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_repo_root() -> Path:
|
||||||
|
"""Return the project root containing pyproject.toml."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
return current.parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["find_repo_root"]
|
||||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
|||||||
|
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||||
|
embedding recomputation.
|
||||||
|
|
||||||
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||||
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||||
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||||
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||||
|
|
||||||
|
Example usage (run from the repo root; downloads the model on first run)::
|
||||||
|
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--index-path .leann/bench/leann-demo.leann \
|
||||||
|
--runs 1
|
||||||
|
|
||||||
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||||
|
if you want a larger or different workload, and change the embedding model via
|
||||||
|
``--model-name``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||||
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_update_with_mode(
|
||||||
|
index_path: Path,
|
||||||
|
new_chunks: list[dict[str, Any]],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
disable_forward_rng: bool,
|
||||||
|
disable_reverse_rng: bool,
|
||||||
|
server_port: int,
|
||||||
|
add_timeout: int,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in new_chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
index.is_recompute = True
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
try:
|
||||||
|
storage_index.ntotal = index.ntotal
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||||
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = ef_construction
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||||
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||||
|
logger.info(
|
||||||
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||||
|
disable_forward_rng,
|
||||||
|
disable_reverse_rng,
|
||||||
|
applied_forward,
|
||||||
|
applied_reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
|
offset_map_backup = offset_map.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=server_port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError("Failed to start embedding server.")
|
||||||
|
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
|
_warmup_embedding_server(actual_port)
|
||||||
|
|
||||||
|
total_start = time.time()
|
||||||
|
add_elapsed = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("incremental add timed out")
|
||||||
|
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(add_timeout)
|
||||||
|
|
||||||
|
add_start = time.time()
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
add_elapsed = time.time() - add_start
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_size)
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map_backup, f)
|
||||||
|
raise
|
||||||
|
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(False)
|
||||||
|
index.hnsw.set_disable_reverse_prune(False)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
total_elapsed = time.time() - total_start
|
||||||
|
|
||||||
|
return total_elapsed, add_elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _total_zmq_nodes(log_path: Path) -> int:
|
||||||
|
if not log_path.exists():
|
||||||
|
return 0
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
text = log_file.read()
|
||||||
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_embedding_server(port: int) -> None:
|
||||||
|
"""Send a dummy REQ so the embedding server loads its model."""
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
sock = ctx.socket(zmq.REQ)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||||
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||||
|
sock.send(payload)
|
||||||
|
try:
|
||||||
|
sock.recv()
|
||||||
|
except zmq.error.Again:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/leann-demo.leann"),
|
||||||
|
help="Output index base path (without extension).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Files used to build the initial index.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Files appended during the benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model used for build/update.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
default="sentence-transformers",
|
||||||
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
help="Distance metric for HNSW backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-construction",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="efConstruction setting for initial build.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=5557,
|
||||||
|
help="Port for the real embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-initial",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Optional cap on initial passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-updates",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Optional cap on update passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-timeout",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("bench_latency.png"),
|
||||||
|
help="Where to save the latency bar plot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Where to append per-scenario results as CSV.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||||
|
|
||||||
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("baseline", False, False, True),
|
||||||
|
("no_cache_baseline", False, False, False),
|
||||||
|
("disable_forward_rng", True, False, True),
|
||||||
|
("disable_forward_and_reverse_rng", True, True, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||||
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
import csv
|
||||||
|
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"cache_enabled",
|
||||||
|
"ef_construction",
|
||||||
|
"max_initial",
|
||||||
|
"max_updates",
|
||||||
|
"total_time_s",
|
||||||
|
"add_only_s",
|
||||||
|
"latency_ms_per_passage",
|
||||||
|
"zmq_nodes",
|
||||||
|
"stageA_time_s",
|
||||||
|
"stageBC_time_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
# Create CSV with header if missing
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(args.runs):
|
||||||
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||||
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||||
|
print(f"\nScenario: {name}")
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
log_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||||
|
args.index_path,
|
||||||
|
update_chunks,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
disable_forward,
|
||||||
|
disable_reverse,
|
||||||
|
args.server_port,
|
||||||
|
args.add_timeout,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
print(f"Scenario {name} timed out: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
if curr_size < prev_size:
|
||||||
|
prev_size = 0
|
||||||
|
zmq_count = 0
|
||||||
|
if log_path.exists():
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
log_file.seek(prev_size)
|
||||||
|
new_entries = log_file.read()
|
||||||
|
zmq_count = sum(
|
||||||
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||||
|
)
|
||||||
|
stageA = sum(
|
||||||
|
float(x)
|
||||||
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
stageBC = sum(
|
||||||
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stageA = 0.0
|
||||||
|
stageBC = 0.0
|
||||||
|
|
||||||
|
per_chunk = add_elapsed / len(update_chunks)
|
||||||
|
print(
|
||||||
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||||
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||||
|
)
|
||||||
|
print(f"ZMQ node fetch total: {zmq_count}")
|
||||||
|
results_total[name].append(total_elapsed)
|
||||||
|
results_add[name].append(add_elapsed)
|
||||||
|
results_zmq[name].append(zmq_count)
|
||||||
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||||
|
results_stageA[name].append(stageA)
|
||||||
|
results_stageBC[name].append(stageBC)
|
||||||
|
|
||||||
|
# Append row to CSV
|
||||||
|
if args.csv_path:
|
||||||
|
row = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": name,
|
||||||
|
"cache_enabled": 1 if cache_enabled else 0,
|
||||||
|
"ef_construction": args.ef_construction,
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"max_updates": args.max_updates,
|
||||||
|
"total_time_s": round(total_elapsed, 6),
|
||||||
|
"add_only_s": round(add_elapsed, 6),
|
||||||
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||||
|
"zmq_nodes": int(zmq_count),
|
||||||
|
"stageA_time_s": round(stageA, 6),
|
||||||
|
"stageBC_time_s": round(stageBC, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
for name in results_add:
|
||||||
|
add_values = results_add[name]
|
||||||
|
total_values = results_total[name]
|
||||||
|
zmq_values = results_zmq[name]
|
||||||
|
latency_values = results_ms_per_passage[name]
|
||||||
|
if not add_values:
|
||||||
|
print(f"{name}: no successful runs")
|
||||||
|
continue
|
||||||
|
avg_add = sum(add_values) / len(add_values)
|
||||||
|
avg_total = sum(total_values) / len(total_values)
|
||||||
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||||
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||||
|
runs = len(add_values)
|
||||||
|
print(
|
||||||
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||||
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.plot_path:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
labels = [name for name, *_ in scenarios]
|
||||||
|
values = [
|
||||||
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||||
|
if results_ms_per_passage[name]
|
||||||
|
else 0.0
|
||||||
|
for name in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
def _auto_cap(vals: list[float]) -> float | None:
|
||||||
|
s = sorted(vals, reverse=True)
|
||||||
|
if len(s) < 2:
|
||||||
|
return None
|
||||||
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||||
|
return s[1] * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||||
|
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.4, 5.0),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||||
|
)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap * 0.02,
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False)
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.set_xticks(range(len(labels)))
|
||||||
|
ax_bottom.set_xticklabels(labels)
|
||||||
|
ax = ax_bottom
|
||||||
|
else:
|
||||||
|
cap = args.cap_y or _auto_cap(values)
|
||||||
|
plt.figure(figsize=(7.2, 4.2))
|
||||||
|
ax = plt.gca()
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||||
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(b[0])
|
||||||
|
if v > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
ax.plot(
|
||||||
|
[0.02 - 0.02, 0.02 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[0.98 - 0.02, 0.98 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
if any(v > cap for v in values):
|
||||||
|
ax.legend(
|
||||||
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||||
|
)
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||||
|
|
||||||
|
plt.ylabel("Average add latency (ms per passage)")
|
||||||
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.plot_path)
|
||||||
|
print(f"Saved latency bar plot to {args.plot_path}")
|
||||||
|
# ZMQ time split (Stage A vs B/C)
|
||||||
|
try:
|
||||||
|
plt.figure(figsize=(6, 4))
|
||||||
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||||
|
bc_vals = [
|
||||||
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||||
|
]
|
||||||
|
ind = range(len(labels))
|
||||||
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||||
|
plt.bar(
|
||||||
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||||
|
)
|
||||||
|
plt.xticks(list(ind), labels, rotation=10)
|
||||||
|
plt.ylabel("Server ZMQ time (s)")
|
||||||
|
plt.title(
|
||||||
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||||
|
)
|
||||||
|
plt.legend()
|
||||||
|
out2 = args.plot_path.with_name(
|
||||||
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||||
|
)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out2)
|
||||||
|
print(f"Saved ZMQ time split plot to {out2}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to plot ZMQ split:", e)
|
||||||
|
except ImportError:
|
||||||
|
print("matplotlib not available; skipping plot generation")
|
||||||
|
|
||||||
|
# leave the last build on disk for inspection
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
"""
|
||||||
|
Compare two latency models for small incremental updates vs. search:
|
||||||
|
|
||||||
|
Scenario A (sequential update then search):
|
||||||
|
- Build initial HNSW (is_recompute=True)
|
||||||
|
- Start embedding server (ZMQ) for recompute
|
||||||
|
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||||
|
- Then run a search query on the updated index
|
||||||
|
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||||
|
|
||||||
|
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||||
|
- Do NOT insert the N passages into the graph
|
||||||
|
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||||
|
embedding and run a search on the existing index
|
||||||
|
- After both finish, compute similarity between the query embedding and the N
|
||||||
|
new passage embeddings, merge with the index search results by score, and
|
||||||
|
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||||
|
|
||||||
|
This script reuses the model/data loading conventions of
|
||||||
|
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||||
|
comparison for the two execution strategies above.
|
||||||
|
|
||||||
|
Example (from the repository root):
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 --num-updates 5 --k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil # type: ignore
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||||
|
if metric == "cosine":
|
||||||
|
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||||
|
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vecs = vecs / norms
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index_for_search(index_path: Path) -> Any:
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
# Force-disable experimental disk cache when loading the index so that
|
||||||
|
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||||
|
cfg = faiss.HNSWIndexConfig()
|
||||||
|
cfg.is_recompute = True
|
||||||
|
if hasattr(cfg, "disk_cache_ratio"):
|
||||||
|
cfg.disk_cache_ratio = 0.0
|
||||||
|
if hasattr(cfg, "external_storage_path"):
|
||||||
|
cfg.external_storage_path = None
|
||||||
|
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||||
|
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||||
|
# ensure recompute mode persists after reload
|
||||||
|
try:
|
||||||
|
index.is_recompute = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
actual_ntotal = index.hnsw.levels.size()
|
||||||
|
except AttributeError:
|
||||||
|
actual_ntotal = index.ntotal
|
||||||
|
if actual_ntotal != index.ntotal:
|
||||||
|
print(
|
||||||
|
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
index.ntotal = actual_ntotal
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_for_updates(
|
||||||
|
meta_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
index_dir = meta_path.parent
|
||||||
|
meta_name = meta_path.name
|
||||||
|
if not meta_name.endswith(".meta.json"):
|
||||||
|
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||||
|
index_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||||
|
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not offsets_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passage store missing; cannot register update passages for recompute mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
passage_id = str(start_id + i)
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[passage_id] = offset
|
||||||
|
assigned_ids.append(passage_id)
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta = {}
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||||
|
distances = np.zeros((1, k), dtype=np.float32)
|
||||||
|
indices = np.zeros((1, k), dtype=np.int64)
|
||||||
|
index.search(
|
||||||
|
1,
|
||||||
|
faiss.swig_ptr(q),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(indices),
|
||||||
|
)
|
||||||
|
return distances[0], indices[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _score_for_metric(dist: float, metric: str) -> float:
|
||||||
|
# Convert FAISS distance to a "higher is better" score
|
||||||
|
if metric in ("mips", "cosine"):
|
||||||
|
return float(dist)
|
||||||
|
# l2 distance (smaller better) -> negative distance as score
|
||||||
|
return -float(dist)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray],
|
||||||
|
offline_scores: list[tuple[int, float]],
|
||||||
|
k: int,
|
||||||
|
metric: str,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
distances, indices = index_results
|
||||||
|
merged: list[tuple[str, float]] = []
|
||||||
|
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||||
|
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||||
|
for j, s in offline_scores:
|
||||||
|
merged.append((f"offline:{j}", s))
|
||||||
|
merged.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return merged[:k]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScenarioResult:
|
||||||
|
name: str
|
||||||
|
update_total_s: float
|
||||||
|
search_s: float
|
||||||
|
overall_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-initial", type=int, default=300)
|
||||||
|
parser.add_argument("--num-updates", type=int, default=5)
|
||||||
|
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="neural network",
|
||||||
|
help="Query text used for the search benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--server-port", type=int, default=5557)
|
||||||
|
parser.add_argument("--add-timeout", type=int, default=600)
|
||||||
|
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
choices=["A", "B", "both"],
|
||||||
|
default="both",
|
||||||
|
help="Run only Scenario A, Scenario B, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Where to append results (CSV).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages loaded from --update-files")
|
||||||
|
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||||
|
if len(update_paragraphs) < args.num_updates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare index object and meta
|
||||||
|
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||||
|
index = _read_index_for_search(args.index_path)
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"max_initial",
|
||||||
|
"num_updates",
|
||||||
|
"k",
|
||||||
|
"total_time_s",
|
||||||
|
"add_total_s",
|
||||||
|
"search_time_s",
|
||||||
|
"emb_time_s",
|
||||||
|
"makespan_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
# Debug: list existing HNSW server PIDs before starting
|
||||||
|
try:
|
||||||
|
existing = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if existing:
|
||||||
|
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||||
|
for p in existing:
|
||||||
|
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||||
|
except Exception as _e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
add_total = 0.0
|
||||||
|
search_after_add = 0.0
|
||||||
|
total_seq = 0.0
|
||||||
|
port_a = None
|
||||||
|
if args.only in ("A", "both"):
|
||||||
|
# Scenario A: sequential update then search
|
||||||
|
start_id = index.ntotal
|
||||||
|
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||||
|
if assigned_ids:
|
||||||
|
logger.debug(
|
||||||
|
"Registered %d update passages starting at id %s",
|
||||||
|
len(assigned_ids),
|
||||||
|
assigned_ids[0],
|
||||||
|
)
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
ok, port = server_manager.start_server(
|
||||||
|
port=args.server_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to start embedding server")
|
||||||
|
try:
|
||||||
|
# Set ZMQ port for recompute mode
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(port)
|
||||||
|
|
||||||
|
# Start A overall timer BEFORE computing update embeddings
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Compute embeddings for updates (counted into A's overall)
|
||||||
|
t_emb0 = time.time()
|
||||||
|
upd_embs = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time_updates = time.time() - t_emb0
|
||||||
|
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||||
|
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||||
|
|
||||||
|
# Perform sequential adds
|
||||||
|
for i in range(upd_embs.shape[0]):
|
||||||
|
t_add0 = time.time()
|
||||||
|
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||||
|
add_total += time.time() - t_add0
|
||||||
|
# Don't persist index after adds to avoid contaminating Scenario B
|
||||||
|
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||||
|
# faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Search after updates
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||||
|
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||||
|
|
||||||
|
# Warm up search with a dummy query first
|
||||||
|
print("[DEBUG] Warming up search...")
|
||||||
|
_ = _search(index, q_emb, 1)
|
||||||
|
|
||||||
|
t_s0 = time.time()
|
||||||
|
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||||
|
search_after_add = time.time() - t_s0
|
||||||
|
total_seq = time.time() - t0
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
port_a = port
|
||||||
|
|
||||||
|
print("\n=== Scenario A: update->search (sequential) ===")
|
||||||
|
# emb_time_updates is defined only when A runs
|
||||||
|
try:
|
||||||
|
_emb_a = emb_time_updates
|
||||||
|
except NameError:
|
||||||
|
_emb_a = 0.0
|
||||||
|
print(
|
||||||
|
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||||
|
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||||
|
)
|
||||||
|
# CSV row for A
|
||||||
|
if args.csv_path:
|
||||||
|
row_a = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "A",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": round(total_seq, 6),
|
||||||
|
"add_total_s": round(add_total, 6),
|
||||||
|
"search_time_s": round(search_after_add, 6),
|
||||||
|
"emb_time_s": round(_emb_a, 6),
|
||||||
|
"makespan_s": 0.0,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_a)
|
||||||
|
|
||||||
|
# Verify server cleanup
|
||||||
|
try:
|
||||||
|
# short sleep to allow signal handling to finish
|
||||||
|
time.sleep(0.5)
|
||||||
|
leftovers = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if leftovers:
|
||||||
|
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||||
|
for p in leftovers:
|
||||||
|
print(
|
||||||
|
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||||
|
if args.only in ("B", "both"):
|
||||||
|
# ensure a server is available for recompute search
|
||||||
|
server_manager_b = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
requested_port = args.server_port if port_a is None else port_a
|
||||||
|
ok_b, port_b = server_manager_b.start_server(
|
||||||
|
port=requested_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok_b:
|
||||||
|
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||||
|
|
||||||
|
# Wait for server to fully initialize
|
||||||
|
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the index first
|
||||||
|
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||||
|
|
||||||
|
# Then configure ZMQ port on the correct index object
|
||||||
|
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||||
|
index_no_update.hnsw.set_zmq_port(port_b)
|
||||||
|
elif hasattr(index_no_update, "set_zmq_port"):
|
||||||
|
index_no_update.set_zmq_port(port_b)
|
||||||
|
|
||||||
|
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||||
|
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||||
|
logger.info("Warming up embedding model for Scenario B...")
|
||||||
|
_ = compute_embeddings(
|
||||||
|
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare worker A: compute embeddings for the same N passages
|
||||||
|
emb_time = 0.0
|
||||||
|
updates_embs_offline: np.ndarray | None = None
|
||||||
|
|
||||||
|
def _worker_emb():
|
||||||
|
nonlocal emb_time, updates_embs_offline
|
||||||
|
t = time.time()
|
||||||
|
updates_embs_offline = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time = time.time() - t
|
||||||
|
|
||||||
|
# Pre-compute query embedding and warm up search outside of timed section.
|
||||||
|
q_vec = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||||
|
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||||
|
print("[DEBUG B] Warming up search...")
|
||||||
|
_ = _search(index_no_update, q_vec, 1)
|
||||||
|
|
||||||
|
# Worker B: timed search on the warmed index
|
||||||
|
search_time = 0.0
|
||||||
|
offline_elapsed = 0.0
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||||
|
|
||||||
|
def _worker_search():
|
||||||
|
nonlocal search_time, index_results
|
||||||
|
t = time.time()
|
||||||
|
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||||
|
search_time = time.time() - t
|
||||||
|
index_results = (distances, indices)
|
||||||
|
|
||||||
|
# Run two workers concurrently
|
||||||
|
t0 = time.time()
|
||||||
|
th1 = threading.Thread(target=_worker_emb)
|
||||||
|
th2 = threading.Thread(target=_worker_search)
|
||||||
|
th1.start()
|
||||||
|
th2.start()
|
||||||
|
th1.join()
|
||||||
|
th2.join()
|
||||||
|
offline_elapsed = time.time() - t0
|
||||||
|
|
||||||
|
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||||
|
offline_scores: list[tuple[int, float]] = []
|
||||||
|
if updates_embs_offline is not None:
|
||||||
|
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||||
|
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||||
|
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||||
|
for j in range(upd2.shape[0]):
|
||||||
|
if args.distance_metric in ("mips", "cosine"):
|
||||||
|
s = float(np.dot(q_vec[0], upd2[j]))
|
||||||
|
else:
|
||||||
|
diff = q_vec[0] - upd2[j]
|
||||||
|
s = -float(np.dot(diff, diff))
|
||||||
|
offline_scores.append((j, s))
|
||||||
|
|
||||||
|
merged_topk = (
|
||||||
|
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||||
|
if index_results
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||||
|
print(
|
||||||
|
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||||
|
)
|
||||||
|
if merged_topk:
|
||||||
|
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||||
|
print(f"Merged top-5 preview: {preview}")
|
||||||
|
# CSV row for B
|
||||||
|
if args.csv_path:
|
||||||
|
row_b = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "B",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": 0.0,
|
||||||
|
"add_total_s": 0.0,
|
||||||
|
"search_time_s": round(search_time, 6),
|
||||||
|
"emb_time_s": round(emb_time, 6),
|
||||||
|
"makespan_s": round(offline_elapsed, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_b)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server_manager_b.stop_server()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
msg_a = (
|
||||||
|
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||||
|
if args.only in ("A", "both")
|
||||||
|
else "A: skipped"
|
||||||
|
)
|
||||||
|
msg_b = (
|
||||||
|
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||||
|
if args.only in ("B", "both")
|
||||||
|
else "B: skipped"
|
||||||
|
)
|
||||||
|
print(msg_a + "\n" + msg_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Plot latency bars from the benchmark CSV produced by
|
||||||
|
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||||
|
|
||||||
|
If you also provide an offline_vs_update.csv via --csv-right
|
||||||
|
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||||
|
output a side-by-side figure:
|
||||||
|
- Left: ms/passage bars (four RNG scenarios).
|
||||||
|
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python benchmarks/update/plot_bench_results.py \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
|
||||||
|
The script selects the latest run_id in the CSV and plots four bars for
|
||||||
|
the default scenarios:
|
||||||
|
- baseline
|
||||||
|
- no_cache_baseline
|
||||||
|
- disable_forward_rng
|
||||||
|
- disable_forward_and_reverse_rng
|
||||||
|
|
||||||
|
If multiple rows exist per scenario for that run_id, the script averages
|
||||||
|
their latency_ms_per_passage values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_SCENARIOS = [
|
||||||
|
"no_cache_baseline",
|
||||||
|
"baseline",
|
||||||
|
"disable_forward_rng",
|
||||||
|
"disable_forward_and_reverse_rng",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"baseline": "+ Cache",
|
||||||
|
"no_cache_baseline": "Naive \n Recompute",
|
||||||
|
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||||
|
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper-style colors and hatches for scenarios
|
||||||
|
SCENARIO_STYLES = {
|
||||||
|
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||||
|
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||||
|
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||||
|
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_latest_run(csv_path: Path):
|
||||||
|
rows = []
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows.append(row)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit("CSV is empty: no rows to plot")
|
||||||
|
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||||
|
run_ids = [r.get("run_id", "") for r in rows]
|
||||||
|
latest = max(run_ids)
|
||||||
|
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||||
|
if not latest_rows:
|
||||||
|
# Fallback: take last 4 rows
|
||||||
|
latest_rows = rows[-4:]
|
||||||
|
latest = latest_rows[-1].get("run_id", "unknown")
|
||||||
|
return latest, latest_rows
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_latency(rows):
|
||||||
|
acc = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
sc = r.get("scenario", "")
|
||||||
|
try:
|
||||||
|
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
acc[sc].append(val)
|
||||||
|
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_cap(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
sorted_vals = sorted(values, reverse=True)
|
||||||
|
if len(sorted_vals) < 2:
|
||||||
|
return None
|
||||||
|
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||||
|
if second <= 0:
|
||||||
|
return None
|
||||||
|
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||||
|
if max_v >= 2.5 * second:
|
||||||
|
return second * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||||
|
# Draw small diagonal ticks near left/right to signal cap
|
||||||
|
x0, x1 = rel_x0, rel_x1
|
||||||
|
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
if v >= 1000:
|
||||||
|
return f"{v / 1000:.1f}k"
|
||||||
|
return f"{v:.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.rcParams["font.family"] = "Helvetica"
|
||||||
|
plt.rcParams["ytick.direction"] = "in"
|
||||||
|
plt.rcParams["hatch.linewidth"] = 1.5
|
||||||
|
plt.rcParams["font.weight"] = "bold"
|
||||||
|
plt.rcParams["axes.labelweight"] = "bold"
|
||||||
|
plt.rcParams["text.usetex"] = True
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Path to results CSV (defaults to bench_results.csv)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=Path,
|
||||||
|
default=Path("add_ablation.pdf"),
|
||||||
|
help="Output image path",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv-right",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--no-auto-cap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
latest_run, latest_rows = load_latest_run(args.csv)
|
||||||
|
avg = aggregate_latency(latest_rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(f"matplotlib not available: {e}")
|
||||||
|
|
||||||
|
scenarios = DEFAULT_SCENARIOS
|
||||||
|
values = [avg.get(name, 0.0) for name in scenarios]
|
||||||
|
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
# If right CSV is provided, build side-by-side figure
|
||||||
|
if args.csv_right is not None:
|
||||||
|
try:
|
||||||
|
right_rows_all = []
|
||||||
|
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||||
|
rreader = csv.DictReader(f)
|
||||||
|
right_rows_all = list(rreader)
|
||||||
|
if right_rows_all:
|
||||||
|
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||||
|
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||||
|
else:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
except Exception:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
|
||||||
|
a_total = 0.0
|
||||||
|
b_makespan = 0.0
|
||||||
|
for r in right_rows:
|
||||||
|
sc = (r.get("scenario", "") or "").strip().upper()
|
||||||
|
if sc == "A":
|
||||||
|
try:
|
||||||
|
a_total = float(r.get("total_time_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif sc == "B":
|
||||||
|
try:
|
||||||
|
b_makespan = float(r.get("makespan_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import gridspec
|
||||||
|
|
||||||
|
# Left subplot (reuse current style, with optional cap)
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
# Use broken axis for left subplot
|
||||||
|
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||||
|
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||||
|
gs = gridspec.GridSpec(
|
||||||
|
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||||
|
)
|
||||||
|
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||||
|
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||||
|
ax_right = fig.add_subplot(gs[:, 1])
|
||||||
|
|
||||||
|
# Determine break points
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = (
|
||||||
|
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||||
|
) # Increased to show more range
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.5, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = (
|
||||||
|
max(values) * 1.90 if values else 1.0
|
||||||
|
) # Increase headroom to 1.90 for text label and tick range
|
||||||
|
|
||||||
|
# Draw bars on both axes
|
||||||
|
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Set limits
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_left_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values (convert ms to s)
|
||||||
|
values_s = [v / 1000.0 for v in values]
|
||||||
|
lower_cap_s = lower_cap / 1000.0
|
||||||
|
upper_start_s = upper_start / 1000.0
|
||||||
|
ymax_s = ymax / 1000.0
|
||||||
|
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||||
|
ax_left_bottom.clear()
|
||||||
|
ax_left_top.clear()
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||||
|
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||||
|
# Draw in bottom axis for all bars
|
||||||
|
ax_left_bottom.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||||
|
if v > upper_start_s:
|
||||||
|
ax_left_top.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
for i, v in enumerate(values_s):
|
||||||
|
if v <= lower_cap_s:
|
||||||
|
ax_left_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap_s * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left_top.text(
|
||||||
|
i,
|
||||||
|
v + (ymax_s - upper_start_s) * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide spines between axes
|
||||||
|
ax_left_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_left_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_left_top.tick_params(
|
||||||
|
labeltop=False, labelbottom=False, bottom=False
|
||||||
|
) # Hide tick marks
|
||||||
|
ax_left_bottom.xaxis.tick_bottom()
|
||||||
|
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||||
|
|
||||||
|
# Draw break marks (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_left_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||||
|
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
|
||||||
|
ax_left_bottom.set_xticks(x)
|
||||||
|
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||||
|
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match bar width with right subplot
|
||||||
|
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||||
|
ax_left_top.set_xlim(-0.6, 3.6)
|
||||||
|
|
||||||
|
ax_left = ax_left_bottom # for compatibility
|
||||||
|
else:
|
||||||
|
# Regular side-by-side layout
|
||||||
|
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||||
|
if val > cap:
|
||||||
|
bars[i].set_hatch("//")
|
||||||
|
ax_left.text(
|
||||||
|
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(val),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax_left.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax_left, y=0.98)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
else:
|
||||||
|
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax_left.set_ylabel("Latency (ms per passage)")
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
ax_left.set_title(
|
||||||
|
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right subplot (A vs B, seconds) - paper style
|
||||||
|
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||||
|
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||||
|
r_styles = [
|
||||||
|
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||||
|
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||||
|
]
|
||||||
|
# 2 bars, centered with proper spacing
|
||||||
|
xr = [0, 1]
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||||
|
ax_right.bar(
|
||||||
|
xr[i],
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(r_values):
|
||||||
|
max_v = max(r_values) if r_values else 1.0
|
||||||
|
offset = max(0.0002, 0.02 * max_v)
|
||||||
|
ax_right.text(
|
||||||
|
xr[i],
|
||||||
|
v + offset,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_right.set_xticks(xr)
|
||||||
|
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_right.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match left subplot's bar width visually
|
||||||
|
# Accounting for width_ratios=[1.5, 1]:
|
||||||
|
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||||
|
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# Right: 2 bars, need same visual width
|
||||||
|
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# range_right = 4.2 / 1.5 = 2.8
|
||||||
|
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||||
|
ax_right.set_xlim(-0.9, 1.9)
|
||||||
|
|
||||||
|
# Set y-axis limit with headroom for text labels
|
||||||
|
if r_values:
|
||||||
|
max_v = max(r_values)
|
||||||
|
ax_right.set_ylim(0, max_v * 1.15)
|
||||||
|
|
||||||
|
# Format y-axis to avoid scientific notation
|
||||||
|
ax_right.ticklabel_format(style="plain", axis="y")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Add aligned ylabels using fig.text (after tight_layout)
|
||||||
|
# Get the vertical center of the entire figure
|
||||||
|
fig_center_y = 0.5
|
||||||
|
# Left ylabel - closer to left plot
|
||||||
|
left_x = 0.05
|
||||||
|
fig.text(
|
||||||
|
left_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Right ylabel - closer to right plot
|
||||||
|
right_bbox = ax_right.get_position()
|
||||||
|
right_x = right_bbox.x0 - 0.07
|
||||||
|
fig.text(
|
||||||
|
right_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broken-Y mode
|
||||||
|
if args.broken_y:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.5, 6.75),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine default breaks from second-highest
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
|
||||||
|
# Hide spines between axes and draw diagonal break marks
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Diagonal lines at the break (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||||
|
|
||||||
|
ax_bottom.set_xticks(x)
|
||||||
|
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax = ax_bottom # for labeling below
|
||||||
|
else:
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
|
||||||
|
plt.figure(figsize=(5.4, 3.15))
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||||
|
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(bar[0])
|
||||||
|
# Hatch and annotate when capped
|
||||||
|
if val > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax, y=0.98)
|
||||||
|
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||||
|
v > cap for v in values
|
||||||
|
) else None
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(
|
||||||
|
idx,
|
||||||
|
val + 1.0,
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
# Try to extract some context for title
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
fig.text(
|
||||||
|
0.02,
|
||||||
|
0.5,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
fig.suptitle(
|
||||||
|
"Add Operation Latency",
|
||||||
|
fontsize=11,
|
||||||
|
y=0.98,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||||
|
else:
|
||||||
|
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||||
|
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -215,6 +215,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
|
if hasattr(self._index, "set_zmq_port"):
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|||||||
@@ -820,10 +820,10 @@ class LeannBuilder:
|
|||||||
actual_port,
|
actual_port,
|
||||||
requested_zmq_port,
|
requested_zmq_port,
|
||||||
)
|
)
|
||||||
try:
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
index.hnsw.zmq_port = actual_port
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
except AttributeError:
|
elif hasattr(index, "set_zmq_port"):
|
||||||
pass
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
if needs_recompute:
|
if needs_recompute:
|
||||||
for i in range(embeddings.shape[0]):
|
for i in range(embeddings.shape[0]):
|
||||||
|
|||||||
Reference in New Issue
Block a user