site-packages back
This commit is contained in:
@@ -12,8 +12,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(
|
||||||
model_name: str, llm_type: str, host: Optional[str] = None
|
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
resolved_host = resolve_ollama_host(host)
|
available_models = check_ollama_models(host)
|
||||||
available_models = check_ollama_models(resolved_host)
|
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -460,19 +457,19 @@ class LLMInterface(ABC):
|
|||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = resolve_ollama_host(host)
|
self.host = host
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if self.host:
|
if host:
|
||||||
requests.get(self.host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -481,11 +478,9 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(
|
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
|
||||||
)
|
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
@@ -582,9 +577,9 @@ class HFChat(LLMInterface):
|
|||||||
def timeout_handler(signum, frame):
|
def timeout_handler(signum, frame):
|
||||||
raise TimeoutError("Model download/loading timed out")
|
raise TimeoutError("Model download/loading timed out")
|
||||||
|
|
||||||
# Set timeout for model loading (60 seconds)
|
# Set timeout for model loading (increase to 300s for large models)
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
signal.alarm(60)
|
signal.alarm(300)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
logger.info(f"Loading tokenizer for {model_name}...")
|
||||||
@@ -626,8 +621,12 @@ class HFChat(LLMInterface):
|
|||||||
logger.error(f"Failed to load model {model_name}: {e}")
|
logger.error(f"Failed to load model {model_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Move model to device if not using device_map
|
# Move model to device only if not managed by accelerate (no device_map)
|
||||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
try:
|
||||||
|
has_device_map = getattr(self.model, "hf_device_map", None) is not None
|
||||||
|
except Exception:
|
||||||
|
has_device_map = False
|
||||||
|
if self.device != "cpu" and not has_device_map:
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
# Set pad token if not present
|
# Set pad token if not present
|
||||||
@@ -659,14 +658,15 @@ class HFChat(LLMInterface):
|
|||||||
# Fallback for models without chat template
|
# Fallback for models without chat template
|
||||||
formatted_prompt = prompt
|
formatted_prompt = prompt
|
||||||
|
|
||||||
# Tokenize input
|
# Tokenize input (respect model context length when available)
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
formatted_prompt,
|
formatted_prompt,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
# Respect model context length when available
|
max_length=getattr(
|
||||||
max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048),
|
getattr(self.model, "config", None), "max_position_embeddings", 2048
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
@@ -692,11 +692,39 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
# Generate
|
# Streaming support (optional)
|
||||||
|
stream = bool(kwargs.get("stream", False))
|
||||||
|
if stream:
|
||||||
|
try:
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen():
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.generate(**inputs, **generation_config, streamer=streamer)
|
||||||
|
|
||||||
|
t = Thread(target=_gen)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
pieces = []
|
||||||
|
for new_text in streamer:
|
||||||
|
print(new_text, end="", flush=True)
|
||||||
|
pieces.append(new_text)
|
||||||
|
t.join()
|
||||||
|
print("") # newline after streaming
|
||||||
|
return ("".join(pieces)).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
|
||||||
|
|
||||||
|
# Non-streaming path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model.generate(**inputs, **generation_config)
|
outputs = self.model.generate(**inputs, **generation_config)
|
||||||
|
|
||||||
# Decode response
|
|
||||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
@@ -760,31 +788,21 @@ class GeminiChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||||
self,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = resolve_openai_base_url(base_url)
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
self.api_key = resolve_openai_api_key(api_key)
|
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||||
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
|
||||||
model,
|
|
||||||
self.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
self.client = openai.OpenAI(api_key=self.api_key)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
@@ -874,16 +892,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(
|
return OllamaChat(
|
||||||
model=model or "llama3:8b",
|
model=model or "llama3:8b",
|
||||||
host=llm_config.get("host"),
|
host=llm_config.get("host", "http://localhost:11434"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
model=model or "gpt-4o",
|
|
||||||
api_key=llm_config.get("api_key"),
|
|
||||||
base_url=llm_config.get("base_url"),
|
|
||||||
)
|
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
|
|||||||
Reference in New Issue
Block a user