Files
LEANN/scripts/measure_generation_times.py
Andy Lee abc12d5069 stash
2025-10-30 22:00:26 +00:00

325 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
"""Measure generation latency of a HuggingFace/OpenAI-compatible model over prompt files."""
import argparse
import contextlib
import io
import json
import logging
import time
from pathlib import Path
from leann.chat import get_llm
PROMPT_PREFIX = "PROMPT #"
logging.getLogger("leann.chat").setLevel(logging.ERROR)
def load_prompts(path: Path) -> list[str]:
prompts: list[str] = []
buffer: list[str] = []
collecting = False
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.startswith(PROMPT_PREFIX):
if buffer:
prompts.append("".join(buffer).strip())
buffer.clear()
collecting = True
continue
if collecting:
buffer.append(line)
if buffer:
prompts.append("".join(buffer).strip())
return prompts
def measure_generation_times(
prompts: list[str],
llm,
generation_kwargs: dict[str, object],
allow_truncation: bool,
enable_qwen_thinking: bool,
verbose: bool,
per_call_timeout: int | None,
):
timings: list[float] = []
tokenizer = getattr(llm, "tokenizer", None)
max_positions = None
if hasattr(llm, "model") and hasattr(llm.model, "config"):
max_positions = getattr(llm.model.config, "max_position_embeddings", None)
requested_new_tokens = None
if max_positions is not None:
if "max_new_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_new_tokens"]
elif "max_tokens" in generation_kwargs:
requested_new_tokens = generation_kwargs["max_tokens"]
context_max_length = max_positions
if max_positions is not None and requested_new_tokens is not None:
if requested_new_tokens >= max_positions:
requested_new_tokens = max_positions - 1
context_max_length = max(max_positions - requested_new_tokens, 1)
suppress_buffer = io.StringIO()
# Log base config
if verbose:
device = getattr(llm, "device", None)
try:
dtype = getattr(getattr(llm, "model", None), "dtype", None)
except Exception:
dtype = None
print(
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
)
total = len(prompts)
for idx, prompt in enumerate(prompts, start=1):
prompt_for_llm = prompt
if (
enable_qwen_thinking
and "/think" not in prompt_for_llm
and "/no_think" not in prompt_for_llm
):
prompt_for_llm = f"{prompt_for_llm}\n/think"
if allow_truncation and tokenizer is not None and max_positions is not None:
tokenized = tokenizer(
prompt_for_llm,
truncation=True,
max_length=context_max_length,
return_tensors="pt",
)
prompt_for_llm = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
per_call_kwargs = dict(generation_kwargs)
if requested_new_tokens is not None:
per_call_kwargs["max_new_tokens"] = requested_new_tokens
# Enable streaming if requested (HF backend will print tokens)
if verbose:
# When verbose (or --stream propagated), enable streaming in HF backend
per_call_kwargs["stream"] = True
# Extra debug info about token lengths
if verbose and tokenizer is not None:
try:
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
in_len = (
len(toks["input_ids"])
if isinstance(toks["input_ids"], list)
else len(toks["input_ids"][0])
)
except Exception:
in_len = None
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
print(
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
)
start = time.perf_counter()
# Optional per-call timeout using signal alarm
timeout_handler_installed = False
if per_call_timeout is not None:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("generation timed out")
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(int(per_call_timeout))
timeout_handler_installed = True
try:
if verbose:
print("[dbg] generation_start")
llm.ask(prompt_for_llm, **per_call_kwargs)
print("[dbg] generation_done")
else:
with contextlib.redirect_stdout(suppress_buffer):
llm.ask(prompt_for_llm, **per_call_kwargs)
except TimeoutError:
if verbose:
print("[dbg] generation_timeout")
finally:
if timeout_handler_installed:
import signal
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
end = time.perf_counter()
timings.append(end - start)
suppress_buffer.seek(0)
suppress_buffer.truncate(0)
return timings
def parse_args():
parser = argparse.ArgumentParser(description="Measure generation timing for prompt files")
parser.add_argument(
"--max-prompts",
type=int,
default=None,
help="Optional limit on number of prompts to evaluate per file",
)
parser.add_argument(
"--allow-truncation",
action="store_true",
help="Allow truncating prompt context to respect model's max context",
)
parser.add_argument(
"--model",
type=str,
default="sshleifer/tiny-gpt2",
help="LLM model identifier (default: sshleifer/tiny-gpt2)",
)
parser.add_argument(
"--llm-type",
type=str,
default="hf",
choices=["hf", "openai", "ollama", "gemini", "simulated"],
help="LLM backend type (default: hf)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "auto"],
help="Device override for HF models (default: cpu)",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=16,
help="Max new tokens per generation (default: 16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="Sampling temperature (default: 0.2)",
)
parser.add_argument(
"--top-p",
type=float,
default=0.8,
help="Nucleus sampling top-p (default: 0.8)",
)
parser.add_argument(
"--qwen-thinking",
action="store_true",
help="Append /think to prompts for Qwen models",
)
parser.add_argument(
"--no-max-new-tokens",
action="store_true",
help="Do not set max_new_tokens in generation kwargs",
)
parser.add_argument(
"--per-call-timeout",
type=int,
default=None,
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream generated text to stdout during generation",
)
parser.add_argument(
"--datasets",
type=str,
default=None,
help=(
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
"Default: all"
),
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable debug logging and show generation progress",
)
return parser.parse_args()
def main():
args = parse_args()
dataset_map = {
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
}
if args.datasets:
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
invalid = [k for k in selected if k not in dataset_map]
if invalid:
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
dataset_files = [dataset_map[k] for k in selected]
else:
dataset_files = list(dataset_map.values())
generation_kwargs = {
"temperature": args.temperature,
"top_p": args.top_p,
}
if not args.no_max_new_tokens:
generation_kwargs["max_new_tokens"] = args.max_new_tokens
results: dict[str, dict[str, float | int]] = {}
llm_config = {"type": args.llm_type, "model": args.model}
try:
llm = get_llm(llm_config)
except Exception as exc:
print(f"Failed to initialize LLM: {exc}")
raise SystemExit(1) from exc
if args.llm_type == "hf" and hasattr(llm, "model") and args.device == "cpu":
llm.model = llm.model.to("cpu")
if hasattr(llm, "device"):
llm.device = "cpu"
for dataset_path in dataset_files:
print(f"Processing {dataset_path.name}...")
prompts = load_prompts(dataset_path)
if args.max_prompts is not None:
prompts = prompts[: args.max_prompts]
if args.verbose:
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
timings = measure_generation_times(
prompts,
llm,
generation_kwargs,
args.allow_truncation,
args.qwen_thinking,
args.verbose or args.stream,
args.per_call_timeout,
)
total_time = sum(timings)
count = len(timings)
average_time = total_time / count if count else 0.0
results[str(dataset_path.name)] = {
"total_prompts": count,
"total_time_seconds": total_time,
"average_time_seconds": average_time,
}
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()