fix: resolve all ruff linting errors and add lint CI check
- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments - Replace Chinese comments with English equivalents - Fix unused imports with proper noqa annotations for intentional imports - Fix bare except clauses with specific exception types - Fix redefined variables and undefined names - Add ruff noqa annotations for generated protobuf files - Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
@@ -4,11 +4,12 @@ This file contains the chat generation logic for the LEANN project,
|
||||
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
import difflib
|
||||
import logging
|
||||
import os
|
||||
import difflib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
@@ -16,10 +17,11 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ollama_models() -> List[str]:
|
||||
def check_ollama_models() -> list[str]:
|
||||
"""Check available Ollama models and return a list"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
@@ -31,51 +33,52 @@ def check_ollama_models() -> List[str]:
|
||||
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||
"""Check if a model exists in Ollama's remote library and return available tags
|
||||
|
||||
|
||||
Returns:
|
||||
(model_exists, available_tags): bool and list of matching tags
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
import re
|
||||
|
||||
|
||||
import requests
|
||||
|
||||
# Split model name and tag
|
||||
if ':' in model_name:
|
||||
base_model, requested_tag = model_name.split(':', 1)
|
||||
if ":" in model_name:
|
||||
base_model, requested_tag = model_name.split(":", 1)
|
||||
else:
|
||||
base_model, requested_tag = model_name, None
|
||||
|
||||
|
||||
# First check if base model exists in library
|
||||
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||
if library_response.status_code != 200:
|
||||
return True, [] # Assume exists if can't check
|
||||
|
||||
|
||||
# Extract model names from library page
|
||||
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||
|
||||
|
||||
if base_model not in models_in_library:
|
||||
return False, [] # Base model doesn't exist
|
||||
|
||||
|
||||
# If base model exists, get available tags
|
||||
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||
if tags_response.status_code != 200:
|
||||
return True, [] # Base model exists but can't get tags
|
||||
|
||||
|
||||
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
|
||||
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
|
||||
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||
|
||||
|
||||
# Clean up tags - remove HTML artifacts and duplicates
|
||||
available_tags = []
|
||||
seen = set()
|
||||
for tag in raw_tags:
|
||||
# Skip if it looks like HTML (contains < or >)
|
||||
if '<' in tag or '>' in tag:
|
||||
if "<" in tag or ">" in tag:
|
||||
continue
|
||||
if tag not in seen:
|
||||
seen.add(tag)
|
||||
available_tags.append(tag)
|
||||
|
||||
|
||||
# Check if exact model exists
|
||||
if requested_tag is None:
|
||||
# User just requested base model, suggest tags
|
||||
@@ -83,76 +86,80 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
|
||||
else:
|
||||
exact_match = model_name in available_tags
|
||||
return exact_match, available_tags[:10]
|
||||
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# If scraping fails, assume model might exist (don't block user)
|
||||
return True, []
|
||||
|
||||
|
||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
||||
"""Use intelligent fuzzy search for Ollama models"""
|
||||
if not available_models:
|
||||
return []
|
||||
|
||||
|
||||
query_lower = query.lower()
|
||||
suggestions = []
|
||||
|
||||
|
||||
# 1. Exact matches first
|
||||
exact_matches = [m for m in available_models if query_lower == m.lower()]
|
||||
suggestions.extend(exact_matches)
|
||||
|
||||
|
||||
# 2. Starts with query
|
||||
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
|
||||
starts_with = [
|
||||
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
|
||||
]
|
||||
suggestions.extend(starts_with)
|
||||
|
||||
|
||||
# 3. Contains query
|
||||
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
|
||||
suggestions.extend(contains)
|
||||
|
||||
|
||||
# 4. Base model name matching (remove version numbers)
|
||||
def get_base_name(model_name: str) -> str:
|
||||
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
||||
return model_name.split(':')[0].split('-')[0]
|
||||
|
||||
return model_name.split(":")[0].split("-")[0]
|
||||
|
||||
query_base = get_base_name(query_lower)
|
||||
base_matches = [
|
||||
m for m in available_models
|
||||
m
|
||||
for m in available_models
|
||||
if get_base_name(m.lower()) == query_base and m not in suggestions
|
||||
]
|
||||
suggestions.extend(base_matches)
|
||||
|
||||
|
||||
# 5. Family/variant matching
|
||||
model_families = {
|
||||
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
|
||||
'qwen': ['qwen', 'qwen2', 'qwen3'],
|
||||
'gemma': ['gemma', 'gemma2'],
|
||||
'phi': ['phi', 'phi2', 'phi3'],
|
||||
'mistral': ['mistral', 'mixtral', 'openhermes'],
|
||||
'dolphin': ['dolphin', 'openchat'],
|
||||
'deepseek': ['deepseek', 'deepseek-coder']
|
||||
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
|
||||
"qwen": ["qwen", "qwen2", "qwen3"],
|
||||
"gemma": ["gemma", "gemma2"],
|
||||
"phi": ["phi", "phi2", "phi3"],
|
||||
"mistral": ["mistral", "mixtral", "openhermes"],
|
||||
"dolphin": ["dolphin", "openchat"],
|
||||
"deepseek": ["deepseek", "deepseek-coder"],
|
||||
}
|
||||
|
||||
|
||||
query_family = None
|
||||
for family, variants in model_families.items():
|
||||
if any(variant in query_lower for variant in variants):
|
||||
query_family = family
|
||||
break
|
||||
|
||||
|
||||
if query_family:
|
||||
family_variants = model_families[query_family]
|
||||
family_matches = [
|
||||
m for m in available_models
|
||||
m
|
||||
for m in available_models
|
||||
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
||||
]
|
||||
suggestions.extend(family_matches)
|
||||
|
||||
|
||||
# 6. Use difflib for remaining fuzzy matches
|
||||
remaining_models = [m for m in available_models if m not in suggestions]
|
||||
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
|
||||
suggestions.extend(difflib_matches)
|
||||
|
||||
|
||||
return suggestions[:8] # Return top 8 suggestions
|
||||
|
||||
|
||||
@@ -162,15 +169,13 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
|
||||
# Remove this too - no need for fallback
|
||||
|
||||
|
||||
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
||||
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
||||
"""Use difflib to find similar model names"""
|
||||
if not available_models:
|
||||
return []
|
||||
|
||||
|
||||
# Get close matches using fuzzy matching
|
||||
suggestions = difflib.get_close_matches(
|
||||
invalid_model, available_models, n=3, cutoff=0.3
|
||||
)
|
||||
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
|
||||
return suggestions
|
||||
|
||||
|
||||
@@ -178,49 +183,50 @@ def check_hf_model_exists(model_name: str) -> bool:
|
||||
"""Quick check if HuggingFace model exists without downloading"""
|
||||
try:
|
||||
from huggingface_hub import model_info
|
||||
|
||||
model_info(model_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_popular_hf_models() -> List[str]:
|
||||
def get_popular_hf_models() -> list[str]:
|
||||
"""Return a list of popular HuggingFace models for suggestions"""
|
||||
try:
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
# Get popular text-generation models, sorted by downloads
|
||||
models = list_models(
|
||||
filter="text-generation",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=20 # Get top 20 most downloaded
|
||||
limit=20, # Get top 20 most downloaded
|
||||
)
|
||||
|
||||
|
||||
# Extract model names and filter for chat/conversation models
|
||||
model_names = []
|
||||
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
|
||||
|
||||
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
|
||||
|
||||
for model in models:
|
||||
model_name = model.id if hasattr(model, 'id') else str(model)
|
||||
model_name = model.id if hasattr(model, "id") else str(model)
|
||||
# Prioritize models with chat-related keywords
|
||||
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
||||
model_names.append(model_name)
|
||||
elif len(model_names) < 10: # Fill up with other popular models
|
||||
model_names.append(model_name)
|
||||
|
||||
|
||||
return model_names[:10] if model_names else _get_fallback_hf_models()
|
||||
|
||||
|
||||
except Exception:
|
||||
# Fallback to static list if API call fails
|
||||
return _get_fallback_hf_models()
|
||||
|
||||
|
||||
def _get_fallback_hf_models() -> List[str]:
|
||||
def _get_fallback_hf_models() -> list[str]:
|
||||
"""Fallback list of popular HuggingFace models"""
|
||||
return [
|
||||
"microsoft/DialoGPT-medium",
|
||||
"microsoft/DialoGPT-large",
|
||||
"microsoft/DialoGPT-large",
|
||||
"facebook/blenderbot-400M-distill",
|
||||
"microsoft/phi-2",
|
||||
"deepseek-ai/deepseek-llm-7b-chat",
|
||||
@@ -228,44 +234,40 @@ def _get_fallback_hf_models() -> List[str]:
|
||||
"facebook/blenderbot_small-90M",
|
||||
"microsoft/phi-1_5",
|
||||
"facebook/opt-350m",
|
||||
"EleutherAI/gpt-neo-1.3B"
|
||||
"EleutherAI/gpt-neo-1.3B",
|
||||
]
|
||||
|
||||
|
||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||
try:
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||
models = list_models(
|
||||
search=query,
|
||||
filter="text-generation",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=limit
|
||||
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
|
||||
)
|
||||
|
||||
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
|
||||
|
||||
|
||||
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||
|
||||
# If direct search doesn't return enough results, try some variations
|
||||
if len(model_names) < 3:
|
||||
# Try searching for partial matches or common variations
|
||||
variations = []
|
||||
|
||||
|
||||
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
||||
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
|
||||
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
|
||||
if base_query != query.lower():
|
||||
variations.append(base_query)
|
||||
|
||||
|
||||
# Try common model name patterns
|
||||
if 'gpt' in query.lower():
|
||||
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
|
||||
elif 'llama' in query.lower():
|
||||
variations.extend(['llama2', 'alpaca', 'vicuna'])
|
||||
elif 'bert' in query.lower():
|
||||
variations.extend(['roberta', 'distilbert', 'albert'])
|
||||
|
||||
if "gpt" in query.lower():
|
||||
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
|
||||
elif "llama" in query.lower():
|
||||
variations.extend(["llama2", "alpaca", "vicuna"])
|
||||
elif "bert" in query.lower():
|
||||
variations.extend(["roberta", "distilbert", "albert"])
|
||||
|
||||
# Search with variations
|
||||
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
||||
try:
|
||||
@@ -274,13 +276,15 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||
filter="text-generation",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=3
|
||||
limit=3,
|
||||
)
|
||||
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
|
||||
var_names = [
|
||||
model.id if hasattr(model, "id") else str(model) for model in var_models
|
||||
]
|
||||
model_names.extend(var_names)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_models = []
|
||||
@@ -288,65 +292,67 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||
if model not in seen:
|
||||
seen.add(model)
|
||||
unique_models.append(model)
|
||||
|
||||
|
||||
return unique_models[:limit]
|
||||
|
||||
|
||||
except Exception:
|
||||
# If search fails, return empty list
|
||||
return []
|
||||
|
||||
|
||||
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
||||
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||
return search_hf_models_fuzzy(query, limit)
|
||||
|
||||
|
||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
|
||||
# Check if the model exists remotely and get available tags
|
||||
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||
|
||||
|
||||
if model_exists_remotely and model_name in available_tags:
|
||||
# Exact model exists remotely - suggest pulling it
|
||||
error_msg += f"\n\nTo install the requested model:\n"
|
||||
error_msg += "\n\nTo install the requested model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n"
|
||||
|
||||
|
||||
# Show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
|
||||
elif model_exists_remotely and available_tags:
|
||||
# Base model exists but requested tag doesn't - suggest correct tags
|
||||
base_model = model_name.split(':')[0]
|
||||
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
|
||||
|
||||
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||
base_model = model_name.split(":")[0]
|
||||
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
|
||||
|
||||
error_msg += (
|
||||
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||
)
|
||||
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||
for i, tag in enumerate(available_tags[:8], 1):
|
||||
error_msg += f" {i}. ollama pull {tag}\n"
|
||||
if len(available_tags) > 8:
|
||||
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||
|
||||
|
||||
# Also show local alternatives
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
if suggestions:
|
||||
error_msg += "\nOr use one of these similar installed models:\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
|
||||
|
||||
else:
|
||||
# Model doesn't exist remotely - show fuzzy suggestions
|
||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
@@ -357,23 +363,25 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
error_msg += f" {i}. {model}\n"
|
||||
if len(available_models) > 8:
|
||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||
|
||||
|
||||
error_msg += "\n\nCommands:"
|
||||
error_msg += "\n ollama list # List installed models"
|
||||
if model_exists_remotely and available_tags:
|
||||
if model_name in available_tags:
|
||||
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||
else:
|
||||
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||
error_msg += (
|
||||
f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||
)
|
||||
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||
return error_msg
|
||||
|
||||
|
||||
elif llm_type == "hf":
|
||||
# For HF models, we can do a quick existence check
|
||||
if not check_hf_model_exists(model_name):
|
||||
# Use HF Hub's native fuzzy search directly
|
||||
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
|
||||
|
||||
|
||||
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
|
||||
if search_suggestions:
|
||||
error_msg += "\n\nDid you mean one of these?\n"
|
||||
@@ -385,10 +393,10 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||
error_msg += "\n\nPopular chat models:\n"
|
||||
for i, model in enumerate(popular_models[:5], 1):
|
||||
error_msg += f" {i}. {model}\n"
|
||||
|
||||
|
||||
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
|
||||
return error_msg
|
||||
|
||||
|
||||
return None # Model is valid or we can't check
|
||||
|
||||
|
||||
@@ -451,28 +459,27 @@ class OllamaChat(LLMInterface):
|
||||
# Check if the Ollama server is responsive
|
||||
if host:
|
||||
requests.get(host)
|
||||
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama")
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
)
|
||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||
)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
import requests
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
full_url = f"{self.host}/api/generate"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
@@ -482,7 +489,7 @@ class OllamaChat(LLMInterface):
|
||||
}
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||
logger.info("Sending request to Ollama and waiting for response...")
|
||||
response = requests.post(full_url, data=json.dumps(payload))
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -506,15 +513,15 @@ class HFChat(LLMInterface):
|
||||
|
||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model_name, "hf")
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
||||
@@ -537,36 +544,34 @@ class HFChat(LLMInterface):
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
|
||||
# Set pad token if not present
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
print('kwargs in HF: ', kwargs)
|
||||
print("kwargs in HF: ", kwargs)
|
||||
# Check if this is a Qwen model and add /no_think by default
|
||||
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||
|
||||
|
||||
# For Qwen models, automatically add /no_think to the prompt
|
||||
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||
prompt = prompt + " /no_think"
|
||||
|
||||
|
||||
# Prepare chat template
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
|
||||
# Apply chat template if available
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
try:
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||
@@ -577,13 +582,9 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048
|
||||
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
|
||||
)
|
||||
|
||||
|
||||
# Move inputs to device
|
||||
if self.device != "cpu":
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
@@ -597,32 +598,29 @@ class HFChat(LLMInterface):
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
|
||||
# Handle temperature=0 for greedy decoding
|
||||
if generation_config["temperature"] == 0.0:
|
||||
generation_config["do_sample"] = False
|
||||
generation_config.pop("temperature")
|
||||
|
||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
**generation_config
|
||||
)
|
||||
outputs = self.model.generate(**inputs, **generation_config)
|
||||
|
||||
# Decode response
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
@@ -649,11 +647,7 @@ class OpenAIChat(LLMInterface):
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
**{
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["max_tokens", "temperature"]
|
||||
},
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||
}
|
||||
|
||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||
@@ -675,7 +669,7 @@ class SimulatedChat(LLMInterface):
|
||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||
|
||||
|
||||
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
||||
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||
"""
|
||||
Factory function to get an LLM interface based on configuration.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user