Compare commits
3 Commits
financeben
...
gen-time
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc6c53edf0 | ||
|
|
abc12d5069 | ||
|
|
9ba0ecac15 |
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 1d51f0c074...c69511a99c
@@ -12,8 +12,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(
|
||||||
model_name: str, llm_type: str, host: Optional[str] = None
|
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
resolved_host = resolve_ollama_host(host)
|
available_models = check_ollama_models(host)
|
||||||
available_models = check_ollama_models(resolved_host)
|
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -460,19 +457,19 @@ class LLMInterface(ABC):
|
|||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = resolve_ollama_host(host)
|
self.host = host
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if self.host:
|
if host:
|
||||||
requests.get(self.host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -481,11 +478,9 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(
|
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
|
||||||
)
|
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
@@ -582,18 +577,33 @@ class HFChat(LLMInterface):
|
|||||||
def timeout_handler(signum, frame):
|
def timeout_handler(signum, frame):
|
||||||
raise TimeoutError("Model download/loading timed out")
|
raise TimeoutError("Model download/loading timed out")
|
||||||
|
|
||||||
# Set timeout for model loading (60 seconds)
|
# Set timeout for model loading (increase to 300s for large models)
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
signal.alarm(60)
|
signal.alarm(300)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
logger.info(f"Loading tokenizer for {model_name}...")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
logger.info(f"Loading model {model_name}...")
|
||||||
|
# Choose a numerically stable dtype per device
|
||||||
|
if self.device == "cuda":
|
||||||
|
# Prefer bfloat16 when available; otherwise fall back to float16
|
||||||
|
try:
|
||||||
|
bf16_ok = torch.cuda.is_bf16_supported()
|
||||||
|
except Exception:
|
||||||
|
bf16_ok = False
|
||||||
|
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||||
|
elif self.device == "mps":
|
||||||
|
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
|
||||||
|
# Use float32 for stability, even if it increases memory.
|
||||||
|
load_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
load_dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
torch_dtype=load_dtype,
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
@@ -611,8 +621,12 @@ class HFChat(LLMInterface):
|
|||||||
logger.error(f"Failed to load model {model_name}: {e}")
|
logger.error(f"Failed to load model {model_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Move model to device if not using device_map
|
# Move model to device only if not managed by accelerate (no device_map)
|
||||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
try:
|
||||||
|
has_device_map = getattr(self.model, "hf_device_map", None) is not None
|
||||||
|
except Exception:
|
||||||
|
has_device_map = False
|
||||||
|
if self.device != "cpu" and not has_device_map:
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
# Set pad token if not present
|
# Set pad token if not present
|
||||||
@@ -644,13 +658,15 @@ class HFChat(LLMInterface):
|
|||||||
# Fallback for models without chat template
|
# Fallback for models without chat template
|
||||||
formatted_prompt = prompt
|
formatted_prompt = prompt
|
||||||
|
|
||||||
# Tokenize input
|
# Tokenize input (respect model context length when available)
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
formatted_prompt,
|
formatted_prompt,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=2048,
|
max_length=getattr(
|
||||||
|
getattr(self.model, "config", None), "max_position_embeddings", 2048
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
@@ -665,6 +681,8 @@ class HFChat(LLMInterface):
|
|||||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
# Helps avoid numerical issues in sampling when logits processors are used
|
||||||
|
"renormalize_logits": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle temperature=0 for greedy decoding
|
# Handle temperature=0 for greedy decoding
|
||||||
@@ -674,11 +692,39 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
# Generate
|
# Streaming support (optional)
|
||||||
|
stream = bool(kwargs.get("stream", False))
|
||||||
|
if stream:
|
||||||
|
try:
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen():
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.generate(**inputs, **generation_config, streamer=streamer)
|
||||||
|
|
||||||
|
t = Thread(target=_gen)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
pieces = []
|
||||||
|
for new_text in streamer:
|
||||||
|
print(new_text, end="", flush=True)
|
||||||
|
pieces.append(new_text)
|
||||||
|
t.join()
|
||||||
|
print("") # newline after streaming
|
||||||
|
return ("".join(pieces)).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
|
||||||
|
|
||||||
|
# Non-streaming path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model.generate(**inputs, **generation_config)
|
outputs = self.model.generate(**inputs, **generation_config)
|
||||||
|
|
||||||
# Decode response
|
|
||||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
@@ -742,31 +788,21 @@ class GeminiChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||||
self,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = resolve_openai_base_url(base_url)
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
self.api_key = resolve_openai_api_key(api_key)
|
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
|
||||||
model,
|
|
||||||
self.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
self.client = openai.OpenAI(api_key=self.api_key)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
@@ -856,16 +892,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(
|
return OllamaChat(
|
||||||
model=model or "llama3:8b",
|
model=model or "llama3:8b",
|
||||||
host=llm_config.get("host"),
|
host=llm_config.get("host", "http://localhost:11434"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
model=model or "gpt-4o",
|
|
||||||
api_key=llm_config.get("api_key"),
|
|
||||||
base_url=llm_config.get("base_url"),
|
|
||||||
)
|
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
|
|||||||
324
scripts/measure_generation_times.py
Executable file
324
scripts/measure_generation_times.py
Executable file
@@ -0,0 +1,324 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Measure generation latency of a HuggingFace/OpenAI-compatible model over prompt files."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.chat import get_llm
|
||||||
|
|
||||||
|
PROMPT_PREFIX = "PROMPT #"
|
||||||
|
logging.getLogger("leann.chat").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompts(path: Path) -> list[str]:
|
||||||
|
prompts: list[str] = []
|
||||||
|
buffer: list[str] = []
|
||||||
|
collecting = False
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as handle:
|
||||||
|
for line in handle:
|
||||||
|
if line.startswith(PROMPT_PREFIX):
|
||||||
|
if buffer:
|
||||||
|
prompts.append("".join(buffer).strip())
|
||||||
|
buffer.clear()
|
||||||
|
collecting = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if collecting:
|
||||||
|
buffer.append(line)
|
||||||
|
|
||||||
|
if buffer:
|
||||||
|
prompts.append("".join(buffer).strip())
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def measure_generation_times(
|
||||||
|
prompts: list[str],
|
||||||
|
llm,
|
||||||
|
generation_kwargs: dict[str, object],
|
||||||
|
allow_truncation: bool,
|
||||||
|
enable_qwen_thinking: bool,
|
||||||
|
verbose: bool,
|
||||||
|
per_call_timeout: int | None,
|
||||||
|
):
|
||||||
|
timings: list[float] = []
|
||||||
|
tokenizer = getattr(llm, "tokenizer", None)
|
||||||
|
max_positions = None
|
||||||
|
if hasattr(llm, "model") and hasattr(llm.model, "config"):
|
||||||
|
max_positions = getattr(llm.model.config, "max_position_embeddings", None)
|
||||||
|
|
||||||
|
requested_new_tokens = None
|
||||||
|
if max_positions is not None:
|
||||||
|
if "max_new_tokens" in generation_kwargs:
|
||||||
|
requested_new_tokens = generation_kwargs["max_new_tokens"]
|
||||||
|
elif "max_tokens" in generation_kwargs:
|
||||||
|
requested_new_tokens = generation_kwargs["max_tokens"]
|
||||||
|
|
||||||
|
context_max_length = max_positions
|
||||||
|
if max_positions is not None and requested_new_tokens is not None:
|
||||||
|
if requested_new_tokens >= max_positions:
|
||||||
|
requested_new_tokens = max_positions - 1
|
||||||
|
context_max_length = max(max_positions - requested_new_tokens, 1)
|
||||||
|
|
||||||
|
suppress_buffer = io.StringIO()
|
||||||
|
# Log base config
|
||||||
|
if verbose:
|
||||||
|
device = getattr(llm, "device", None)
|
||||||
|
try:
|
||||||
|
dtype = getattr(getattr(llm, "model", None), "dtype", None)
|
||||||
|
except Exception:
|
||||||
|
dtype = None
|
||||||
|
print(
|
||||||
|
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
|
||||||
|
)
|
||||||
|
total = len(prompts)
|
||||||
|
for idx, prompt in enumerate(prompts, start=1):
|
||||||
|
prompt_for_llm = prompt
|
||||||
|
if (
|
||||||
|
enable_qwen_thinking
|
||||||
|
and "/think" not in prompt_for_llm
|
||||||
|
and "/no_think" not in prompt_for_llm
|
||||||
|
):
|
||||||
|
prompt_for_llm = f"{prompt_for_llm}\n/think"
|
||||||
|
|
||||||
|
if allow_truncation and tokenizer is not None and max_positions is not None:
|
||||||
|
tokenized = tokenizer(
|
||||||
|
prompt_for_llm,
|
||||||
|
truncation=True,
|
||||||
|
max_length=context_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
prompt_for_llm = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
per_call_kwargs = dict(generation_kwargs)
|
||||||
|
if requested_new_tokens is not None:
|
||||||
|
per_call_kwargs["max_new_tokens"] = requested_new_tokens
|
||||||
|
# Enable streaming if requested (HF backend will print tokens)
|
||||||
|
if verbose:
|
||||||
|
# When verbose (or --stream propagated), enable streaming in HF backend
|
||||||
|
per_call_kwargs["stream"] = True
|
||||||
|
|
||||||
|
# Extra debug info about token lengths
|
||||||
|
if verbose and tokenizer is not None:
|
||||||
|
try:
|
||||||
|
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
|
||||||
|
in_len = (
|
||||||
|
len(toks["input_ids"])
|
||||||
|
if isinstance(toks["input_ids"], list)
|
||||||
|
else len(toks["input_ids"][0])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
in_len = None
|
||||||
|
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
|
||||||
|
print(
|
||||||
|
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
# Optional per-call timeout using signal alarm
|
||||||
|
timeout_handler_installed = False
|
||||||
|
if per_call_timeout is not None:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("generation timed out")
|
||||||
|
|
||||||
|
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(int(per_call_timeout))
|
||||||
|
timeout_handler_installed = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_start")
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
print("[dbg] generation_done")
|
||||||
|
else:
|
||||||
|
with contextlib.redirect_stdout(suppress_buffer):
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
except TimeoutError:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_timeout")
|
||||||
|
finally:
|
||||||
|
if timeout_handler_installed:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
signal.alarm(0)
|
||||||
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
end = time.perf_counter()
|
||||||
|
timings.append(end - start)
|
||||||
|
suppress_buffer.seek(0)
|
||||||
|
suppress_buffer.truncate(0)
|
||||||
|
|
||||||
|
return timings
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Measure generation timing for prompt files")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-prompts",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Optional limit on number of prompts to evaluate per file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-truncation",
|
||||||
|
action="store_true",
|
||||||
|
help="Allow truncating prompt context to respect model's max context",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="sshleifer/tiny-gpt2",
|
||||||
|
help="LLM model identifier (default: sshleifer/tiny-gpt2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["hf", "openai", "ollama", "gemini", "simulated"],
|
||||||
|
help="LLM backend type (default: hf)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
choices=["cpu", "auto"],
|
||||||
|
help="Device override for HF models (default: cpu)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-new-tokens",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Max new tokens per generation (default: 16)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Sampling temperature (default: 0.2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Nucleus sampling top-p (default: 0.8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--qwen-thinking",
|
||||||
|
action="store_true",
|
||||||
|
help="Append /think to prompts for Qwen models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-max-new-tokens",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not set max_new_tokens in generation kwargs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-call-timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Stream generated text to stdout during generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--datasets",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
|
||||||
|
"Default: all"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug logging and show generation progress",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
dataset_map = {
|
||||||
|
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
|
||||||
|
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
|
||||||
|
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
|
||||||
|
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
|
||||||
|
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
|
||||||
|
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
|
||||||
|
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
|
||||||
|
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
|
||||||
|
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
|
||||||
|
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
|
||||||
|
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
|
||||||
|
}
|
||||||
|
if args.datasets:
|
||||||
|
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
|
||||||
|
invalid = [k for k in selected if k not in dataset_map]
|
||||||
|
if invalid:
|
||||||
|
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
|
||||||
|
dataset_files = [dataset_map[k] for k in selected]
|
||||||
|
else:
|
||||||
|
dataset_files = list(dataset_map.values())
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
}
|
||||||
|
if not args.no_max_new_tokens:
|
||||||
|
generation_kwargs["max_new_tokens"] = args.max_new_tokens
|
||||||
|
|
||||||
|
results: dict[str, dict[str, float | int]] = {}
|
||||||
|
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.model}
|
||||||
|
try:
|
||||||
|
llm = get_llm(llm_config)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Failed to initialize LLM: {exc}")
|
||||||
|
raise SystemExit(1) from exc
|
||||||
|
|
||||||
|
if args.llm_type == "hf" and hasattr(llm, "model") and args.device == "cpu":
|
||||||
|
llm.model = llm.model.to("cpu")
|
||||||
|
if hasattr(llm, "device"):
|
||||||
|
llm.device = "cpu"
|
||||||
|
|
||||||
|
for dataset_path in dataset_files:
|
||||||
|
print(f"Processing {dataset_path.name}...")
|
||||||
|
prompts = load_prompts(dataset_path)
|
||||||
|
if args.max_prompts is not None:
|
||||||
|
prompts = prompts[: args.max_prompts]
|
||||||
|
if args.verbose:
|
||||||
|
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
|
||||||
|
timings = measure_generation_times(
|
||||||
|
prompts,
|
||||||
|
llm,
|
||||||
|
generation_kwargs,
|
||||||
|
args.allow_truncation,
|
||||||
|
args.qwen_thinking,
|
||||||
|
args.verbose or args.stream,
|
||||||
|
args.per_call_timeout,
|
||||||
|
)
|
||||||
|
total_time = sum(timings)
|
||||||
|
count = len(timings)
|
||||||
|
average_time = total_time / count if count else 0.0
|
||||||
|
results[str(dataset_path.name)] = {
|
||||||
|
"total_prompts": count,
|
||||||
|
"total_time_seconds": total_time,
|
||||||
|
"average_time_seconds": average_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(json.dumps(results, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user