refactor: chat and base searcher

This commit is contained in:
Andy Lee
2025-07-11 16:34:12 +00:00
parent 8bffb1e5b8
commit 0da08fbe38
5 changed files with 353 additions and 428 deletions

View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
"""
This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMInterface(ABC):
"""Abstract base class for a generic Language Model (LLM) interface."""
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Sends a prompt to the LLM and returns the generated text.
Args:
prompt: The input prompt for the LLM.
**kwargs: Additional keyword arguments for the LLM backend.
Returns:
The response string from the LLM.
"""
pass
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
except ImportError:
raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.")
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
full_url = f"{self.host}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": kwargs
}
logger.info(f"Sending request to Ollama: {payload}")
try:
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
# The response from Ollama can be a stream of JSON objects, handle this
response_parts = response.text.strip().split('\n')
full_response = ""
for part in response_parts:
if part:
json_part = json.loads(part)
full_response += json_part.get("response", "")
if json_part.get("done"):
break
return full_response
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return f"Error: Could not get a response from Ollama. Details: {e}"
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
try:
from transformers import pipeline
except ImportError:
raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.")
self.pipeline = pipeline("text-generation", model=model_name)
def ask(self, prompt: str, **kwargs) -> str:
# Sensible defaults for text generation
params = {
"max_length": 500,
"num_return_sequences": 1,
**kwargs
}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
return results[0]['generated_text']
class SimulatedChat(LLMInterface):
"""A simple simulated chat for testing and development."""
def ask(self, prompt: str, **kwargs) -> str:
logger.info("Simulating LLM call...")
print("Prompt sent to LLM (simulation):", prompt[:500] + "...")
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
Args:
llm_config: A dictionary specifying the LLM type and its parameters.
Example: {"type": "ollama", "model": "llama3"}
{"type": "hf", "model": "distilgpt2"}
None (for simulation mode)
Returns:
An instance of an LLMInterface subclass.
"""
if llm_config is None:
logger.info("No LLM config provided, defaulting to simulated chat.")
return SimulatedChat()
llm_type = llm_config.get("type", "simulated")
model = llm_config.get("model")
logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'")
if llm_type == "ollama":
return OllamaChat(model=model, host=llm_config.get("host"))
elif llm_type == "hf":
return HFChat(model_name=model)
elif llm_type == "simulated":
return SimulatedChat()
else:
raise ValueError(f"Unknown LLM type: '{llm_type}'")

View File

@@ -0,0 +1,97 @@
import json
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Any, List
import numpy as np
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendSearcherInterface
class BaseSearcher(LeannBackendSearcherInterface, ABC):
"""
Abstract base class for Leann searchers, containing common logic for
loading metadata, managing embedding servers, and handling file paths.
"""
def __init__(self, index_path: str, backend_module_name: str, **kwargs):
"""
Initializes the BaseSearcher.
Args:
index_path: Path to the Leann index file (e.g., '.../my_index.leann').
backend_module_name: The specific embedding server module to use
(e.g., 'leann_backend_hnsw.hnsw_embedding_server').
**kwargs: Additional keyword arguments.
"""
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.meta = kwargs.get("meta", self._load_meta())
if not self.meta:
raise ValueError("Searcher requires metadata from .meta.json.")
self.dimensions = self.meta.get("dimensions")
if not self.dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
)
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
return pickle.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None:
"""
Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses.
"""
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}")
@abstractmethod
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.
Must be implemented by subclasses.
"""
pass
def __del__(self):
"""Ensures the embedding server is stopped when the searcher is destroyed."""
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()