Based on excellent analysis from user, implemented comprehensive fixes: 1. ZMQ Socket Cleanup: - Set LINGER=0 on all ZMQ sockets (client and server) - Use try-finally blocks to ensure socket.close() and context.term() - Prevents blocking on exit when ZMQ contexts have pending operations 2. Global Test Cleanup: - Added tests/conftest.py with session-scoped cleanup fixture - Cleans up leftover ZMQ contexts and child processes after all tests - Lists remaining threads for debugging 3. CI Improvements: - Apply timeout to ALL Python versions on Linux (not just 3.13) - Increased timeout to 180s for better reliability - Added process cleanup (pkill) on timeout 4. Dependencies: - Added psutil>=5.9.0 to test dependencies for process management Root cause: Python 3.9/3.13 are more sensitive to cleanup timing during interpreter shutdown. ZMQ's default LINGER=-1 was blocking exit, and atexit handlers were unreliable for cleanup. This should resolve the 'all tests pass but CI hangs' issue.
288 lines
11 KiB
Python
288 lines
11 KiB
Python
"""
|
|
DiskANN-specific embedding server
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import zmq
|
|
|
|
# Set up logging based on environment variable
|
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
logger.setLevel(log_level)
|
|
|
|
# Ensure we have a handler if none exists
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
logger.propagate = False
|
|
|
|
|
|
def create_diskann_embedding_server(
|
|
passages_file: Optional[str] = None,
|
|
zmq_port: int = 5555,
|
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
embedding_mode: str = "sentence-transformers",
|
|
distance_metric: str = "l2",
|
|
):
|
|
"""
|
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
|
"""
|
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
|
|
|
# Add leann-core to path for unified embedding computation
|
|
current_dir = Path(__file__).parent
|
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
sys.path.insert(0, str(leann_core_path))
|
|
|
|
try:
|
|
from leann.api import PassageManager
|
|
from leann.embedding_compute import compute_embeddings
|
|
|
|
logger.info("Successfully imported unified embedding computation module")
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import embedding computation module: {e}")
|
|
return
|
|
finally:
|
|
sys.path.pop(0)
|
|
|
|
# Check port availability
|
|
import socket
|
|
|
|
def check_port(port):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("localhost", port)) == 0
|
|
|
|
if check_port(zmq_port):
|
|
logger.error(f"Port {zmq_port} is already in use")
|
|
return
|
|
|
|
# Only support metadata file, fail fast for everything else
|
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
|
|
|
# Load metadata to get passage sources
|
|
with open(passages_file) as f:
|
|
meta = json.load(f)
|
|
|
|
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
|
logger.info(
|
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
|
)
|
|
|
|
# Import protobuf after ensuring the path is correct
|
|
try:
|
|
from . import embedding_pb2
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import protobuf module: {e}")
|
|
return
|
|
|
|
def zmq_server_thread():
|
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
|
context = zmq.Context()
|
|
socket = context.socket(
|
|
zmq.REP
|
|
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
|
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
|
socket.bind(f"tcp://*:{zmq_port}")
|
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
|
|
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
|
|
while True:
|
|
try:
|
|
# REP socket receives single-part messages
|
|
message = socket.recv()
|
|
|
|
# Check for empty messages - REP socket requires response to every request
|
|
if len(message) == 0:
|
|
logger.debug("Received empty message, sending empty response")
|
|
socket.send(b"") # REP socket must respond to every request
|
|
continue
|
|
|
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
|
|
|
e2e_start = time.time()
|
|
|
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
|
texts = []
|
|
node_ids = []
|
|
is_text_request = False
|
|
|
|
try:
|
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
req_proto.ParseFromString(message)
|
|
node_ids = list(req_proto.node_ids)
|
|
|
|
if not node_ids:
|
|
raise RuntimeError(
|
|
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
|
)
|
|
|
|
logger.info(
|
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
|
)
|
|
except Exception as protobuf_error:
|
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
|
try:
|
|
import msgpack
|
|
|
|
request = msgpack.unpackb(message)
|
|
# For BaseSearcher compatibility, request is a list of texts directly
|
|
if isinstance(request, list) and all(
|
|
isinstance(item, str) for item in request
|
|
):
|
|
texts = request
|
|
is_text_request = True
|
|
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
|
else:
|
|
raise ValueError("Not a valid msgpack text request")
|
|
except Exception as msgpack_error:
|
|
raise RuntimeError(
|
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
|
)
|
|
|
|
# Look up texts by node IDs (only if not direct text request)
|
|
if not is_text_request:
|
|
for nid in node_ids:
|
|
try:
|
|
passage_data = passages.get_passage(str(nid))
|
|
txt = passage_data["text"]
|
|
if not txt:
|
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
texts.append(txt)
|
|
except KeyError as e:
|
|
logger.error(f"Passage ID {nid} not found: {e}")
|
|
raise e
|
|
except Exception as e:
|
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
raise
|
|
|
|
# Debug logging
|
|
logger.debug(f"Processing {len(texts)} texts")
|
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
|
|
|
# Process embeddings using unified computation
|
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
logger.info(
|
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
)
|
|
|
|
# Prepare response based on request type
|
|
if is_text_request:
|
|
# For BaseSearcher compatibility: return msgpack format
|
|
import msgpack
|
|
|
|
response_data = msgpack.packb(embeddings.tolist())
|
|
else:
|
|
# For DiskANN C++ compatibility: return protobuf format
|
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
|
|
# Serialize embeddings data
|
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
|
|
response_data = resp_proto.SerializeToString()
|
|
|
|
# Send response back to the client
|
|
socket.send(response_data)
|
|
|
|
e2e_end = time.time()
|
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
|
|
except zmq.Again:
|
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in ZMQ server loop: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
|
zmq_thread.start()
|
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("DiskANN Server shutting down...")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import signal
|
|
import sys
|
|
|
|
def signal_handler(sig, frame):
|
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
sys.exit(0)
|
|
|
|
# Register signal handlers for graceful shutdown
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
parser.add_argument(
|
|
"--passages-file",
|
|
type=str,
|
|
help="Metadata JSON file containing passage sources",
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
default="sentence-transformers/all-mpnet-base-v2",
|
|
help="Embedding model name",
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-mode",
|
|
type=str,
|
|
default="sentence-transformers",
|
|
choices=["sentence-transformers", "openai", "mlx"],
|
|
help="Embedding backend mode",
|
|
)
|
|
parser.add_argument(
|
|
"--distance-metric",
|
|
type=str,
|
|
default="l2",
|
|
choices=["l2", "mips", "cosine"],
|
|
help="Distance metric for similarity computation",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create and start the DiskANN embedding server
|
|
create_diskann_embedding_server(
|
|
passages_file=args.passages_file,
|
|
zmq_port=args.zmq_port,
|
|
model_name=args.model_name,
|
|
embedding_mode=args.embedding_mode,
|
|
distance_metric=args.distance_metric,
|
|
)
|