site-packages back
This commit is contained in:
@@ -12,8 +12,6 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: Optional[str] = None
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
resolved_host = resolve_ollama_host(host)
|
||||
available_models = check_ollama_models(resolved_host)
|
||||
available_models = check_ollama_models(host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -460,19 +457,19 @@ class LLMInterface(ABC):
|
||||
class OllamaChat(LLMInterface):
|
||||
"""LLM interface for Ollama models."""
|
||||
|
||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||
self.model = model
|
||||
self.host = resolve_ollama_host(host)
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||
self.host = host
|
||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||
try:
|
||||
import requests
|
||||
|
||||
# Check if the Ollama server is responsive
|
||||
if self.host:
|
||||
requests.get(self.host)
|
||||
if host:
|
||||
requests.get(host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -481,11 +478,9 @@ class OllamaChat(LLMInterface):
|
||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
)
|
||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
@@ -582,9 +577,9 @@ class HFChat(LLMInterface):
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError("Model download/loading timed out")
|
||||
|
||||
# Set timeout for model loading (60 seconds)
|
||||
# Set timeout for model loading (increase to 300s for large models)
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(60)
|
||||
signal.alarm(300)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
@@ -626,8 +621,12 @@ class HFChat(LLMInterface):
|
||||
logger.error(f"Failed to load model {model_name}: {e}")
|
||||
raise
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
# Move model to device only if not managed by accelerate (no device_map)
|
||||
try:
|
||||
has_device_map = getattr(self.model, "hf_device_map", None) is not None
|
||||
except Exception:
|
||||
has_device_map = False
|
||||
if self.device != "cpu" and not has_device_map:
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Set pad token if not present
|
||||
@@ -659,14 +658,15 @@ class HFChat(LLMInterface):
|
||||
# Fallback for models without chat template
|
||||
formatted_prompt = prompt
|
||||
|
||||
# Tokenize input
|
||||
# Tokenize input (respect model context length when available)
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
# Respect model context length when available
|
||||
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048),
|
||||
max_length=getattr(
|
||||
getattr(self.model, "config", None), "max_position_embeddings", 2048
|
||||
),
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
@@ -692,11 +692,39 @@ class HFChat(LLMInterface):
|
||||
|
||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||
|
||||
# Generate
|
||||
# Streaming support (optional)
|
||||
stream = bool(kwargs.get("stream", False))
|
||||
if stream:
|
||||
try:
|
||||
from threading import Thread
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
def _gen():
|
||||
with torch.no_grad():
|
||||
self.model.generate(**inputs, **generation_config, streamer=streamer)
|
||||
|
||||
t = Thread(target=_gen)
|
||||
t.start()
|
||||
|
||||
pieces = []
|
||||
for new_text in streamer:
|
||||
print(new_text, end="", flush=True)
|
||||
pieces.append(new_text)
|
||||
t.join()
|
||||
print("") # newline after streaming
|
||||
return ("".join(pieces)).strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
|
||||
|
||||
# Non-streaming path
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(**inputs, **generation_config)
|
||||
|
||||
# Decode response
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
@@ -760,31 +788,21 @@ class GeminiChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
):
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
self.model = model
|
||||
self.base_url = resolve_openai_base_url(base_url)
|
||||
self.api_key = resolve_openai_api_key(api_key)
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||
model,
|
||||
self.base_url,
|
||||
)
|
||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
self.client = openai.OpenAI(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||
@@ -874,16 +892,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
if llm_type == "ollama":
|
||||
return OllamaChat(
|
||||
model=model or "llama3:8b",
|
||||
host=llm_config.get("host"),
|
||||
host=llm_config.get("host", "http://localhost:11434"),
|
||||
)
|
||||
elif llm_type == "hf":
|
||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||
elif llm_type == "openai":
|
||||
return OpenAIChat(
|
||||
model=model or "gpt-4o",
|
||||
api_key=llm_config.get("api_key"),
|
||||
base_url=llm_config.get("base_url"),
|
||||
)
|
||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "gemini":
|
||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||
elif llm_type == "simulated":
|
||||
|
||||
Reference in New Issue
Block a user