diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index f7a59cd..d517c0f 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -163,6 +163,7 @@ def create_embedding_server_thread( max_batch_size=128, passages_file: Optional[str] = None, use_mlx: bool = False, + enable_warmup: bool = False, ): """ 在当前线程中创建并运行 embedding server @@ -238,6 +239,62 @@ def create_embedding_server_thread( print(f"INFO: Loaded {len(passages)} passages.") + def client_warmup(zmq_port): + """Perform client-side warmup for DiskANN server""" + time.sleep(2) + print(f"Performing client-side warmup with model {model_name}...") + + # Get actual passage IDs from the loaded passages + sample_ids = [] + if hasattr(passages, 'keys') and len(passages) > 0: + available_ids = list(passages.keys()) + # Take up to 5 actual IDs, but at least 1 + sample_ids = available_ids[:min(5, len(available_ids))] + print(f"Using actual passage IDs for warmup: {sample_ids}") + else: + print("No passages available for warmup, skipping warmup...") + return + + try: + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect(f"tcp://localhost:{zmq_port}") + socket.setsockopt(zmq.RCVTIMEO, 30000) + socket.setsockopt(zmq.SNDTIMEO, 30000) + + try: + ids_to_send = [int(x) for x in sample_ids] + except ValueError: + print("Warning: Could not convert sample IDs to integers, skipping warmup") + return + + if not ids_to_send: + print("Skipping warmup send.") + return + + # Use protobuf format for warmup + from . import embedding_pb2 + req_proto = embedding_pb2.NodeEmbeddingRequest() + req_proto.node_ids.extend(ids_to_send) + request_bytes = req_proto.SerializeToString() + + for i in range(3): + print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...") + socket.send(request_bytes) + response_bytes = socket.recv() + + resp_proto = embedding_pb2.NodeEmbeddingResponse() + resp_proto.ParseFromString(response_bytes) + embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0 + print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings") + time.sleep(0.1) + + print("Client-side Protobuf ZMQ warmup complete") + socket.close() + context.term() + except Exception as e: + print(f"Error during Protobuf ZMQ warmup: {e}") + class DeviceTimer: """设备计时器""" def __init__(self, name="", device=device): @@ -343,6 +400,16 @@ def create_embedding_server_thread( print(f"INFO: Embedding server ready to serve requests") + # Start warmup thread if enabled + if enable_warmup and len(passages) > 0: + import threading + print(f"Warmup enabled: starting warmup thread") + warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,)) + warmup_thread.daemon = True + warmup_thread.start() + else: + print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})") + while True: try: parts = socket.recv_multipart() @@ -587,12 +654,13 @@ def create_embedding_server( model_name="sentence-transformers/all-mpnet-base-v2", passages_file: Optional[str] = None, use_mlx: bool = False, + enable_warmup: bool = False, ): """ 原有的 create_embedding_server 函数保持不变 这个是阻塞版本,用于直接运行 """ - create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx) + create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup) if __name__ == "__main__": @@ -610,6 +678,7 @@ if __name__ == "__main__": parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Embedding model name") parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings") + parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start") args = parser.parse_args() create_embedding_server( @@ -625,4 +694,5 @@ if __name__ == "__main__": model_name=args.model_name, passages_file=args.passages_file, use_mlx=args.use_mlx, + enable_warmup=not args.disable_warmup, ) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 624b8d3..c19f581 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -151,6 +151,7 @@ def create_hnsw_embedding_server( custom_max_length_param: Optional[int] = None, distance_metric: str = "mips", use_mlx: bool = False, + enable_warmup: bool = False, ): """ Create and start a ZMQ-based embedding server for HNSW backend. @@ -167,6 +168,7 @@ def create_hnsw_embedding_server( model_name: Transformer model name custom_max_length_param: Custom max sequence length distance_metric: The distance metric to use + enable_warmup: Whether to perform warmup requests on server start """ if not use_mlx: print(f"Loading tokenizer for {model_name}...") @@ -465,7 +467,17 @@ def create_hnsw_embedding_server( """Perform client-side warmup""" time.sleep(2) print(f"Performing client-side warmup with model {model_name}...") - sample_ids = ["0", "1", "2", "3", "4"] + + # Get actual passage IDs from the loaded passages + sample_ids = [] + if hasattr(passages, 'keys') and len(passages) > 0: + available_ids = list(passages.keys()) + # Take up to 5 actual IDs, but at least 1 + sample_ids = available_ids[:min(5, len(available_ids))] + print(f"Using actual passage IDs for warmup: {sample_ids}") + else: + print("No passages available for warmup, skipping warmup...") + return try: context = zmq.Context() @@ -477,7 +489,8 @@ def create_hnsw_embedding_server( try: ids_to_send = [int(x) for x in sample_ids] except ValueError: - ids_to_send = [] + print("Warning: Could not convert sample IDs to integers, skipping warmup") + return if not ids_to_send: print("Skipping warmup send.") @@ -915,10 +928,13 @@ def create_hnsw_embedding_server( pass # Start warmup and server threads - if len(passages) > 0: + if enable_warmup and len(passages) > 0: + print(f"Warmup enabled: starting warmup thread") warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,)) warmup_thread.daemon = True warmup_thread.start() + else: + print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})") zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread.start() @@ -976,6 +992,12 @@ if __name__ == "__main__": default=False, help="Use MLX for model inference", ) + parser.add_argument( + "--disable-warmup", + action="store_true", + default=False, + help="Disable warmup requests on server start", + ) args = parser.parse_args() @@ -992,4 +1014,5 @@ if __name__ == "__main__": custom_max_length_param=args.custom_max_length, distance_metric=args.distance_metric, use_mlx=args.use_mlx, + enable_warmup=not args.disable_warmup, ) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index e8dca42..c112b28 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -237,7 +237,7 @@ class LeannBuilder: class LeannSearcher: - def __init__(self, index_path: str, **backend_kwargs): + def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): meta_path_str = f"{index_path}.meta.json" if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}") @@ -251,6 +251,7 @@ class LeannSearcher: if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.") final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} + final_kwargs["enable_warmup"] = enable_warmup self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]: @@ -306,9 +307,9 @@ from .chat import get_llm class LeannChat: def __init__( - self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs + self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs ): - self.searcher = LeannSearcher(index_path, **kwargs) + self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs) self.llm = get_llm(llm_config) def ask(self, question: str, top_k=5, **kwargs): diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 1e3e849..9b0bf53 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -184,7 +184,7 @@ class EmbeddingServerManager: Args: port (int): The ZMQ port for the server. model_name (str): The name of the embedding model to use. - **kwargs: Additional arguments for the server (e.g., passages_file, distance_metric). + **kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup). Returns: bool: True if the server is started successfully or already running, False otherwise. @@ -312,6 +312,8 @@ class EmbeddingServerManager: # command.extend(["--distance-metric", kwargs["distance_metric"]]) if "use_mlx" in kwargs and kwargs["use_mlx"]: command.extend(["--use-mlx"]) + if "enable_warmup" in kwargs and not kwargs["enable_warmup"]: + command.extend(["--disable-warmup"]) project_root = Path(__file__).parent.parent.parent.parent.parent print(f"INFO: Running command from project root: {project_root}") diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index c505ef3..b0f5ad3 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -79,6 +79,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): passages_file=passages_source_file, distance_metric=kwargs.get("distance_metric"), use_mlx=kwargs.get("use_mlx", False), + enable_warmup=kwargs.get("enable_warmup", False), ) if not server_started: raise RuntimeError(f"Failed to start embedding server on port {port}")