stash
This commit is contained in:
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 1d51f0c074...c69511a99c
@@ -591,9 +591,24 @@ class HFChat(LLMInterface):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
# Choose a numerically stable dtype per device
|
||||
if self.device == "cuda":
|
||||
# Prefer bfloat16 when available; otherwise fall back to float16
|
||||
try:
|
||||
bf16_ok = torch.cuda.is_bf16_supported()
|
||||
except Exception:
|
||||
bf16_ok = False
|
||||
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||
elif self.device == "mps":
|
||||
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
|
||||
# Use float32 for stability, even if it increases memory.
|
||||
load_dtype = torch.float32
|
||||
else:
|
||||
load_dtype = torch.float32
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
torch_dtype=load_dtype,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
@@ -650,7 +665,8 @@ class HFChat(LLMInterface):
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
# Respect model context length when available
|
||||
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048),
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
@@ -665,6 +681,8 @@ class HFChat(LLMInterface):
|
||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
# Helps avoid numerical issues in sampling when logits processors are used
|
||||
"renormalize_logits": True,
|
||||
}
|
||||
|
||||
# Handle temperature=0 for greedy decoding
|
||||
|
||||
Reference in New Issue
Block a user