feat: disable warmup by default
This commit is contained in:
@@ -163,6 +163,7 @@ def create_embedding_server_thread(
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
在当前线程中创建并运行 embedding server
|
||||
@@ -238,6 +239,62 @@ def create_embedding_server_thread(
|
||||
|
||||
print(f"INFO: Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""设备计时器"""
|
||||
def __init__(self, name="", device=device):
|
||||
@@ -343,6 +400,16 @@ def create_embedding_server_thread(
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
@@ -587,12 +654,13 @@ def create_embedding_server(
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -610,6 +678,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
create_embedding_server(
|
||||
@@ -625,4 +694,5 @@ if __name__ == "__main__":
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
use_mlx=args.use_mlx,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
|
||||
@@ -151,6 +151,7 @@ def create_hnsw_embedding_server(
|
||||
custom_max_length_param: Optional[int] = None,
|
||||
distance_metric: str = "mips",
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
@@ -167,6 +168,7 @@ def create_hnsw_embedding_server(
|
||||
model_name: Transformer model name
|
||||
custom_max_length_param: Custom max sequence length
|
||||
distance_metric: The distance metric to use
|
||||
enable_warmup: Whether to perform warmup requests on server start
|
||||
"""
|
||||
if not use_mlx:
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
@@ -465,7 +467,17 @@ def create_hnsw_embedding_server(
|
||||
"""Perform client-side warmup"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
sample_ids = ["0", "1", "2", "3", "4"]
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
@@ -477,7 +489,8 @@ def create_hnsw_embedding_server(
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
ids_to_send = []
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
@@ -915,10 +928,13 @@ def create_hnsw_embedding_server(
|
||||
pass
|
||||
|
||||
# Start warmup and server threads
|
||||
if len(passages) > 0:
|
||||
if enable_warmup and len(passages) > 0:
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
@@ -976,6 +992,12 @@ if __name__ == "__main__":
|
||||
default=False,
|
||||
help="Use MLX for model inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-warmup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable warmup requests on server start",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -992,4 +1014,5 @@ if __name__ == "__main__":
|
||||
custom_max_length_param=args.custom_max_length,
|
||||
distance_metric=args.distance_metric,
|
||||
use_mlx=args.use_mlx,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
|
||||
@@ -237,7 +237,7 @@ class LeannBuilder:
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, **backend_kwargs):
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
@@ -251,6 +251,7 @@ class LeannSearcher:
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
|
||||
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
|
||||
@@ -306,9 +307,9 @@ from .chat import get_llm
|
||||
|
||||
class LeannChat:
|
||||
def __init__(
|
||||
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs
|
||||
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs
|
||||
):
|
||||
self.searcher = LeannSearcher(index_path, **kwargs)
|
||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||
self.llm = get_llm(llm_config)
|
||||
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
|
||||
@@ -184,7 +184,7 @@ class EmbeddingServerManager:
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
@@ -312,6 +312,8 @@ class EmbeddingServerManager:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if "use_mlx" in kwargs and kwargs["use_mlx"]:
|
||||
command.extend(["--use-mlx"])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
|
||||
@@ -79,6 +79,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
use_mlx=kwargs.get("use_mlx", False),
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
Reference in New Issue
Block a user