feat: disable warmup by default
This commit is contained in:
@@ -163,6 +163,7 @@ def create_embedding_server_thread(
|
||||
max_batch_size=128,
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
在当前线程中创建并运行 embedding server
|
||||
@@ -238,6 +239,62 @@ def create_embedding_server_thread(
|
||||
|
||||
print(f"INFO: Loaded {len(passages)} passages.")
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup for DiskANN server"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
|
||||
# Get actual passage IDs from the loaded passages
|
||||
sample_ids = []
|
||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
||||
available_ids = list(passages.keys())
|
||||
# Take up to 5 actual IDs, but at least 1
|
||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
||||
else:
|
||||
print("No passages available for warmup, skipping warmup...")
|
||||
return
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
||||
return
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
# Use protobuf format for warmup
|
||||
from . import embedding_pb2
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.node_ids.extend(ids_to_send)
|
||||
request_bytes = req_proto.SerializeToString()
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
resp_proto.ParseFromString(response_bytes)
|
||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side Protobuf ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""设备计时器"""
|
||||
def __init__(self, name="", device=device):
|
||||
@@ -343,6 +400,16 @@ def create_embedding_server_thread(
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
# Start warmup thread if enabled
|
||||
if enable_warmup and len(passages) > 0:
|
||||
import threading
|
||||
print(f"Warmup enabled: starting warmup thread")
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
else:
|
||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
@@ -587,12 +654,13 @@ def create_embedding_server(
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
passages_file: Optional[str] = None,
|
||||
use_mlx: bool = False,
|
||||
enable_warmup: bool = False,
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx)
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -610,6 +678,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
|
||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
||||
args = parser.parse_args()
|
||||
|
||||
create_embedding_server(
|
||||
@@ -625,4 +694,5 @@ if __name__ == "__main__":
|
||||
model_name=args.model_name,
|
||||
passages_file=args.passages_file,
|
||||
use_mlx=args.use_mlx,
|
||||
enable_warmup=not args.disable_warmup,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user