fix: resolve all ruff linting errors and add lint CI check

- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments
- Replace Chinese comments with English equivalents
- Fix unused imports with proper noqa annotations for intentional imports
- Fix bare except clauses with specific exception types
- Fix redefined variables and undefined names
- Add ruff noqa annotations for generated protobuf files
- Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
Andy Lee
2025-07-26 22:35:12 -07:00
parent 8537a6b17e
commit b3e9ee96fa
53 changed files with 5655 additions and 5220 deletions

View File

@@ -4,11 +4,12 @@ This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import difflib
import logging
import os
import difflib
from abc import ABC, abstractmethod
from typing import Any
import torch
# Configure logging
@@ -16,10 +17,11 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models() -> List[str]:
def check_ollama_models() -> list[str]:
"""Check available Ollama models and return a list"""
try:
import requests
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200:
data = response.json()
@@ -31,51 +33,52 @@ def check_ollama_models() -> List[str]:
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import requests
import re
import requests
# Split model name and tag
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
if ":" in model_name:
base_model, requested_tag = model_name.split(":", 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag:
if "<" in tag or ">" in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
@@ -83,76 +86,80 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
query_lower = query.lower()
suggestions = []
# 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches)
# 2. Starts with query
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
starts_with = [
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
suggestions.extend(starts_with)
# 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains)
# 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(':')[0].split('-')[0]
return model_name.split(":")[0].split("-")[0]
query_base = get_base_name(query_lower)
base_matches = [
m for m in available_models
m
for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions
]
suggestions.extend(base_matches)
# 5. Family/variant matching
model_families = {
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
'qwen': ['qwen', 'qwen2', 'qwen3'],
'gemma': ['gemma', 'gemma2'],
'phi': ['phi', 'phi2', 'phi3'],
'mistral': ['mistral', 'mixtral', 'openhermes'],
'dolphin': ['dolphin', 'openchat'],
'deepseek': ['deepseek', 'deepseek-coder']
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
"qwen": ["qwen", "qwen2", "qwen3"],
"gemma": ["gemma", "gemma2"],
"phi": ["phi", "phi2", "phi3"],
"mistral": ["mistral", "mixtral", "openhermes"],
"dolphin": ["dolphin", "openchat"],
"deepseek": ["deepseek", "deepseek-coder"],
}
query_family = None
for family, variants in model_families.items():
if any(variant in query_lower for variant in variants):
query_family = family
break
if query_family:
family_variants = model_families[query_family]
family_matches = [
m for m in available_models
m
for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
]
suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions
@@ -162,15 +169,13 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
# Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(
invalid_model, available_models, n=3, cutoff=0.3
)
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
return suggestions
@@ -178,49 +183,50 @@ def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading"""
try:
from huggingface_hub import model_info
model_info(model_name)
return True
except Exception:
return False
def get_popular_hf_models() -> List[str]:
def get_popular_hf_models() -> list[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads
models = list_models(
filter="text-generation",
sort="downloads",
direction=-1,
limit=20 # Get top 20 most downloaded
limit=20, # Get top 20 most downloaded
)
# Extract model names and filter for chat/conversation models
model_names = []
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
for model in models:
model_name = model.id if hasattr(model, 'id') else str(model)
model_name = model.id if hasattr(model, "id") else str(model)
# Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception:
# Fallback to static list if API call fails
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> List[str]:
def _get_fallback_hf_models() -> list[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill",
"microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat",
@@ -228,44 +234,40 @@ def _get_fallback_hf_models() -> List[str]:
"facebook/blenderbot_small-90M",
"microsoft/phi-1_5",
"facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B"
"EleutherAI/gpt-neo-1.3B",
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query,
filter="text-generation",
sort="downloads",
direction=-1,
limit=limit
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
)
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
# If direct search doesn't return enough results, try some variations
if len(model_names) < 3:
# Try searching for partial matches or common variations
variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
if base_query != query.lower():
variations.append(base_query)
# Try common model name patterns
if 'gpt' in query.lower():
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
elif 'llama' in query.lower():
variations.extend(['llama2', 'alpaca', 'vicuna'])
elif 'bert' in query.lower():
variations.extend(['roberta', 'distilbert', 'albert'])
if "gpt" in query.lower():
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
elif "llama" in query.lower():
variations.extend(["llama2", "alpaca", "vicuna"])
elif "bert" in query.lower():
variations.extend(["roberta", "distilbert", "albert"])
# Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try:
@@ -274,13 +276,15 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
filter="text-generation",
sort="downloads",
direction=-1,
limit=3
limit=3,
)
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
var_names = [
model.id if hasattr(model, "id") else str(model) for model in var_models
]
model_names.extend(var_names)
except:
except Exception:
continue
# Remove duplicates while preserving order
seen = set()
unique_models = []
@@ -288,65 +292,67 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
if model not in seen:
seen.add(model)
unique_models.append(model)
return unique_models[:limit]
except Exception:
# If search fails, return empty list
return []
def search_hf_models(query: str, limit: int = 10) -> List[str]:
def search_hf_models(query: str, limit: int = 10) -> list[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n"
error_msg += "\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
base_model = model_name.split(":")[0]
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
@@ -357,23 +363,25 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
# For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n"
@@ -385,10 +393,10 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg
return None # Model is valid or we can't check
@@ -451,28 +459,27 @@ class OllamaChat(LLMInterface):
# Check if the Ollama server is responsive
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama")
if model_error:
raise ValueError(model_error)
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
import requests
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
@@ -482,7 +489,7 @@ class OllamaChat(LLMInterface):
}
logger.debug(f"Sending request to Ollama: {payload}")
try:
logger.info(f"Sending request to Ollama and waiting for response...")
logger.info("Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
@@ -506,15 +513,15 @@ class HFChat(LLMInterface):
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
raise ValueError(model_error)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
@@ -537,36 +544,34 @@ class HFChat(LLMInterface):
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
trust_remote_code=True,
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
print('kwargs in HF: ', kwargs)
print("kwargs in HF: ", kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
messages, tokenize=False, add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
@@ -577,13 +582,9 @@ class HFChat(LLMInterface):
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -597,32 +598,29 @@ class HFChat(LLMInterface):
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
)
outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
@@ -649,11 +647,7 @@ class OpenAIChat(LLMInterface):
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
}
logger.info(f"Sending request to OpenAI with model {self.model}")
@@ -675,7 +669,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.