diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 1ce3edc..9fd7a92 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -1,4 +1,5 @@ import faulthandler + faulthandler.enable() import argparse @@ -13,17 +14,14 @@ from pathlib import Path dotenv.load_dotenv() node_parser = SentenceSplitter( - chunk_size=256, - chunk_overlap=64, - separator=" ", - paragraph_separator="\n\n" + chunk_size=256, chunk_overlap=64, separator=" ", paragraph_separator="\n\n" ) print("Loading documents...") documents = SimpleDirectoryReader( - "examples/data", + "examples/data", recursive=True, encoding="utf-8", - required_exts=[".pdf", ".txt", ".md"] + required_exts=[".pdf", ".txt", ".md"], ).load_data(show_progress=True) print("Documents loaded.") all_texts = [] @@ -32,58 +30,86 @@ for doc in documents: for node in nodes: all_texts.append(node.get_content()) + async def main(args): INDEX_DIR = Path(args.index_dir) INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") if not INDEX_DIR.exists(): print(f"--- Index directory not found, building new index ---") - + print(f"\n[PHASE 1] Building Leann index...") # Use HNSW backend for better macOS compatibility builder = LeannBuilder( backend_name="hnsw", embedding_model="facebook/contriever", - graph_degree=32, + graph_degree=32, complexity=64, is_compact=True, is_recompute=True, - num_threads=1 # Force single-threaded mode + num_threads=1, # Force single-threaded mode ) print(f"Loaded {len(all_texts)} text chunks from documents.") for chunk_text in all_texts: builder.add_text(chunk_text) - + builder.build_index(INDEX_PATH) print(f"\nLeann index built at {INDEX_PATH}!") else: print(f"--- Using existing index at {INDEX_DIR} ---") print(f"\n[PHASE 2] Starting Leann chat session...") - - llm_config = { - "type": "ollama", "model": "qwen3:8b" - } + llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"} chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) - + query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" - query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" - query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发" + query = ( + "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?" + ) + query = ( + "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发" + ) print(f"You: {query}") - chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True, complexity=32) + chat_response = chat.ask( + query, top_k=20, recompute_beighbor_embeddings=True, complexity=32 + ) print(f"Leann: {chat_response}") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.") - parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf", "openai"], help="The LLM backend to use.") - parser.add_argument("--model", type=str, default='Qwen/Qwen3-0.6B', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).") - parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.") - parser.add_argument("--index-dir", type=str, default="./test_pdf_index_pangu_test", help="Directory where the Leann index will be stored.") + parser = argparse.ArgumentParser( + description="Run Leann Chat with various LLM backends." + ) + parser.add_argument( + "--llm", + type=str, + default="hf", + choices=["simulated", "ollama", "hf", "openai"], + help="The LLM backend to use.", + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-0.6B", + help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).", + ) + parser.add_argument( + "--host", + type=str, + default="http://localhost:11434", + help="The host for the Ollama API.", + ) + parser.add_argument( + "--index-dir", + type=str, + default="./test_pdf_index_pangu_test", + help="Directory where the Leann index will be stored.", + ) args = parser.parse_args() - asyncio.run(main(args)) \ No newline at end of file + asyncio.run(main(args)) diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 62d68a5..7ac0f32 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -5,16 +5,291 @@ supporting different backends like Ollama, Hugging Face Transformers, and a simu """ from abc import ABC, abstractmethod -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, List import logging import os +import difflib # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + +def check_ollama_models() -> List[str]: + """Check available Ollama models and return a list""" + try: + import requests + response = requests.get("http://localhost:11434/api/tags", timeout=5) + if response.status_code == 200: + data = response.json() + return [model["name"] for model in data.get("models", [])] + return [] + except Exception: + return [] + + +def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]: + """Use intelligent fuzzy search for Ollama models""" + if not available_models: + return [] + + query_lower = query.lower() + suggestions = [] + + # 1. Exact matches first + exact_matches = [m for m in available_models if query_lower == m.lower()] + suggestions.extend(exact_matches) + + # 2. Starts with query + starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions] + suggestions.extend(starts_with) + + # 3. Contains query + contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions] + suggestions.extend(contains) + + # 4. Base model name matching (remove version numbers) + def get_base_name(model_name: str) -> str: + """Extract base name without version (e.g., 'llama3:8b' -> 'llama3')""" + return model_name.split(':')[0].split('-')[0] + + query_base = get_base_name(query_lower) + base_matches = [ + m for m in available_models + if get_base_name(m.lower()) == query_base and m not in suggestions + ] + suggestions.extend(base_matches) + + # 5. Family/variant matching + model_families = { + 'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'], + 'qwen': ['qwen', 'qwen2', 'qwen3'], + 'gemma': ['gemma', 'gemma2'], + 'phi': ['phi', 'phi2', 'phi3'], + 'mistral': ['mistral', 'mixtral', 'openhermes'], + 'dolphin': ['dolphin', 'openchat'], + 'deepseek': ['deepseek', 'deepseek-coder'] + } + + query_family = None + for family, variants in model_families.items(): + if any(variant in query_lower for variant in variants): + query_family = family + break + + if query_family: + family_variants = model_families[query_family] + family_matches = [ + m for m in available_models + if any(variant in m.lower() for variant in family_variants) and m not in suggestions + ] + suggestions.extend(family_matches) + + # 6. Use difflib for remaining fuzzy matches + remaining_models = [m for m in available_models if m not in suggestions] + difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4) + suggestions.extend(difflib_matches) + + return suggestions[:8] # Return top 8 suggestions + + +# Remove this function entirely - we don't need external API calls for Ollama + + +# Remove this too - no need for fallback + + +def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]: + """Use difflib to find similar model names""" + if not available_models: + return [] + + # Get close matches using fuzzy matching + suggestions = difflib.get_close_matches( + invalid_model, available_models, n=3, cutoff=0.3 + ) + return suggestions + + +def check_hf_model_exists(model_name: str) -> bool: + """Quick check if HuggingFace model exists without downloading""" + try: + from huggingface_hub import model_info + model_info(model_name) + return True + except Exception: + return False + + +def get_popular_hf_models() -> List[str]: + """Return a list of popular HuggingFace models for suggestions""" + try: + from huggingface_hub import list_models + + # Get popular text-generation models, sorted by downloads + models = list_models( + filter="text-generation", + sort="downloads", + direction=-1, + limit=20 # Get top 20 most downloaded + ) + + # Extract model names and filter for chat/conversation models + model_names = [] + chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant'] + + for model in models: + model_name = model.id if hasattr(model, 'id') else str(model) + # Prioritize models with chat-related keywords + if any(keyword in model_name.lower() for keyword in chat_keywords): + model_names.append(model_name) + elif len(model_names) < 10: # Fill up with other popular models + model_names.append(model_name) + + return model_names[:10] if model_names else _get_fallback_hf_models() + + except Exception: + # Fallback to static list if API call fails + return _get_fallback_hf_models() + + +def _get_fallback_hf_models() -> List[str]: + """Fallback list of popular HuggingFace models""" + return [ + "microsoft/DialoGPT-medium", + "microsoft/DialoGPT-large", + "facebook/blenderbot-400M-distill", + "microsoft/phi-2", + "deepseek-ai/deepseek-llm-7b-chat", + "microsoft/DialoGPT-small", + "facebook/blenderbot_small-90M", + "microsoft/phi-1_5", + "facebook/opt-350m", + "EleutherAI/gpt-neo-1.3B" + ] + + +def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]: + """Use HuggingFace Hub's native fuzzy search for model suggestions""" + try: + from huggingface_hub import list_models + + # HF Hub's search is already fuzzy! It handles typos and partial matches + models = list_models( + search=query, + filter="text-generation", + sort="downloads", + direction=-1, + limit=limit + ) + + model_names = [model.id if hasattr(model, 'id') else str(model) for model in models] + + # If direct search doesn't return enough results, try some variations + if len(model_names) < 3: + # Try searching for partial matches or common variations + variations = [] + + # Extract base name (e.g., "gpt3" from "gpt-3.5") + base_query = query.lower().replace('-', '').replace('.', '').replace('_', '') + if base_query != query.lower(): + variations.append(base_query) + + # Try common model name patterns + if 'gpt' in query.lower(): + variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT']) + elif 'llama' in query.lower(): + variations.extend(['llama2', 'alpaca', 'vicuna']) + elif 'bert' in query.lower(): + variations.extend(['roberta', 'distilbert', 'albert']) + + # Search with variations + for var in variations[:2]: # Limit to 2 variations to avoid too many API calls + try: + var_models = list_models( + search=var, + filter="text-generation", + sort="downloads", + direction=-1, + limit=3 + ) + var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models] + model_names.extend(var_names) + except: + continue + + # Remove duplicates while preserving order + seen = set() + unique_models = [] + for model in model_names: + if model not in seen: + seen.add(model) + unique_models.append(model) + + return unique_models[:limit] + + except Exception: + # If search fails, return empty list + return [] + + +def search_hf_models(query: str, limit: int = 10) -> List[str]: + """Simple search for HuggingFace models based on query (kept for backward compatibility)""" + return search_hf_models_fuzzy(query, limit) + + +def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]: + """Validate model name and provide suggestions if invalid""" + if llm_type == "ollama": + available_models = check_ollama_models() + if available_models and model_name not in available_models: + # Use intelligent fuzzy search based on locally installed models + suggestions = search_ollama_models_fuzzy(model_name, available_models) + + error_msg = f"Model '{model_name}' not found in your local Ollama installation." + if suggestions: + error_msg += "\n\nDid you mean one of these installed models?\n" + for i, suggestion in enumerate(suggestions, 1): + error_msg += f" {i}. {suggestion}\n" + else: + error_msg += "\n\nYour installed models:\n" + for i, model in enumerate(available_models[:8], 1): + error_msg += f" {i}. {model}\n" + if len(available_models) > 8: + error_msg += f" ... and {len(available_models) - 8} more\n" + + error_msg += "\nTo list all models: ollama list" + error_msg += "\nTo download a new model: ollama pull " + error_msg += "\nBrowse models: https://ollama.com/library" + return error_msg + + elif llm_type == "hf": + # For HF models, we can do a quick existence check + if not check_hf_model_exists(model_name): + # Use HF Hub's native fuzzy search directly + search_suggestions = search_hf_models_fuzzy(model_name, limit=8) + + error_msg = f"Model '{model_name}' not found on HuggingFace Hub." + if search_suggestions: + error_msg += "\n\nDid you mean one of these?\n" + for i, suggestion in enumerate(search_suggestions, 1): + error_msg += f" {i}. {suggestion}\n" + else: + # Fallback to popular models if search returns nothing + popular_models = get_popular_hf_models() + error_msg += "\n\nPopular chat models:\n" + for i, model in enumerate(popular_models[:5], 1): + error_msg += f" {i}. {model}\n" + + error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation" + return error_msg + + return None # Model is valid or we can't check + + class LLMInterface(ABC): """Abstract base class for a generic Language Model (LLM) interface.""" + @abstractmethod def ask(self, prompt: str, **kwargs) -> str: """ @@ -32,7 +307,7 @@ class LLMInterface(ABC): batch_recompute=True, global_pruning=True ) - + Supported kwargs: - complexity (int): Search complexity parameter (default: 32) - beam_width (int): Beam width for search (default: 4) @@ -57,22 +332,37 @@ class LLMInterface(ABC): # """ pass + class OllamaChat(LLMInterface): """LLM interface for Ollama models.""" + def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): self.model = model self.host = host logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") try: import requests + # Check if the Ollama server is responsive if host: requests.get(host) + + # Pre-check model availability with helpful suggestions + model_error = validate_model_and_suggest(model, "ollama") + if model_error: + raise ValueError(model_error) + except ImportError: - raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.") + raise ImportError( + "The 'requests' library is required for Ollama. Please install it with 'pip install requests'." + ) except requests.exceptions.ConnectionError: - logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") - raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") + logger.error( + f"Could not connect to Ollama at {host}. Please ensure Ollama is running." + ) + raise ConnectionError( + f"Could not connect to Ollama at {host}. Please ensure Ollama is running." + ) def ask(self, prompt: str, **kwargs) -> str: import requests @@ -83,15 +373,15 @@ class OllamaChat(LLMInterface): "model": self.model, "prompt": prompt, "stream": False, # Keep it simple for now - "options": kwargs + "options": kwargs, } logger.info(f"Sending request to Ollama: {payload}") try: response = requests.post(full_url, data=json.dumps(payload)) response.raise_for_status() - + # The response from Ollama can be a stream of JSON objects, handle this - response_parts = response.text.strip().split('\n') + response_parts = response.text.strip().split("\n") full_response = "" for part in response_parts: if part: @@ -104,15 +394,25 @@ class OllamaChat(LLMInterface): logger.error(f"Error communicating with Ollama: {e}") return f"Error: Could not get a response from Ollama. Details: {e}" + class HFChat(LLMInterface): """LLM interface for local Hugging Face Transformers models.""" + def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"): logger.info(f"Initializing HFChat with model='{model_name}'") + + # Pre-check model availability with helpful suggestions + model_error = validate_model_and_suggest(model_name, "hf") + if model_error: + raise ValueError(model_error) + try: - from transformers import pipeline + from transformers.pipelines import pipeline import torch except ImportError: - raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.") + raise ImportError( + "The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'." + ) # Auto-detect device if torch.cuda.is_available(): @@ -140,47 +440,54 @@ class HFChat(LLMInterface): # Remove unsupported zero temperature and use deterministic generation kwargs.pop("temperature") kwargs.setdefault("do_sample", False) - + # Sensible defaults for text generation - params = { - "max_length": 500, - "num_return_sequences": 1, - **kwargs - } + params = {"max_length": 500, "num_return_sequences": 1, **kwargs} logger.info(f"Generating text with Hugging Face model with params: {params}") results = self.pipeline(prompt, **params) - + # Handle different response formats from transformers if isinstance(results, list) and len(results) > 0: - generated_text = results[0].get('generated_text', '') if isinstance(results[0], dict) else str(results[0]) + generated_text = ( + results[0].get("generated_text", "") + if isinstance(results[0], dict) + else str(results[0]) + ) else: generated_text = str(results) - + # Extract only the newly generated portion by removing the original prompt if isinstance(generated_text, str) and generated_text.startswith(prompt): - response = generated_text[len(prompt):].strip() + response = generated_text[len(prompt) :].strip() else: # Fallback: return the full response if prompt removal fails response = str(generated_text) - + return response + class OpenAIChat(LLMInterface): """LLM interface for OpenAI models.""" + def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): self.model = model self.api_key = api_key or os.getenv("OPENAI_API_KEY") - + if not self.api_key: - raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") - + raise ValueError( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." + ) + logger.info(f"Initializing OpenAI Chat with model='{model}'") - + try: import openai + self.client = openai.OpenAI(api_key=self.api_key) except ImportError: - raise ImportError("The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'.") + raise ImportError( + "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." + ) def ask(self, prompt: str, **kwargs) -> str: # Default parameters for OpenAI @@ -189,11 +496,15 @@ class OpenAIChat(LLMInterface): "messages": [{"role": "user", "content": prompt}], "max_tokens": kwargs.get("max_tokens", 1000), "temperature": kwargs.get("temperature", 0.7), - **{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]} + **{ + k: v + for k, v in kwargs.items() + if k not in ["max_tokens", "temperature"] + }, } - + logger.info(f"Sending request to OpenAI with model {self.model}") - + try: response = self.client.chat.completions.create(**params) return response.choices[0].message.content.strip() @@ -201,13 +512,16 @@ class OpenAIChat(LLMInterface): logger.error(f"Error communicating with OpenAI: {e}") return f"Error: Could not get a response from OpenAI. Details: {e}" + class SimulatedChat(LLMInterface): """A simple simulated chat for testing and development.""" + def ask(self, prompt: str, **kwargs) -> str: logger.info("Simulating LLM call...") print("Prompt sent to LLM (simulation):", prompt[:500] + "...") return "This is a simulated answer from the LLM based on the retrieved context." + def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface: """ Factory function to get an LLM interface based on configuration. @@ -225,16 +539,19 @@ def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface: llm_config = { "type": "openai", "model": "gpt-4o", - "api_key": os.getenv("OPENAI_API_KEY") + "api_key": os.getenv("OPENAI_API_KEY"), } llm_type = llm_config.get("type", "openai") model = llm_config.get("model") - + logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'") if llm_type == "ollama": - return OllamaChat(model=model or "llama3:8b", host=llm_config.get("host", "http://localhost:11434")) + return OllamaChat( + model=model or "llama3:8b", + host=llm_config.get("host", "http://localhost:11434"), + ) elif llm_type == "hf": return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") elif llm_type == "openai":