feat: disable warmup by default
This commit is contained in:
@@ -151,6 +151,7 @@ def create_hnsw_embedding_server(
|
||||
custom_max_length_param: Optional[int] = None,
|
||||
distance_metric: str = "mips",
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
@@ -167,6 +168,7 @@ def create_hnsw_embedding_server(
|
||||
model_name: Transformer model name
|
||||
custom_max_length_param: Custom max sequence length
|
||||
distance_metric: The distance metric to use
|
||||
enable_warmup: Whether to perform warmup requests on server start
|
||||
"""
|
||||
if not use_mlx:
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
@@ -465,7 +467,17 @@ def create_hnsw_embedding_server(
|
||||
"""Perform client-side warmup"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
sample_ids = ["0", "1", "2", "3", "4"]
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
@@ -477,7 +489,8 @@ def create_hnsw_embedding_server(
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
ids_to_send = []
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
@@ -915,10 +928,13 @@ def create_hnsw_embedding_server(
|
||||
pass
|
||||
|
||||
# Start warmup and server threads
|
||||
if len(passages) > 0:
|
||||
if enable_warmup and len(passages) > 0:
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
@@ -976,6 +992,12 @@ if __name__ == "__main__":
|
||||
default=False,
|
||||
help="Use MLX for model inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-warmup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable warmup requests on server start",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -992,4 +1014,5 @@ if __name__ == "__main__":
|
||||
custom_max_length_param=args.custom_max_length,
|
||||
distance_metric=args.distance_metric,
|
||||
use_mlx=args.use_mlx,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user