feat: chat on mps
This commit is contained in:
@@ -147,8 +147,8 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
params.zmq_port = kwargs.get("zmq_port", 5557)
|
||||
params.efSearch = kwargs.get("ef", 128)
|
||||
params.beam_size = 2
|
||||
params.efSearch = kwargs.get("complexity", 32)
|
||||
params.beam_size = kwargs.get("beam_width", 1)
|
||||
|
||||
batch_size = query.shape[0]
|
||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||
|
||||
@@ -18,15 +18,43 @@ class LLMInterface(ABC):
|
||||
@abstractmethod
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
Sends a prompt to the LLM and returns the generated text.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt for the LLM.
|
||||
**kwargs: Additional keyword arguments for the LLM backend.
|
||||
|
||||
Returns:
|
||||
The response string from the LLM.
|
||||
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
|
||||
chat.ask(
|
||||
"What is ANN?",
|
||||
top_k=10,
|
||||
complexity=64,
|
||||
beam_width=8,
|
||||
USE_DEFERRED_FETCH=True,
|
||||
skip_search_reorder=True,
|
||||
recompute_beighbor_embeddings=True,
|
||||
dedup_node_dis=True,
|
||||
prune_ratio=0.1,
|
||||
batch_recompute=True,
|
||||
global_pruning=True
|
||||
)
|
||||
|
||||
Supported kwargs:
|
||||
- complexity (int): Search complexity parameter (default: 32)
|
||||
- beam_width (int): Beam width for search (default: 4)
|
||||
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||
- prune_ratio (float): Pruning ratio for search (default: 0.0)
|
||||
- batch_recompute (bool): Enable batch recomputation (default: False)
|
||||
- global_pruning (bool): Enable global pruning (default: False)
|
||||
"""
|
||||
|
||||
# """
|
||||
# Sends a prompt to the LLM and returns the generated text.
|
||||
|
||||
# Args:
|
||||
# prompt: The input prompt for the LLM.
|
||||
# **kwargs: Additional keyword arguments for the LLM backend.
|
||||
|
||||
# Returns:
|
||||
# The response string from the LLM.
|
||||
# """
|
||||
pass
|
||||
|
||||
class OllamaChat(LLMInterface):
|
||||
@@ -82,10 +110,22 @@ class HFChat(LLMInterface):
|
||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||
try:
|
||||
from transformers import pipeline
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.")
|
||||
|
||||
self.pipeline = pipeline("text-generation", model=model_name)
|
||||
raise ImportError("The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'.")
|
||||
|
||||
# Auto-detect device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
logger.info("CUDA is available. Using GPU.")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||
else:
|
||||
device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Sensible defaults for text generation
|
||||
|
||||
Reference in New Issue
Block a user