diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 9fd7a92..5af7436 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -1,7 +1,3 @@ -import faulthandler - -faulthandler.enable() - import argparse from llama_index.core import SimpleDirectoryReader, Settings from llama_index.core.node_parser import SentenceSplitter @@ -62,7 +58,7 @@ async def main(args): print(f"\n[PHASE 2] Starting Leann chat session...") - llm_config = {"type": "hf", "model": "Qwen/Qwen3-8B"} + llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"} chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 7a60b8b..6a067e7 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -7,7 +7,7 @@ import json import pickle import numpy as np from pathlib import Path -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Literal from dataclasses import dataclass, field import uuid import torch @@ -250,22 +250,41 @@ class LeannSearcher: final_kwargs["enable_warmup"] = enable_warmup self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) - def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: + def search( + self, + query: str, + top_k: int = 5, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + **kwargs, + ) -> List[SearchResult]: print("🔍 DEBUG LeannSearcher.search() called:") print(f" Query: '{query}'") print(f" Top_k: {top_k}") - print(f" Search kwargs: {search_kwargs}") + print(f" Additional kwargs: {kwargs}") - query_embedding = compute_embeddings( - [query], self.embedding_model, self.use_mlx - ) + # Use backend's compute_query_embedding method + # This will automatically use embedding server if available and needed + query_embedding = self.backend_impl.compute_query_embedding(query, zmq_port) print(f" Generated embedding shape: {query_embedding.shape}") print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}") print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}") - # Add use_mlx to search kwargs - search_kwargs["use_mlx"] = self.use_mlx - results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) + results = self.backend_impl.search( + query_embedding, + top_k, + complexity=complexity, + beam_width=beam_width, + prune_ratio=prune_ratio, + recompute_embeddings=recompute_embeddings, + pruning_strategy=pruning_strategy, + zmq_port=zmq_port, + **kwargs, + ) print( f" Backend returned: labels={len(results.get('labels', [[]])[0])} results" ) @@ -309,8 +328,33 @@ class LeannChat: self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs) self.llm = get_llm(llm_config) - def ask(self, question: str, top_k=5, **kwargs): - results = self.searcher.search(question, top_k=top_k, **kwargs) + def ask( + self, + question: str, + top_k: int = 5, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + llm_kwargs: Optional[Dict[str, Any]] = None, + **search_kwargs, + ): + if llm_kwargs is None: + llm_kwargs = {} + + results = self.searcher.search( + question, + top_k=top_k, + complexity=complexity, + beam_width=beam_width, + prune_ratio=prune_ratio, + recompute_embeddings=recompute_embeddings, + pruning_strategy=pruning_strategy, + zmq_port=zmq_port, + **search_kwargs, + ) context = "\n\n".join([r.text for r in results]) prompt = ( "Here is some retrieved context that might help answer your question:\n\n" @@ -318,7 +362,7 @@ class LeannChat: f"Question: {question}\n\n" "Please provide the best answer you can based on this context and your knowledge." ) - return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {})) + return self.llm.ask(prompt, **llm_kwargs) def start_interactive(self): print("\nLeann Chat started (type 'quit' to exit)") diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index 43b76fa..fc7c71b 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -2,13 +2,16 @@ from abc import ABC, abstractmethod import numpy as np from typing import Dict, Any, List, Literal + class LeannBackendBuilderInterface(ABC): """Backend interface for building indexes""" - - @abstractmethod - def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None: + + @abstractmethod + def build( + self, data: np.ndarray, ids: List[str], index_path: str, **kwargs + ) -> None: """Build index - + Args: data: Vector data (N, D) ids: List of string IDs for each vector @@ -17,30 +20,35 @@ class LeannBackendBuilderInterface(ABC): """ pass + class LeannBackendSearcherInterface(ABC): """Backend interface for searching""" - + @abstractmethod def __init__(self, index_path: str, **kwargs): """Initialize searcher - + Args: index_path: Path to index file **kwargs: Backend-specific loading parameters """ pass - + @abstractmethod - def search(self, query: np.ndarray, top_k: int, - complexity: int = 64, - beam_width: int = 1, - prune_ratio: float = 0.0, - recompute_embeddings: bool = False, - pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, - **kwargs) -> Dict[str, Any]: + def search( + self, + query: np.ndarray, + top_k: int, + complexity: int = 64, + beam_width: int = 1, + prune_ratio: float = 0.0, + recompute_embeddings: bool = False, + pruning_strategy: Literal["global", "local", "proportional"] = "global", + zmq_port: int = 5557, + **kwargs, + ) -> Dict[str, Any]: """Search for nearest neighbors - + Args: query: Query vectors (B, D) where B is batch size, D is dimension top_k: Number of nearest neighbors to return @@ -51,23 +59,40 @@ class LeannBackendSearcherInterface(ABC): pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" zmq_port: ZMQ port for embedding server communication **kwargs: Backend-specific parameters - + Returns: {"labels": [...], "distances": [...]} """ pass + @abstractmethod + def compute_query_embedding( + self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True + ) -> np.ndarray: + """Compute embedding for a query string + + Args: + query: The query string to embed + zmq_port: ZMQ port for embedding server + use_server_if_available: Whether to try using embedding server first + + Returns: + Query embedding as numpy array with shape (1, D) + """ + pass + + class LeannBackendFactoryInterface(ABC): """Backend factory interface""" - + @staticmethod @abstractmethod def builder(**kwargs) -> LeannBackendBuilderInterface: """Create Builder instance""" pass - + @staticmethod - @abstractmethod + @abstractmethod def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: """Create Searcher instance""" - pass \ No newline at end of file + pass diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 82069e3..55c9843 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -89,6 +89,72 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): if not server_started: raise RuntimeError(f"Failed to start embedding server on port {port}") + def compute_query_embedding( + self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True + ) -> np.ndarray: + """ + Compute embedding for a query string. + + Args: + query: The query string to embed + zmq_port: ZMQ port for embedding server + use_server_if_available: Whether to try using embedding server first + + Returns: + Query embedding as numpy array + """ + # Try to use embedding server if available and requested + if ( + use_server_if_available + and self.embedding_server_manager + and self.embedding_server_manager.server_process + ): + try: + return self._compute_embedding_via_server([query], zmq_port)[ + 0:1 + ] # Return (1, D) shape + except Exception as e: + print(f"⚠️ Embedding server failed: {e}") + print("⏭️ Falling back to direct model loading...") + + # Fallback to direct computation + from .api import compute_embeddings + + use_mlx = self.meta.get("use_mlx", False) + return compute_embeddings([query], self.embedding_model, use_mlx) + + def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: + """Compute embeddings using the ZMQ embedding server.""" + import zmq + import msgpack + + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout + socket.connect(f"tcp://localhost:{zmq_port}") + + # Send embedding request + request = chunks + request_bytes = msgpack.packb(request) + socket.send(request_bytes) + + # Wait for response + response_bytes = socket.recv() + response = msgpack.unpackb(response_bytes) + + socket.close() + context.term() + + # Convert response to numpy array + if isinstance(response, list) and len(response) > 0: + return np.array(response, dtype=np.float32) + else: + raise RuntimeError("Invalid response from embedding server") + + except Exception as e: + raise RuntimeError(f"Failed to compute embeddings via server: {e}") + @abstractmethod def search( self,