* stash * stash * add std err in add and trace progress * fix. * docs * style: format * docs * better figs * better figs * update results * fotmat --------- Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
805 lines
28 KiB
Python
805 lines
28 KiB
Python
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
|
embedding recomputation.
|
|
|
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
|
|
|
Example usage (run from the repo root; downloads the model on first run)::
|
|
|
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
|
--index-path .leann/bench/leann-demo.leann \
|
|
--runs 1
|
|
|
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
|
if you want a larger or different workload, and change the embedding model via
|
|
``--model-name``.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import pickle
|
|
import re
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import msgpack
|
|
import numpy as np
|
|
import zmq
|
|
from leann.api import LeannBuilder
|
|
|
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
|
|
|
from leann.embedding_compute import compute_embeddings
|
|
from leann.embedding_server_manager import EmbeddingServerManager
|
|
from leann.registry import register_project_directory
|
|
from leann_backend_hnsw import faiss # type: ignore
|
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
if not logging.getLogger().handlers:
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def _find_repo_root() -> Path:
|
|
"""Locate project root by walking up until pyproject.toml is found."""
|
|
current = Path(__file__).resolve()
|
|
for parent in current.parents:
|
|
if (parent / "pyproject.toml").exists():
|
|
return parent
|
|
# Fallback: assume repo is two levels up (../..)
|
|
return current.parents[2]
|
|
|
|
|
|
REPO_ROOT = _find_repo_root()
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
from apps.chunking import create_text_chunks # noqa: E402
|
|
|
|
DEFAULT_INITIAL_FILES = [
|
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
|
]
|
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
|
|
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
|
|
|
|
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
|
from llama_index.core import SimpleDirectoryReader
|
|
|
|
documents = []
|
|
for path in paths:
|
|
p = path.expanduser().resolve()
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"Input path not found: {p}")
|
|
if p.is_dir():
|
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
|
documents.extend(reader.load_data(show_progress=True))
|
|
else:
|
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
|
documents.extend(reader.load_data(show_progress=True))
|
|
|
|
if not documents:
|
|
return []
|
|
|
|
chunks = create_text_chunks(
|
|
documents,
|
|
chunk_size=512,
|
|
chunk_overlap=128,
|
|
use_ast_chunking=False,
|
|
)
|
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
|
if limit is not None:
|
|
cleaned = cleaned[:limit]
|
|
return cleaned
|
|
|
|
|
|
def ensure_index_dir(index_path: Path) -> None:
|
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def cleanup_index_files(index_path: Path) -> None:
|
|
parent = index_path.parent
|
|
if not parent.exists():
|
|
return
|
|
stem = index_path.stem
|
|
for file in parent.glob(f"{stem}*"):
|
|
if file.is_file():
|
|
file.unlink()
|
|
|
|
|
|
def build_initial_index(
|
|
index_path: Path,
|
|
paragraphs: list[str],
|
|
model_name: str,
|
|
embedding_mode: str,
|
|
distance_metric: str,
|
|
ef_construction: int,
|
|
) -> None:
|
|
builder = LeannBuilder(
|
|
backend_name="hnsw",
|
|
embedding_model=model_name,
|
|
embedding_mode=embedding_mode,
|
|
is_compact=False,
|
|
is_recompute=True,
|
|
distance_metric=distance_metric,
|
|
backend_kwargs={
|
|
"distance_metric": distance_metric,
|
|
"is_compact": False,
|
|
"is_recompute": True,
|
|
"efConstruction": ef_construction,
|
|
},
|
|
)
|
|
for idx, passage in enumerate(paragraphs):
|
|
builder.add_text(passage, metadata={"id": str(idx)})
|
|
builder.build_index(str(index_path))
|
|
|
|
|
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
|
|
|
|
|
def benchmark_update_with_mode(
|
|
index_path: Path,
|
|
new_chunks: list[dict[str, Any]],
|
|
model_name: str,
|
|
embedding_mode: str,
|
|
distance_metric: str,
|
|
disable_forward_rng: bool,
|
|
disable_reverse_rng: bool,
|
|
server_port: int,
|
|
add_timeout: int,
|
|
ef_construction: int,
|
|
) -> tuple[float, float]:
|
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
|
|
|
with open(meta_path, encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
|
|
with open(offset_file, "rb") as f:
|
|
offset_map: dict[str, int] = pickle.load(f)
|
|
existing_ids = set(offset_map.keys())
|
|
|
|
valid_chunks: list[dict[str, Any]] = []
|
|
for chunk in new_chunks:
|
|
text = chunk.get("text", "")
|
|
if not isinstance(text, str) or not text.strip():
|
|
continue
|
|
metadata = chunk.setdefault("metadata", {})
|
|
passage_id = chunk.get("id") or metadata.get("id")
|
|
if passage_id and passage_id in existing_ids:
|
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
|
valid_chunks.append(chunk)
|
|
|
|
if not valid_chunks:
|
|
raise ValueError("No valid chunks to append.")
|
|
|
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
|
embeddings = compute_embeddings(
|
|
texts_to_embed,
|
|
model_name,
|
|
mode=embedding_mode,
|
|
is_build=False,
|
|
batch_size=16,
|
|
)
|
|
|
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
if distance_metric == "cosine":
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
norms[norms == 0] = 1
|
|
embeddings = embeddings / norms
|
|
|
|
index = faiss.read_index(str(index_file))
|
|
index.is_recompute = True
|
|
if getattr(index, "storage", None) is None:
|
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
|
storage_index = faiss.IndexFlatIP(index.d)
|
|
else:
|
|
storage_index = faiss.IndexFlatL2(index.d)
|
|
index.storage = storage_index
|
|
index.own_fields = True
|
|
try:
|
|
storage_index.ntotal = index.ntotal
|
|
except AttributeError:
|
|
pass
|
|
try:
|
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
|
if ef_construction is not None:
|
|
index.hnsw.efConstruction = ef_construction
|
|
except AttributeError:
|
|
pass
|
|
|
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
|
logger.info(
|
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
|
disable_forward_rng,
|
|
disable_reverse_rng,
|
|
applied_forward,
|
|
applied_reverse,
|
|
)
|
|
|
|
base_id = index.ntotal
|
|
for offset, chunk in enumerate(valid_chunks):
|
|
new_id = str(base_id + offset)
|
|
chunk.setdefault("metadata", {})["id"] = new_id
|
|
chunk["id"] = new_id
|
|
|
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
|
offset_map_backup = offset_map.copy()
|
|
|
|
try:
|
|
with open(passages_file, "a", encoding="utf-8") as f:
|
|
for chunk in valid_chunks:
|
|
offset = f.tell()
|
|
json.dump(
|
|
{
|
|
"id": chunk["id"],
|
|
"text": chunk["text"],
|
|
"metadata": chunk.get("metadata", {}),
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
f.write("\n")
|
|
offset_map[chunk["id"]] = offset
|
|
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map, f)
|
|
|
|
server_manager = EmbeddingServerManager(
|
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
)
|
|
server_started, actual_port = server_manager.start_server(
|
|
port=server_port,
|
|
model_name=model_name,
|
|
embedding_mode=embedding_mode,
|
|
passages_file=str(meta_path),
|
|
distance_metric=distance_metric,
|
|
)
|
|
if not server_started:
|
|
raise RuntimeError("Failed to start embedding server.")
|
|
|
|
if hasattr(index.hnsw, "set_zmq_port"):
|
|
index.hnsw.set_zmq_port(actual_port)
|
|
elif hasattr(index, "set_zmq_port"):
|
|
index.set_zmq_port(actual_port)
|
|
|
|
_warmup_embedding_server(actual_port)
|
|
|
|
total_start = time.time()
|
|
add_elapsed = 0.0
|
|
|
|
try:
|
|
import signal
|
|
|
|
def _timeout_handler(signum, frame):
|
|
raise TimeoutError("incremental add timed out")
|
|
|
|
if add_timeout > 0:
|
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
|
signal.alarm(add_timeout)
|
|
|
|
add_start = time.time()
|
|
for i in range(embeddings.shape[0]):
|
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
|
add_elapsed = time.time() - add_start
|
|
if add_timeout > 0:
|
|
signal.alarm(0)
|
|
faiss.write_index(index, str(index_file))
|
|
finally:
|
|
server_manager.stop_server()
|
|
|
|
except TimeoutError:
|
|
raise
|
|
except Exception:
|
|
if passages_file.exists():
|
|
with open(passages_file, "rb+") as f:
|
|
f.truncate(rollback_size)
|
|
with open(offset_file, "wb") as f:
|
|
pickle.dump(offset_map_backup, f)
|
|
raise
|
|
|
|
prune_hnsw_embeddings_inplace(str(index_file))
|
|
|
|
meta["total_passages"] = len(offset_map)
|
|
with open(meta_path, "w", encoding="utf-8") as f:
|
|
json.dump(meta, f, indent=2)
|
|
|
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
|
try:
|
|
index.hnsw.set_disable_rng_during_add(False)
|
|
index.hnsw.set_disable_reverse_prune(False)
|
|
except AttributeError:
|
|
pass
|
|
faiss.write_index(index, str(index_file))
|
|
|
|
total_elapsed = time.time() - total_start
|
|
|
|
return total_elapsed, add_elapsed
|
|
|
|
|
|
def _total_zmq_nodes(log_path: Path) -> int:
|
|
if not log_path.exists():
|
|
return 0
|
|
with log_path.open("r", encoding="utf-8") as log_file:
|
|
text = log_file.read()
|
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
|
|
|
|
|
def _warmup_embedding_server(port: int) -> None:
|
|
"""Send a dummy REQ so the embedding server loads its model."""
|
|
ctx = zmq.Context()
|
|
try:
|
|
sock = ctx.socket(zmq.REQ)
|
|
sock.setsockopt(zmq.LINGER, 0)
|
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
|
sock.send(payload)
|
|
try:
|
|
sock.recv()
|
|
except zmq.error.Again:
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
ctx.term()
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--index-path",
|
|
type=Path,
|
|
default=Path(".leann/bench/leann-demo.leann"),
|
|
help="Output index base path (without extension).",
|
|
)
|
|
parser.add_argument(
|
|
"--initial-files",
|
|
nargs="*",
|
|
type=Path,
|
|
default=DEFAULT_INITIAL_FILES,
|
|
help="Files used to build the initial index.",
|
|
)
|
|
parser.add_argument(
|
|
"--update-files",
|
|
nargs="*",
|
|
type=Path,
|
|
default=DEFAULT_UPDATE_FILES,
|
|
help="Files appended during the benchmark.",
|
|
)
|
|
parser.add_argument(
|
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
|
help="Embedding model used for build/update.",
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-mode",
|
|
default="sentence-transformers",
|
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
|
)
|
|
parser.add_argument(
|
|
"--distance-metric",
|
|
default="mips",
|
|
choices=["mips", "l2", "cosine"],
|
|
help="Distance metric for HNSW backend.",
|
|
)
|
|
parser.add_argument(
|
|
"--ef-construction",
|
|
type=int,
|
|
default=200,
|
|
help="efConstruction setting for initial build.",
|
|
)
|
|
parser.add_argument(
|
|
"--server-port",
|
|
type=int,
|
|
default=5557,
|
|
help="Port for the real embedding server.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-initial",
|
|
type=int,
|
|
default=300,
|
|
help="Optional cap on initial passages (after chunking).",
|
|
)
|
|
parser.add_argument(
|
|
"--max-updates",
|
|
type=int,
|
|
default=1,
|
|
help="Optional cap on update passages (after chunking).",
|
|
)
|
|
parser.add_argument(
|
|
"--add-timeout",
|
|
type=int,
|
|
default=900,
|
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
|
)
|
|
parser.add_argument(
|
|
"--plot-path",
|
|
type=Path,
|
|
default=Path("bench_latency.png"),
|
|
help="Where to save the latency bar plot.",
|
|
)
|
|
parser.add_argument(
|
|
"--cap-y",
|
|
type=float,
|
|
default=None,
|
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
|
)
|
|
parser.add_argument(
|
|
"--broken-y",
|
|
action="store_true",
|
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
|
)
|
|
parser.add_argument(
|
|
"--lower-cap-y",
|
|
type=float,
|
|
default=None,
|
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
|
)
|
|
parser.add_argument(
|
|
"--upper-start-y",
|
|
type=float,
|
|
default=None,
|
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
|
)
|
|
parser.add_argument(
|
|
"--csv-path",
|
|
type=Path,
|
|
default=Path("benchmarks/update/bench_results.csv"),
|
|
help="Where to append per-scenario results as CSV.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
register_project_directory(REPO_ROOT)
|
|
|
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
|
if not update_paragraphs:
|
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
|
|
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
|
ensure_index_dir(args.index_path)
|
|
|
|
scenarios = [
|
|
("baseline", False, False, True),
|
|
("no_cache_baseline", False, False, False),
|
|
("disable_forward_rng", True, False, True),
|
|
("disable_forward_and_reverse_rng", True, True, True),
|
|
]
|
|
|
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
|
|
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
|
|
|
# CSV setup
|
|
import csv
|
|
|
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
|
csv_fields = [
|
|
"run_id",
|
|
"scenario",
|
|
"cache_enabled",
|
|
"ef_construction",
|
|
"max_initial",
|
|
"max_updates",
|
|
"total_time_s",
|
|
"add_only_s",
|
|
"latency_ms_per_passage",
|
|
"zmq_nodes",
|
|
"stageA_time_s",
|
|
"stageBC_time_s",
|
|
"model_name",
|
|
"embedding_mode",
|
|
"distance_metric",
|
|
]
|
|
# Create CSV with header if missing
|
|
if args.csv_path:
|
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
writer.writeheader()
|
|
|
|
for run in range(args.runs):
|
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
|
print(f"\nScenario: {name}")
|
|
cleanup_index_files(args.index_path)
|
|
if log_path.exists():
|
|
try:
|
|
log_path.unlink()
|
|
except OSError:
|
|
pass
|
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
|
build_initial_index(
|
|
args.index_path,
|
|
initial_paragraphs,
|
|
args.model_name,
|
|
args.embedding_mode,
|
|
args.distance_metric,
|
|
args.ef_construction,
|
|
)
|
|
|
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
|
|
|
try:
|
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
|
args.index_path,
|
|
update_chunks,
|
|
args.model_name,
|
|
args.embedding_mode,
|
|
args.distance_metric,
|
|
disable_forward,
|
|
disable_reverse,
|
|
args.server_port,
|
|
args.add_timeout,
|
|
args.ef_construction,
|
|
)
|
|
except TimeoutError as exc:
|
|
print(f"Scenario {name} timed out: {exc}")
|
|
continue
|
|
|
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
|
if curr_size < prev_size:
|
|
prev_size = 0
|
|
zmq_count = 0
|
|
if log_path.exists():
|
|
with log_path.open("r", encoding="utf-8") as log_file:
|
|
log_file.seek(prev_size)
|
|
new_entries = log_file.read()
|
|
zmq_count = sum(
|
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
|
)
|
|
stageA = sum(
|
|
float(x)
|
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
|
)
|
|
stageBC = sum(
|
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
|
)
|
|
else:
|
|
stageA = 0.0
|
|
stageBC = 0.0
|
|
|
|
per_chunk = add_elapsed / len(update_chunks)
|
|
print(
|
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
|
)
|
|
print(f"ZMQ node fetch total: {zmq_count}")
|
|
results_total[name].append(total_elapsed)
|
|
results_add[name].append(add_elapsed)
|
|
results_zmq[name].append(zmq_count)
|
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
|
results_stageA[name].append(stageA)
|
|
results_stageBC[name].append(stageBC)
|
|
|
|
# Append row to CSV
|
|
if args.csv_path:
|
|
row = {
|
|
"run_id": run_id,
|
|
"scenario": name,
|
|
"cache_enabled": 1 if cache_enabled else 0,
|
|
"ef_construction": args.ef_construction,
|
|
"max_initial": args.max_initial,
|
|
"max_updates": args.max_updates,
|
|
"total_time_s": round(total_elapsed, 6),
|
|
"add_only_s": round(add_elapsed, 6),
|
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
|
"zmq_nodes": int(zmq_count),
|
|
"stageA_time_s": round(stageA, 6),
|
|
"stageBC_time_s": round(stageBC, 6),
|
|
"model_name": args.model_name,
|
|
"embedding_mode": args.embedding_mode,
|
|
"distance_metric": args.distance_metric,
|
|
}
|
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
|
writer.writerow(row)
|
|
|
|
print("\n=== Summary ===")
|
|
for name in results_add:
|
|
add_values = results_add[name]
|
|
total_values = results_total[name]
|
|
zmq_values = results_zmq[name]
|
|
latency_values = results_ms_per_passage[name]
|
|
if not add_values:
|
|
print(f"{name}: no successful runs")
|
|
continue
|
|
avg_add = sum(add_values) / len(add_values)
|
|
avg_total = sum(total_values) / len(total_values)
|
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
|
runs = len(add_values)
|
|
print(
|
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
|
)
|
|
|
|
if args.plot_path:
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
|
|
labels = [name for name, *_ in scenarios]
|
|
values = [
|
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
|
if results_ms_per_passage[name]
|
|
else 0.0
|
|
for name in labels
|
|
]
|
|
|
|
def _auto_cap(vals: list[float]) -> float | None:
|
|
s = sorted(vals, reverse=True)
|
|
if len(s) < 2:
|
|
return None
|
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
|
return s[1] * 1.1
|
|
return None
|
|
|
|
def _fmt_ms(v: float) -> str:
|
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
|
|
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
|
|
|
if args.broken_y:
|
|
s = sorted(values, reverse=True)
|
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
|
upper_start = (
|
|
args.upper_start_y
|
|
if args.upper_start_y is not None
|
|
else max(second * 1.2, lower_cap * 1.02)
|
|
)
|
|
ymax = max(values) * 1.10 if values else 1.0
|
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
|
2,
|
|
1,
|
|
sharex=True,
|
|
figsize=(7.4, 5.0),
|
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
|
)
|
|
x = list(range(len(labels)))
|
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
|
ax_bottom.set_ylim(0, lower_cap)
|
|
ax_top.set_ylim(upper_start, ymax)
|
|
for i, v in enumerate(values):
|
|
if v <= lower_cap:
|
|
ax_bottom.text(
|
|
i,
|
|
v + lower_cap * 0.02,
|
|
_fmt_ms(v),
|
|
ha="center",
|
|
va="bottom",
|
|
fontsize=9,
|
|
)
|
|
else:
|
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
ax_top.spines["bottom"].set_visible(False)
|
|
ax_bottom.spines["top"].set_visible(False)
|
|
ax_top.tick_params(labeltop=False)
|
|
ax_bottom.xaxis.tick_bottom()
|
|
d = 0.015
|
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
|
kwargs.update({"transform": ax_bottom.transAxes})
|
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
|
ax_bottom.set_xticks(range(len(labels)))
|
|
ax_bottom.set_xticklabels(labels)
|
|
ax = ax_bottom
|
|
else:
|
|
cap = args.cap_y or _auto_cap(values)
|
|
plt.figure(figsize=(7.2, 4.2))
|
|
ax = plt.gca()
|
|
if cap is not None:
|
|
show_vals = [min(v, cap) for v in values]
|
|
bars = []
|
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
|
bars.append(b[0])
|
|
if v > cap:
|
|
bars[-1].set_hatch("//")
|
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
|
else:
|
|
ax.text(
|
|
i,
|
|
show + max(1.0, 0.01 * (cap or show)),
|
|
_fmt_ms(v),
|
|
ha="center",
|
|
va="bottom",
|
|
fontsize=9,
|
|
)
|
|
ax.set_ylim(0, cap * 1.10)
|
|
ax.plot(
|
|
[0.02 - 0.02, 0.02 + 0.02],
|
|
[0.98 + 0.02, 0.98 - 0.02],
|
|
transform=ax.transAxes,
|
|
color="k",
|
|
lw=1,
|
|
)
|
|
ax.plot(
|
|
[0.98 - 0.02, 0.98 + 0.02],
|
|
[0.98 + 0.02, 0.98 - 0.02],
|
|
transform=ax.transAxes,
|
|
color="k",
|
|
lw=1,
|
|
)
|
|
if any(v > cap for v in values):
|
|
ax.legend(
|
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
|
)
|
|
ax.set_xticks(range(len(labels)))
|
|
ax.set_xticklabels(labels)
|
|
else:
|
|
ax.bar(labels, values, color=colors[: len(labels)])
|
|
for idx, val in enumerate(values):
|
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
|
|
|
plt.ylabel("Average add latency (ms per passage)")
|
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
|
plt.tight_layout()
|
|
plt.savefig(args.plot_path)
|
|
print(f"Saved latency bar plot to {args.plot_path}")
|
|
# ZMQ time split (Stage A vs B/C)
|
|
try:
|
|
plt.figure(figsize=(6, 4))
|
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
|
bc_vals = [
|
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
|
]
|
|
ind = range(len(labels))
|
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
|
plt.bar(
|
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
|
)
|
|
plt.xticks(list(ind), labels, rotation=10)
|
|
plt.ylabel("Server ZMQ time (s)")
|
|
plt.title(
|
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
|
)
|
|
plt.legend()
|
|
out2 = args.plot_path.with_name(
|
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
|
)
|
|
plt.tight_layout()
|
|
plt.savefig(out2)
|
|
print(f"Saved ZMQ time split plot to {out2}")
|
|
except Exception as e:
|
|
print("Failed to plot ZMQ split:", e)
|
|
except ImportError:
|
|
print("matplotlib not available; skipping plot generation")
|
|
|
|
# leave the last build on disk for inspection
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|