diff --git a/packages/leann-backend-hnsw/third_party/faiss b/packages/leann-backend-hnsw/third_party/faiss index 1d51f0c..c69511a 160000 --- a/packages/leann-backend-hnsw/third_party/faiss +++ b/packages/leann-backend-hnsw/third_party/faiss @@ -1 +1 @@ -Subproject commit 1d51f0c07420808a18f85a4db6636fd25e4a1daa +Subproject commit c69511a99cb78edf094a4ea304ab3db71cb54327 diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 8135daf..12960b9 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -591,9 +591,24 @@ class HFChat(LLMInterface): self.tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Loading model {model_name}...") + # Choose a numerically stable dtype per device + if self.device == "cuda": + # Prefer bfloat16 when available; otherwise fall back to float16 + try: + bf16_ok = torch.cuda.is_bf16_supported() + except Exception: + bf16_ok = False + load_dtype = torch.bfloat16 if bf16_ok else torch.float16 + elif self.device == "mps": + # On Apple MPS, float16 often causes NaNs/INFs during sampling. + # Use float32 for stability, even if it increases memory. + load_dtype = torch.float32 + else: + load_dtype = torch.float32 + self.model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float16 if self.device != "cpu" else torch.float32, + torch_dtype=load_dtype, device_map="auto" if self.device != "cpu" else None, trust_remote_code=True, ) @@ -650,7 +665,8 @@ class HFChat(LLMInterface): return_tensors="pt", padding=True, truncation=True, - max_length=2048, + # Respect model context length when available + max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048), ) # Move inputs to device @@ -665,6 +681,8 @@ class HFChat(LLMInterface): "do_sample": kwargs.get("temperature", 0.7) > 0, "pad_token_id": self.tokenizer.eos_token_id, "eos_token_id": self.tokenizer.eos_token_id, + # Helps avoid numerical issues in sampling when logits processors are used + "renormalize_logits": True, } # Handle temperature=0 for greedy decoding diff --git a/scripts/measure_generation_times.py b/scripts/measure_generation_times.py index 6307685..39753e6 100755 --- a/scripts/measure_generation_times.py +++ b/scripts/measure_generation_times.py @@ -44,6 +44,8 @@ def measure_generation_times( generation_kwargs: dict[str, object], allow_truncation: bool, enable_qwen_thinking: bool, + verbose: bool, + per_call_timeout: int | None, ): timings: list[float] = [] tokenizer = getattr(llm, "tokenizer", None) @@ -65,7 +67,18 @@ def measure_generation_times( context_max_length = max(max_positions - requested_new_tokens, 1) suppress_buffer = io.StringIO() - for prompt in prompts: + # Log base config + if verbose: + device = getattr(llm, "device", None) + try: + dtype = getattr(getattr(llm, "model", None), "dtype", None) + except Exception: + dtype = None + print( + f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}" + ) + total = len(prompts) + for idx, prompt in enumerate(prompts, start=1): prompt_for_llm = prompt if ( enable_qwen_thinking @@ -86,11 +99,58 @@ def measure_generation_times( per_call_kwargs = dict(generation_kwargs) if requested_new_tokens is not None: per_call_kwargs["max_new_tokens"] = requested_new_tokens + # Enable streaming if requested (HF backend will print tokens) + if verbose: + # When verbose (or --stream propagated), enable streaming in HF backend + per_call_kwargs["stream"] = True + + # Extra debug info about token lengths + if verbose and tokenizer is not None: + try: + toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False) + in_len = ( + len(toks["input_ids"]) + if isinstance(toks["input_ids"], list) + else len(toks["input_ids"][0]) + ) + except Exception: + in_len = None + print(f"[dbg] prompt {idx}/{total} tokens={in_len}") + print( + f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}" + ) start = time.perf_counter() - with contextlib.redirect_stdout(suppress_buffer): - llm.ask(prompt_for_llm, **per_call_kwargs) - end = time.perf_counter() + # Optional per-call timeout using signal alarm + timeout_handler_installed = False + if per_call_timeout is not None: + import signal + + def _timeout_handler(signum, frame): + raise TimeoutError("generation timed out") + + old_handler = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(int(per_call_timeout)) + timeout_handler_installed = True + + try: + if verbose: + print("[dbg] generation_start") + llm.ask(prompt_for_llm, **per_call_kwargs) + print("[dbg] generation_done") + else: + with contextlib.redirect_stdout(suppress_buffer): + llm.ask(prompt_for_llm, **per_call_kwargs) + except TimeoutError: + if verbose: + print("[dbg] generation_timeout") + finally: + if timeout_handler_installed: + import signal + + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + end = time.perf_counter() timings.append(end - start) suppress_buffer.seek(0) suppress_buffer.truncate(0) @@ -154,23 +214,69 @@ def parse_args(): action="store_true", help="Append /think to prompts for Qwen models", ) + parser.add_argument( + "--no-max-new-tokens", + action="store_true", + help="Do not set max_new_tokens in generation kwargs", + ) + parser.add_argument( + "--per-call-timeout", + type=int, + default=None, + help="Optional timeout (seconds) per generation call; if hit, moves to next prompt", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream generated text to stdout during generation", + ) + parser.add_argument( + "--datasets", + type=str, + default=None, + help=( + "Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. " + "Default: all" + ), + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable debug logging and show generation progress", + ) return parser.parse_args() def main(): args = parse_args() - dataset_files = [ - Path("prompt_all_nq_bm25.txt"), - Path("prompt_all_nq_diskann_full.txt"), - Path("prompt_all_nq_diskann_pq5.txt"), - Path("prompt_all_nq_hnsw.txt"), - ] + dataset_map = { + # "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"), + # "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"), + # "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"), + # "nq_bm25": Path("prompt_dump_nq_bm25.txt"), + # # "nq_diskann": Path("prompt_dump_nq_diskann.txt"), + # "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"), + "gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"), + "gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"), + # "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"), + # "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"), + # "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"), + } + if args.datasets: + selected = [k.strip() for k in args.datasets.split(",") if k.strip()] + invalid = [k for k in selected if k not in dataset_map] + if invalid: + raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}") + dataset_files = [dataset_map[k] for k in selected] + else: + dataset_files = list(dataset_map.values()) generation_kwargs = { - "max_new_tokens": args.max_new_tokens, "temperature": args.temperature, "top_p": args.top_p, } + if not args.no_max_new_tokens: + generation_kwargs["max_new_tokens"] = args.max_new_tokens results: dict[str, dict[str, float | int]] = {} @@ -191,12 +297,16 @@ def main(): prompts = load_prompts(dataset_path) if args.max_prompts is not None: prompts = prompts[: args.max_prompts] + if args.verbose: + print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)") timings = measure_generation_times( prompts, llm, generation_kwargs, args.allow_truncation, args.qwen_thinking, + args.verbose or args.stream, + args.per_call_timeout, ) total_time = sum(timings) count = len(timings)