feat: hint for users about wrong model name
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import faulthandler
|
import faulthandler
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -13,17 +14,14 @@ from pathlib import Path
|
|||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
node_parser = SentenceSplitter(
|
node_parser = SentenceSplitter(
|
||||||
chunk_size=256,
|
chunk_size=256, chunk_overlap=64, separator=" ", paragraph_separator="\n\n"
|
||||||
chunk_overlap=64,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n"
|
|
||||||
)
|
)
|
||||||
print("Loading documents...")
|
print("Loading documents...")
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"]
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
).load_data(show_progress=True)
|
).load_data(show_progress=True)
|
||||||
print("Documents loaded.")
|
print("Documents loaded.")
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -32,58 +30,86 @@ for doc in documents:
|
|||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
|
||||||
async def main(args):
|
async def main(args):
|
||||||
INDEX_DIR = Path(args.index_dir)
|
INDEX_DIR = Path(args.index_dir)
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model="facebook/contriever",
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
|
|
||||||
llm_config = {
|
llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"}
|
||||||
"type": "ollama", "model": "qwen3:8b"
|
|
||||||
}
|
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
query = (
|
||||||
query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
"What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
|
)
|
||||||
|
query = (
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32)
|
chat_response = chat.ask(
|
||||||
|
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32
|
||||||
|
)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.")
|
description="Run Leann Chat with various LLM backends."
|
||||||
parser.add_argument("--model", type=str, default='Qwen/Qwen3-0.6B', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).")
|
)
|
||||||
parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.")
|
parser.add_argument(
|
||||||
parser.add_argument("--index-dir", type=str, default="./test_pdf_index_pangu_test", help="Directory where the Leann index will be stored.")
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
help="The LLM backend to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="Qwen/Qwen3-0.6B",
|
||||||
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The host for the Ollama API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_pdf_index_pangu_test",
|
||||||
|
help="Directory where the Leann index will be stored.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(main(args))
|
asyncio.run(main(args))
|
||||||
|
|||||||
@@ -5,16 +5,291 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, List
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import difflib
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_models() -> List[str]:
|
||||||
|
"""Check available Ollama models and return a list"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
return [model["name"] for model in data.get("models", [])]
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||||
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
|
if not available_models:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_lower = query.lower()
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
# 1. Exact matches first
|
||||||
|
exact_matches = [m for m in available_models if query_lower == m.lower()]
|
||||||
|
suggestions.extend(exact_matches)
|
||||||
|
|
||||||
|
# 2. Starts with query
|
||||||
|
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
|
||||||
|
suggestions.extend(starts_with)
|
||||||
|
|
||||||
|
# 3. Contains query
|
||||||
|
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
|
||||||
|
suggestions.extend(contains)
|
||||||
|
|
||||||
|
# 4. Base model name matching (remove version numbers)
|
||||||
|
def get_base_name(model_name: str) -> str:
|
||||||
|
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
||||||
|
return model_name.split(':')[0].split('-')[0]
|
||||||
|
|
||||||
|
query_base = get_base_name(query_lower)
|
||||||
|
base_matches = [
|
||||||
|
m for m in available_models
|
||||||
|
if get_base_name(m.lower()) == query_base and m not in suggestions
|
||||||
|
]
|
||||||
|
suggestions.extend(base_matches)
|
||||||
|
|
||||||
|
# 5. Family/variant matching
|
||||||
|
model_families = {
|
||||||
|
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
|
||||||
|
'qwen': ['qwen', 'qwen2', 'qwen3'],
|
||||||
|
'gemma': ['gemma', 'gemma2'],
|
||||||
|
'phi': ['phi', 'phi2', 'phi3'],
|
||||||
|
'mistral': ['mistral', 'mixtral', 'openhermes'],
|
||||||
|
'dolphin': ['dolphin', 'openchat'],
|
||||||
|
'deepseek': ['deepseek', 'deepseek-coder']
|
||||||
|
}
|
||||||
|
|
||||||
|
query_family = None
|
||||||
|
for family, variants in model_families.items():
|
||||||
|
if any(variant in query_lower for variant in variants):
|
||||||
|
query_family = family
|
||||||
|
break
|
||||||
|
|
||||||
|
if query_family:
|
||||||
|
family_variants = model_families[query_family]
|
||||||
|
family_matches = [
|
||||||
|
m for m in available_models
|
||||||
|
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
||||||
|
]
|
||||||
|
suggestions.extend(family_matches)
|
||||||
|
|
||||||
|
# 6. Use difflib for remaining fuzzy matches
|
||||||
|
remaining_models = [m for m in available_models if m not in suggestions]
|
||||||
|
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
|
||||||
|
suggestions.extend(difflib_matches)
|
||||||
|
|
||||||
|
return suggestions[:8] # Return top 8 suggestions
|
||||||
|
|
||||||
|
|
||||||
|
# Remove this function entirely - we don't need external API calls for Ollama
|
||||||
|
|
||||||
|
|
||||||
|
# Remove this too - no need for fallback
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
||||||
|
"""Use difflib to find similar model names"""
|
||||||
|
if not available_models:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get close matches using fuzzy matching
|
||||||
|
suggestions = difflib.get_close_matches(
|
||||||
|
invalid_model, available_models, n=3, cutoff=0.3
|
||||||
|
)
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
|
def check_hf_model_exists(model_name: str) -> bool:
|
||||||
|
"""Quick check if HuggingFace model exists without downloading"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import model_info
|
||||||
|
model_info(model_name)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_popular_hf_models() -> List[str]:
|
||||||
|
"""Return a list of popular HuggingFace models for suggestions"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
|
# Get popular text-generation models, sorted by downloads
|
||||||
|
models = list_models(
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=20 # Get top 20 most downloaded
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract model names and filter for chat/conversation models
|
||||||
|
model_names = []
|
||||||
|
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_name = model.id if hasattr(model, 'id') else str(model)
|
||||||
|
# Prioritize models with chat-related keywords
|
||||||
|
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
||||||
|
model_names.append(model_name)
|
||||||
|
elif len(model_names) < 10: # Fill up with other popular models
|
||||||
|
model_names.append(model_name)
|
||||||
|
|
||||||
|
return model_names[:10] if model_names else _get_fallback_hf_models()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Fallback to static list if API call fails
|
||||||
|
return _get_fallback_hf_models()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fallback_hf_models() -> List[str]:
|
||||||
|
"""Fallback list of popular HuggingFace models"""
|
||||||
|
return [
|
||||||
|
"microsoft/DialoGPT-medium",
|
||||||
|
"microsoft/DialoGPT-large",
|
||||||
|
"facebook/blenderbot-400M-distill",
|
||||||
|
"microsoft/phi-2",
|
||||||
|
"deepseek-ai/deepseek-llm-7b-chat",
|
||||||
|
"microsoft/DialoGPT-small",
|
||||||
|
"facebook/blenderbot_small-90M",
|
||||||
|
"microsoft/phi-1_5",
|
||||||
|
"facebook/opt-350m",
|
||||||
|
"EleutherAI/gpt-neo-1.3B"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||||
|
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
|
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||||
|
models = list_models(
|
||||||
|
search=query,
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
|
||||||
|
|
||||||
|
# If direct search doesn't return enough results, try some variations
|
||||||
|
if len(model_names) < 3:
|
||||||
|
# Try searching for partial matches or common variations
|
||||||
|
variations = []
|
||||||
|
|
||||||
|
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
||||||
|
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
|
||||||
|
if base_query != query.lower():
|
||||||
|
variations.append(base_query)
|
||||||
|
|
||||||
|
# Try common model name patterns
|
||||||
|
if 'gpt' in query.lower():
|
||||||
|
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
|
||||||
|
elif 'llama' in query.lower():
|
||||||
|
variations.extend(['llama2', 'alpaca', 'vicuna'])
|
||||||
|
elif 'bert' in query.lower():
|
||||||
|
variations.extend(['roberta', 'distilbert', 'albert'])
|
||||||
|
|
||||||
|
# Search with variations
|
||||||
|
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
||||||
|
try:
|
||||||
|
var_models = list_models(
|
||||||
|
search=var,
|
||||||
|
filter="text-generation",
|
||||||
|
sort="downloads",
|
||||||
|
direction=-1,
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
|
||||||
|
model_names.extend(var_names)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove duplicates while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_models = []
|
||||||
|
for model in model_names:
|
||||||
|
if model not in seen:
|
||||||
|
seen.add(model)
|
||||||
|
unique_models.append(model)
|
||||||
|
|
||||||
|
return unique_models[:limit]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If search fails, return empty list
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
||||||
|
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||||
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||||
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
|
if llm_type == "ollama":
|
||||||
|
available_models = check_ollama_models()
|
||||||
|
if available_models and model_name not in available_models:
|
||||||
|
# Use intelligent fuzzy search based on locally installed models
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
|
||||||
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
error_msg += "\n\nYour installed models:\n"
|
||||||
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
|
error_msg += "\nTo list all models: ollama list"
|
||||||
|
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
||||||
|
error_msg += "\nBrowse models: https://ollama.com/library"
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
elif llm_type == "hf":
|
||||||
|
# For HF models, we can do a quick existence check
|
||||||
|
if not check_hf_model_exists(model_name):
|
||||||
|
# Use HF Hub's native fuzzy search directly
|
||||||
|
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
|
||||||
|
|
||||||
|
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
|
||||||
|
if search_suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these?\n"
|
||||||
|
for i, suggestion in enumerate(search_suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
else:
|
||||||
|
# Fallback to popular models if search returns nothing
|
||||||
|
popular_models = get_popular_hf_models()
|
||||||
|
error_msg += "\n\nPopular chat models:\n"
|
||||||
|
for i, model in enumerate(popular_models[:5], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
|
||||||
|
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
return None # Model is valid or we can't check
|
||||||
|
|
||||||
|
|
||||||
class LLMInterface(ABC):
|
class LLMInterface(ABC):
|
||||||
"""Abstract base class for a generic Language Model (LLM) interface."""
|
"""Abstract base class for a generic Language Model (LLM) interface."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -32,7 +307,7 @@ class LLMInterface(ABC):
|
|||||||
batch_recompute=True,
|
batch_recompute=True,
|
||||||
global_pruning=True
|
global_pruning=True
|
||||||
)
|
)
|
||||||
|
|
||||||
Supported kwargs:
|
Supported kwargs:
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
@@ -57,22 +332,37 @@ class LLMInterface(ABC):
|
|||||||
# """
|
# """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = host
|
self.host = host
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if host:
|
if host:
|
||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
|
# Pre-check model availability with helpful suggestions
|
||||||
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
|
if model_error:
|
||||||
|
raise ValueError(model_error)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.")
|
raise ImportError(
|
||||||
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
logger.error(
|
||||||
raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
import requests
|
import requests
|
||||||
@@ -83,15 +373,15 @@ class OllamaChat(LLMInterface):
|
|||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False, # Keep it simple for now
|
"stream": False, # Keep it simple for now
|
||||||
"options": kwargs
|
"options": kwargs,
|
||||||
}
|
}
|
||||||
logger.info(f"Sending request to Ollama: {payload}")
|
logger.info(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
response = requests.post(full_url, data=json.dumps(payload))
|
response = requests.post(full_url, data=json.dumps(payload))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# The response from Ollama can be a stream of JSON objects, handle this
|
# The response from Ollama can be a stream of JSON objects, handle this
|
||||||
response_parts = response.text.strip().split('\n')
|
response_parts = response.text.strip().split("\n")
|
||||||
full_response = ""
|
full_response = ""
|
||||||
for part in response_parts:
|
for part in response_parts:
|
||||||
if part:
|
if part:
|
||||||
@@ -104,15 +394,25 @@ class OllamaChat(LLMInterface):
|
|||||||
logger.error(f"Error communicating with Ollama: {e}")
|
logger.error(f"Error communicating with Ollama: {e}")
|
||||||
return f"Error: Could not get a response from Ollama. Details: {e}"
|
return f"Error: Could not get a response from Ollama. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models."""
|
"""LLM interface for local Hugging Face Transformers models."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
|
|
||||||
|
# Pre-check model availability with helpful suggestions
|
||||||
|
model_error = validate_model_and_suggest(model_name, "hf")
|
||||||
|
if model_error:
|
||||||
|
raise ValueError(model_error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import pipeline
|
from transformers.pipelines import pipeline
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
|
raise ImportError(
|
||||||
|
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
||||||
|
)
|
||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -140,47 +440,54 @@ class HFChat(LLMInterface):
|
|||||||
# Remove unsupported zero temperature and use deterministic generation
|
# Remove unsupported zero temperature and use deterministic generation
|
||||||
kwargs.pop("temperature")
|
kwargs.pop("temperature")
|
||||||
kwargs.setdefault("do_sample", False)
|
kwargs.setdefault("do_sample", False)
|
||||||
|
|
||||||
# Sensible defaults for text generation
|
# Sensible defaults for text generation
|
||||||
params = {
|
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
||||||
"max_length": 500,
|
|
||||||
"num_return_sequences": 1,
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||||
results = self.pipeline(prompt, **params)
|
results = self.pipeline(prompt, **params)
|
||||||
|
|
||||||
# Handle different response formats from transformers
|
# Handle different response formats from transformers
|
||||||
if isinstance(results, list) and len(results) > 0:
|
if isinstance(results, list) and len(results) > 0:
|
||||||
generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0])
|
generated_text = (
|
||||||
|
results[0].get("generated_text", "")
|
||||||
|
if isinstance(results[0], dict)
|
||||||
|
else str(results[0])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
generated_text = str(results)
|
generated_text = str(results)
|
||||||
|
|
||||||
# Extract only the newly generated portion by removing the original prompt
|
# Extract only the newly generated portion by removing the original prompt
|
||||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||||
response = generated_text[len(prompt):].strip()
|
response = generated_text[len(prompt) :].strip()
|
||||||
else:
|
else:
|
||||||
# Fallback: return the full response if prompt removal fails
|
# Fallback: return the full response if prompt removal fails
|
||||||
response = str(generated_text)
|
response = str(generated_text)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.")
|
raise ValueError(
|
||||||
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key)
|
self.client = openai.OpenAI(api_key=self.api_key)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.")
|
raise ImportError(
|
||||||
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
# Default parameters for OpenAI
|
# Default parameters for OpenAI
|
||||||
@@ -189,11 +496,15 @@ class OpenAIChat(LLMInterface):
|
|||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k not in ["max_tokens", "temperature"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content.strip()
|
return response.choices[0].message.content.strip()
|
||||||
@@ -201,13 +512,16 @@ class OpenAIChat(LLMInterface):
|
|||||||
logger.error(f"Error communicating with OpenAI: {e}")
|
logger.error(f"Error communicating with OpenAI: {e}")
|
||||||
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
return f"Error: Could not get a response from OpenAI. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class SimulatedChat(LLMInterface):
|
class SimulatedChat(LLMInterface):
|
||||||
"""A simple simulated chat for testing and development."""
|
"""A simple simulated chat for testing and development."""
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
logger.info("Simulating LLM call...")
|
logger.info("Simulating LLM call...")
|
||||||
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
|
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
|
||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
@@ -225,16 +539,19 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
|||||||
llm_config = {
|
llm_config = {
|
||||||
"type": "openai",
|
"type": "openai",
|
||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"api_key": os.getenv("OPENAI_API_KEY")
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_type = llm_config.get("type", "openai")
|
llm_type = llm_config.get("type", "openai")
|
||||||
model = llm_config.get("model")
|
model = llm_config.get("model")
|
||||||
|
|
||||||
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
|
||||||
|
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434"))
|
return OllamaChat(
|
||||||
|
model=model or "llama3:8b",
|
||||||
|
host=llm_config.get("host", "http://localhost:11434"),
|
||||||
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
|
|||||||
Reference in New Issue
Block a user