This commit is contained in:
Andy Lee
2025-10-30 22:00:26 +00:00
parent 9ba0ecac15
commit abc12d5069
3 changed files with 142 additions and 14 deletions

View File

@@ -591,9 +591,24 @@ class HFChat(LLMInterface):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
# Choose a numerically stable dtype per device
if self.device == "cuda":
# Prefer bfloat16 when available; otherwise fall back to float16
try:
bf16_ok = torch.cuda.is_bf16_supported()
except Exception:
bf16_ok = False
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
elif self.device == "mps":
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
# Use float32 for stability, even if it increases memory.
load_dtype = torch.float32
else:
load_dtype = torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
torch_dtype=load_dtype,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
@@ -650,7 +665,8 @@ class HFChat(LLMInterface):
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
# Respect model context length when available
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048),
)
# Move inputs to device
@@ -665,6 +681,8 @@ class HFChat(LLMInterface):
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
# Helps avoid numerical issues in sampling when logits processors are used
"renormalize_logits": True,
}
# Handle temperature=0 for greedy decoding