From bc6c53edf04e744be10773e3299d736876a43dfd Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Thu, 30 Oct 2025 22:00:34 +0000 Subject: [PATCH] site-packages back --- packages/leann-core/src/leann/chat.py | 102 +++++++++++++++----------- 1 file changed, 58 insertions(+), 44 deletions(-) diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 12960b9..5e1796c 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -12,8 +12,6 @@ from typing import Any, Optional import torch -from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url - # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]: def validate_model_and_suggest( - model_name: str, llm_type: str, host: Optional[str] = None + model_name: str, llm_type: str, host: str = "http://localhost:11434" ) -> Optional[str]: """Validate model name and provide suggestions if invalid""" if llm_type == "ollama": - resolved_host = resolve_ollama_host(host) - available_models = check_ollama_models(resolved_host) + available_models = check_ollama_models(host) if available_models and model_name not in available_models: error_msg = f"Model '{model_name}' not found in your local Ollama installation." @@ -460,19 +457,19 @@ class LLMInterface(ABC): class OllamaChat(LLMInterface): """LLM interface for Ollama models.""" - def __init__(self, model: str = "llama3:8b", host: Optional[str] = None): + def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): self.model = model - self.host = resolve_ollama_host(host) - logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'") + self.host = host + logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") try: import requests # Check if the Ollama server is responsive - if self.host: - requests.get(self.host) + if host: + requests.get(host) # Pre-check model availability with helpful suggestions - model_error = validate_model_and_suggest(model, "ollama", self.host) + model_error = validate_model_and_suggest(model, "ollama", host) if model_error: raise ValueError(model_error) @@ -481,11 +478,9 @@ class OllamaChat(LLMInterface): "The 'requests' library is required for Ollama. Please install it with 'pip install requests'." ) except requests.exceptions.ConnectionError: - logger.error( - f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running." - ) + logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") raise ConnectionError( - f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running." + f"Could not connect to Ollama at {host}. Please ensure Ollama is running." ) def ask(self, prompt: str, **kwargs) -> str: @@ -582,9 +577,9 @@ class HFChat(LLMInterface): def timeout_handler(signum, frame): raise TimeoutError("Model download/loading timed out") - # Set timeout for model loading (60 seconds) + # Set timeout for model loading (increase to 300s for large models) old_handler = signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(60) + signal.alarm(300) try: logger.info(f"Loading tokenizer for {model_name}...") @@ -626,8 +621,12 @@ class HFChat(LLMInterface): logger.error(f"Failed to load model {model_name}: {e}") raise - # Move model to device if not using device_map - if self.device != "cpu" and "device_map" not in str(self.model): + # Move model to device only if not managed by accelerate (no device_map) + try: + has_device_map = getattr(self.model, "hf_device_map", None) is not None + except Exception: + has_device_map = False + if self.device != "cpu" and not has_device_map: self.model = self.model.to(self.device) # Set pad token if not present @@ -659,14 +658,15 @@ class HFChat(LLMInterface): # Fallback for models without chat template formatted_prompt = prompt - # Tokenize input + # Tokenize input (respect model context length when available) inputs = self.tokenizer( formatted_prompt, return_tensors="pt", padding=True, truncation=True, - # Respect model context length when available - max_length=getattr(getattr(self.model, "config", None), "max_position_embeddings", 2048), + max_length=getattr( + getattr(self.model, "config", None), "max_position_embeddings", 2048 + ), ) # Move inputs to device @@ -692,11 +692,39 @@ class HFChat(LLMInterface): logger.info(f"Generating with HuggingFace model, config: {generation_config}") - # Generate + # Streaming support (optional) + stream = bool(kwargs.get("stream", False)) + if stream: + try: + from threading import Thread + + from transformers import TextIteratorStreamer + + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + + def _gen(): + with torch.no_grad(): + self.model.generate(**inputs, **generation_config, streamer=streamer) + + t = Thread(target=_gen) + t.start() + + pieces = [] + for new_text in streamer: + print(new_text, end="", flush=True) + pieces.append(new_text) + t.join() + print("") # newline after streaming + return ("".join(pieces)).strip() + except Exception as e: + logger.warning(f"Streaming failed, falling back to non-streaming: {e}") + + # Non-streaming path with torch.no_grad(): outputs = self.model.generate(**inputs, **generation_config) - # Decode response generated_tokens = outputs[0][inputs["input_ids"].shape[1] :] response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) @@ -760,31 +788,21 @@ class GeminiChat(LLMInterface): class OpenAIChat(LLMInterface): """LLM interface for OpenAI models.""" - def __init__( - self, - model: str = "gpt-4o", - api_key: Optional[str] = None, - base_url: Optional[str] = None, - ): + def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): self.model = model - self.base_url = resolve_openai_base_url(base_url) - self.api_key = resolve_openai_api_key(api_key) + self.api_key = api_key or os.getenv("OPENAI_API_KEY") if not self.api_key: raise ValueError( "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." ) - logger.info( - "Initializing OpenAI Chat with model='%s' and base_url='%s'", - model, - self.base_url, - ) + logger.info(f"Initializing OpenAI Chat with model='{model}'") try: import openai - self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + self.client = openai.OpenAI(api_key=self.api_key) except ImportError: raise ImportError( "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." @@ -874,16 +892,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface: if llm_type == "ollama": return OllamaChat( model=model or "llama3:8b", - host=llm_config.get("host"), + host=llm_config.get("host", "http://localhost:11434"), ) elif llm_type == "hf": return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") elif llm_type == "openai": - return OpenAIChat( - model=model or "gpt-4o", - api_key=llm_config.get("api_key"), - base_url=llm_config.get("base_url"), - ) + return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) elif llm_type == "gemini": return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) elif llm_type == "simulated":