stash
This commit is contained in:
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 1d51f0c074...c69511a99c
@@ -591,9 +591,24 @@ class HFChat(LLMInterface):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
logger.info(f"Loading model {model_name}...")
|
||||||
|
# Choose a numerically stable dtype per device
|
||||||
|
if self.device == "cuda":
|
||||||
|
# Prefer bfloat16 when available; otherwise fall back to float16
|
||||||
|
try:
|
||||||
|
bf16_ok = torch.cuda.is_bf16_supported()
|
||||||
|
except Exception:
|
||||||
|
bf16_ok = False
|
||||||
|
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||||
|
elif self.device == "mps":
|
||||||
|
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
|
||||||
|
# Use float32 for stability, even if it increases memory.
|
||||||
|
load_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
load_dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
torch_dtype=load_dtype,
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
@@ -650,7 +665,8 @@ class HFChat(LLMInterface):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=2048,
|
# Respect model context length when available
|
||||||
|
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
@@ -665,6 +681,8 @@ class HFChat(LLMInterface):
|
|||||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
# Helps avoid numerical issues in sampling when logits processors are used
|
||||||
|
"renormalize_logits": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle temperature=0 for greedy decoding
|
# Handle temperature=0 for greedy decoding
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ def measure_generation_times(
|
|||||||
generation_kwargs: dict[str, object],
|
generation_kwargs: dict[str, object],
|
||||||
allow_truncation: bool,
|
allow_truncation: bool,
|
||||||
enable_qwen_thinking: bool,
|
enable_qwen_thinking: bool,
|
||||||
|
verbose: bool,
|
||||||
|
per_call_timeout: int | None,
|
||||||
):
|
):
|
||||||
timings: list[float] = []
|
timings: list[float] = []
|
||||||
tokenizer = getattr(llm, "tokenizer", None)
|
tokenizer = getattr(llm, "tokenizer", None)
|
||||||
@@ -65,7 +67,18 @@ def measure_generation_times(
|
|||||||
context_max_length = max(max_positions - requested_new_tokens, 1)
|
context_max_length = max(max_positions - requested_new_tokens, 1)
|
||||||
|
|
||||||
suppress_buffer = io.StringIO()
|
suppress_buffer = io.StringIO()
|
||||||
for prompt in prompts:
|
# Log base config
|
||||||
|
if verbose:
|
||||||
|
device = getattr(llm, "device", None)
|
||||||
|
try:
|
||||||
|
dtype = getattr(getattr(llm, "model", None), "dtype", None)
|
||||||
|
except Exception:
|
||||||
|
dtype = None
|
||||||
|
print(
|
||||||
|
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
|
||||||
|
)
|
||||||
|
total = len(prompts)
|
||||||
|
for idx, prompt in enumerate(prompts, start=1):
|
||||||
prompt_for_llm = prompt
|
prompt_for_llm = prompt
|
||||||
if (
|
if (
|
||||||
enable_qwen_thinking
|
enable_qwen_thinking
|
||||||
@@ -86,11 +99,58 @@ def measure_generation_times(
|
|||||||
per_call_kwargs = dict(generation_kwargs)
|
per_call_kwargs = dict(generation_kwargs)
|
||||||
if requested_new_tokens is not None:
|
if requested_new_tokens is not None:
|
||||||
per_call_kwargs["max_new_tokens"] = requested_new_tokens
|
per_call_kwargs["max_new_tokens"] = requested_new_tokens
|
||||||
|
# Enable streaming if requested (HF backend will print tokens)
|
||||||
|
if verbose:
|
||||||
|
# When verbose (or --stream propagated), enable streaming in HF backend
|
||||||
|
per_call_kwargs["stream"] = True
|
||||||
|
|
||||||
|
# Extra debug info about token lengths
|
||||||
|
if verbose and tokenizer is not None:
|
||||||
|
try:
|
||||||
|
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
|
||||||
|
in_len = (
|
||||||
|
len(toks["input_ids"])
|
||||||
|
if isinstance(toks["input_ids"], list)
|
||||||
|
else len(toks["input_ids"][0])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
in_len = None
|
||||||
|
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
|
||||||
|
print(
|
||||||
|
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
|
||||||
|
)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
with contextlib.redirect_stdout(suppress_buffer):
|
# Optional per-call timeout using signal alarm
|
||||||
llm.ask(prompt_for_llm, **per_call_kwargs)
|
timeout_handler_installed = False
|
||||||
end = time.perf_counter()
|
if per_call_timeout is not None:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("generation timed out")
|
||||||
|
|
||||||
|
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(int(per_call_timeout))
|
||||||
|
timeout_handler_installed = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_start")
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
print("[dbg] generation_done")
|
||||||
|
else:
|
||||||
|
with contextlib.redirect_stdout(suppress_buffer):
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
except TimeoutError:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_timeout")
|
||||||
|
finally:
|
||||||
|
if timeout_handler_installed:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
signal.alarm(0)
|
||||||
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
end = time.perf_counter()
|
||||||
timings.append(end - start)
|
timings.append(end - start)
|
||||||
suppress_buffer.seek(0)
|
suppress_buffer.seek(0)
|
||||||
suppress_buffer.truncate(0)
|
suppress_buffer.truncate(0)
|
||||||
@@ -154,23 +214,69 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Append /think to prompts for Qwen models",
|
help="Append /think to prompts for Qwen models",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-max-new-tokens",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not set max_new_tokens in generation kwargs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-call-timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Stream generated text to stdout during generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--datasets",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
|
||||||
|
"Default: all"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug logging and show generation progress",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
dataset_files = [
|
dataset_map = {
|
||||||
Path("prompt_all_nq_bm25.txt"),
|
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
|
||||||
Path("prompt_all_nq_diskann_full.txt"),
|
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
|
||||||
Path("prompt_all_nq_diskann_pq5.txt"),
|
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
|
||||||
Path("prompt_all_nq_hnsw.txt"),
|
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
|
||||||
]
|
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
|
||||||
|
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
|
||||||
|
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
|
||||||
|
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
|
||||||
|
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
|
||||||
|
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
|
||||||
|
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
|
||||||
|
}
|
||||||
|
if args.datasets:
|
||||||
|
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
|
||||||
|
invalid = [k for k in selected if k not in dataset_map]
|
||||||
|
if invalid:
|
||||||
|
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
|
||||||
|
dataset_files = [dataset_map[k] for k in selected]
|
||||||
|
else:
|
||||||
|
dataset_files = list(dataset_map.values())
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"max_new_tokens": args.max_new_tokens,
|
|
||||||
"temperature": args.temperature,
|
"temperature": args.temperature,
|
||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
||||||
}
|
}
|
||||||
|
if not args.no_max_new_tokens:
|
||||||
|
generation_kwargs["max_new_tokens"] = args.max_new_tokens
|
||||||
|
|
||||||
results: dict[str, dict[str, float | int]] = {}
|
results: dict[str, dict[str, float | int]] = {}
|
||||||
|
|
||||||
@@ -191,12 +297,16 @@ def main():
|
|||||||
prompts = load_prompts(dataset_path)
|
prompts = load_prompts(dataset_path)
|
||||||
if args.max_prompts is not None:
|
if args.max_prompts is not None:
|
||||||
prompts = prompts[: args.max_prompts]
|
prompts = prompts[: args.max_prompts]
|
||||||
|
if args.verbose:
|
||||||
|
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
|
||||||
timings = measure_generation_times(
|
timings = measure_generation_times(
|
||||||
prompts,
|
prompts,
|
||||||
llm,
|
llm,
|
||||||
generation_kwargs,
|
generation_kwargs,
|
||||||
args.allow_truncation,
|
args.allow_truncation,
|
||||||
args.qwen_thinking,
|
args.qwen_thinking,
|
||||||
|
args.verbose or args.stream,
|
||||||
|
args.per_call_timeout,
|
||||||
)
|
)
|
||||||
total_time = sum(timings)
|
total_time = sum(timings)
|
||||||
count = len(timings)
|
count = len(timings)
|
||||||
|
|||||||
Reference in New Issue
Block a user