feat: disable warmup by default

This commit is contained in:
Andy Lee
2025-07-15 22:16:02 -07:00
parent 125c1f6f25
commit 6a1dc895fb
5 changed files with 105 additions and 8 deletions

View File

@@ -163,6 +163,7 @@ def create_embedding_server_thread(
max_batch_size=128,
passages_file: Optional[str] = None,
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
在当前线程中创建并运行 embedding server
@@ -238,6 +239,62 @@ def create_embedding_server_thread(
print(f"INFO: Loaded {len(passages)} passages.")
def client_warmup(zmq_port):
"""Perform client-side warmup for DiskANN server"""
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
# Get actual passage IDs from the loaded passages
sample_ids = []
if hasattr(passages, 'keys') and len(passages) > 0:
available_ids = list(passages.keys())
# Take up to 5 actual IDs, but at least 1
sample_ids = available_ids[:min(5, len(available_ids))]
print(f"Using actual passage IDs for warmup: {sample_ids}")
else:
print("No passages available for warmup, skipping warmup...")
return
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect(f"tcp://localhost:{zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
print("Warning: Could not convert sample IDs to integers, skipping warmup")
return
if not ids_to_send:
print("Skipping warmup send.")
return
# Use protobuf format for warmup
from . import embedding_pb2
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.node_ids.extend(ids_to_send)
request_bytes = req_proto.SerializeToString()
for i in range(3):
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
socket.send(request_bytes)
response_bytes = socket.recv()
resp_proto = embedding_pb2.NodeEmbeddingResponse()
resp_proto.ParseFromString(response_bytes)
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
time.sleep(0.1)
print("Client-side Protobuf ZMQ warmup complete")
socket.close()
context.term()
except Exception as e:
print(f"Error during Protobuf ZMQ warmup: {e}")
class DeviceTimer:
"""设备计时器"""
def __init__(self, name="", device=device):
@@ -343,6 +400,16 @@ def create_embedding_server_thread(
print(f"INFO: Embedding server ready to serve requests")
# Start warmup thread if enabled
if enable_warmup and len(passages) > 0:
import threading
print(f"Warmup enabled: starting warmup thread")
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
warmup_thread.daemon = True
warmup_thread.start()
else:
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
while True:
try:
parts = socket.recv_multipart()
@@ -587,12 +654,13 @@ def create_embedding_server(
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
use_mlx: bool = False,
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup)
if __name__ == "__main__":
@@ -610,6 +678,7 @@ if __name__ == "__main__":
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
create_embedding_server(
@@ -625,4 +694,5 @@ if __name__ == "__main__":
model_name=args.model_name,
passages_file=args.passages_file,
use_mlx=args.use_mlx,
enable_warmup=not args.disable_warmup,
)