feat: chat on mps

This commit is contained in:
Andy Lee
2025-07-12 06:07:43 +00:00
parent d288946173
commit ec5e9ac33b
5 changed files with 54 additions and 238 deletions

View File

@@ -147,8 +147,8 @@ class HNSWSearcher(BaseSearcher):
params = faiss.SearchParametersHNSW()
params.zmq_port = kwargs.get("zmq_port", 5557)
params.efSearch = kwargs.get("ef", 128)
params.beam_size = 2
params.efSearch = kwargs.get("complexity", 32)
params.beam_size = kwargs.get("beam_width", 1)
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)

View File

@@ -18,15 +18,43 @@ class LLMInterface(ABC):
@abstractmethod
def ask(self, prompt: str, **kwargs) -> str:
"""
Sends a prompt to the LLM and returns the generated text.
Args:
prompt: The input prompt for the LLM.
**kwargs: Additional keyword arguments for the LLM backend.
Returns:
The response string from the LLM.
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
# """
# Sends a prompt to the LLM and returns the generated text.
# Args:
# prompt: The input prompt for the LLM.
# **kwargs: Additional keyword arguments for the LLM backend.
# Returns:
# The response string from the LLM.
# """
pass
class OllamaChat(LLMInterface):
@@ -82,10 +110,22 @@ class HFChat(LLMInterface):
logger.info(f"Initializing HFChat with model='{model_name}'")
try:
from transformers import pipeline
import torch
except ImportError:
raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.")
self.pipeline = pipeline("text-generation", model=model_name)
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
def ask(self, prompt: str, **kwargs) -> str:
# Sensible defaults for text generation