feat: disable warmup by default

This commit is contained in:
Andy Lee
2025-07-15 22:16:02 -07:00
parent 125c1f6f25
commit 6a1dc895fb
5 changed files with 105 additions and 8 deletions

View File

@@ -151,6 +151,7 @@ def create_hnsw_embedding_server(
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
@@ -167,6 +168,7 @@ def create_hnsw_embedding_server(
model_name: Transformer model name
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
enable_warmup: Whether to perform warmup requests on server start
"""
if not use_mlx:
print(f"Loading tokenizer for {model_name}...")
@@ -465,7 +467,17 @@ def create_hnsw_embedding_server(
"""Perform client-side warmup"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["0", "1", "2", "3", "4"]
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
@@ -477,7 +489,8 @@ def create_hnsw_embedding_server(
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
ids_to_send = []
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
@@ -915,10 +928,13 @@ def create_hnsw_embedding_server(
pass
# Start warmup and server threads
if len(passages) > 0:
if enable_warmup and len(passages) > 0:
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
@@ -976,6 +992,12 @@ if __name__ == "__main__":
default=False,
help="Use MLX for model inference",
)
parser.add_argument(
"--disable-warmup",
action="store_true",
default=False,
help="Disable warmup requests on server start",
)
args = parser.parse_args()
@@ -992,4 +1014,5 @@ if __name__ == "__main__":
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
use_mlx=args.use_mlx,
enable_warmup=not args.disable_warmup,
)