site-packages back

This commit is contained in:
Andy Lee
2025-10-30 22:00:34 +00:00
parent abc12d5069
commit bc6c53edf0

View File

@@ -12,8 +12,6 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest( def validate_model_and_suggest(
model_name: str, llm_type: str, host: Optional[str] = None model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> Optional[str]: ) -> Optional[str]:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
resolved_host = resolve_ollama_host(host) available_models = check_ollama_models(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models: if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation." error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -460,19 +457,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface): class OllamaChat(LLMInterface):
"""LLM interface for Ollama models.""" """LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None): def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model self.model = model
self.host = resolve_ollama_host(host) self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'") logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try: try:
import requests import requests
# Check if the Ollama server is responsive # Check if the Ollama server is responsive
if self.host: if host:
requests.get(self.host) requests.get(host)
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", self.host) model_error = validate_model_and_suggest(model, "ollama", host)
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
@@ -481,11 +478,9 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error( logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running." f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
@@ -582,9 +577,9 @@ class HFChat(LLMInterface):
def timeout_handler(signum, frame): def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out") raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (60 seconds) # Set timeout for model loading (increase to 300s for large models)
old_handler = signal.signal(signal.SIGALRM, timeout_handler) old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60) signal.alarm(300)
try: try:
logger.info(f"Loading tokenizer for {model_name}...") logger.info(f"Loading tokenizer for {model_name}...")
@@ -626,8 +621,12 @@ class HFChat(LLMInterface):
logger.error(f"Failed to load model {model_name}: {e}") logger.error(f"Failed to load model {model_name}: {e}")
raise raise
# Move model to device if not using device_map # Move model to device only if not managed by accelerate (no device_map)
if self.device != "cpu" and "device_map" not in str(self.model): try:
has_device_map = getattr(self.model, "hf_device_map", None) is not None
except Exception:
has_device_map = False
if self.device != "cpu" and not has_device_map:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
# Set pad token if not present # Set pad token if not present
@@ -659,14 +658,15 @@ class HFChat(LLMInterface):
# Fallback for models without chat template # Fallback for models without chat template
formatted_prompt = prompt formatted_prompt = prompt
# Tokenize input # Tokenize input (respect model context length when available)
inputs = self.tokenizer( inputs = self.tokenizer(
formatted_prompt, formatted_prompt,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
# Respect model context length when available max_length=getattr(
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048), getattr(self.model, "config", None), "max_position_embeddings", 2048
),
) )
# Move inputs to device # Move inputs to device
@@ -692,11 +692,39 @@ class HFChat(LLMInterface):
logger.info(f"Generating with HuggingFace model, config: {generation_config}") logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate # Streaming support (optional)
stream = bool(kwargs.get("stream", False))
if stream:
try:
from threading import Thread
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
def _gen():
with torch.no_grad():
self.model.generate(**inputs, **generation_config, streamer=streamer)
t = Thread(target=_gen)
t.start()
pieces = []
for new_text in streamer:
print(new_text, end="", flush=True)
pieces.append(new_text)
t.join()
print("") # newline after streaming
return ("".join(pieces)).strip()
except Exception as e:
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
# Non-streaming path
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config) outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@@ -760,31 +788,21 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models.""" """LLM interface for OpenAI models."""
def __init__( def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model self.model = model
self.base_url = resolve_openai_base_url(base_url) self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
) )
logger.info( logger.info(f"Initializing OpenAI Chat with model='{model}'")
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try: try:
import openai import openai
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) self.client = openai.OpenAI(api_key=self.api_key)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -874,16 +892,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama": if llm_type == "ollama":
return OllamaChat( return OllamaChat(
model=model or "llama3:8b", model=model or "llama3:8b",
host=llm_config.get("host"), host=llm_config.get("host", "http://localhost:11434"),
) )
elif llm_type == "hf": elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai": elif llm_type == "openai":
return OpenAIChat( return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated": elif llm_type == "simulated":