Files
LEANN/benchmarks/update/bench_hnsw_rng_recompute.py
Andy Lee d4f5f2896f Faster Update (#148)
* stash

* stash

* add std err in add and trace progress

* fix.

* docs

* style: format

* docs

* better figs

* better figs

* update results

* fotmat

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-11-05 13:37:47 -08:00

805 lines
28 KiB
Python

"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()