Based on excellent analysis from user, implemented comprehensive fixes: 1. ZMQ Socket Cleanup: - Set LINGER=0 on all ZMQ sockets (client and server) - Use try-finally blocks to ensure socket.close() and context.term() - Prevents blocking on exit when ZMQ contexts have pending operations 2. Global Test Cleanup: - Added tests/conftest.py with session-scoped cleanup fixture - Cleans up leftover ZMQ contexts and child processes after all tests - Lists remaining threads for debugging 3. CI Improvements: - Apply timeout to ALL Python versions on Linux (not just 3.13) - Increased timeout to 180s for better reliability - Added process cleanup (pkill) on timeout 4. Dependencies: - Added psutil>=5.9.0 to test dependencies for process management Root cause: Python 3.9/3.13 are more sensitive to cleanup timing during interpreter shutdown. ZMQ's default LINGER=-1 was blocking exit, and atexit handlers were unreliable for cleanup. This should resolve the 'all tests pass but CI hangs' issue.
303 lines
11 KiB
Python
303 lines
11 KiB
Python
"""
|
|
HNSW-specific embedding server
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import msgpack
|
|
import numpy as np
|
|
import zmq
|
|
|
|
# Set up logging based on environment variable
|
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
logger.setLevel(log_level)
|
|
|
|
# Ensure we have a handler if none exists
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
logger.propagate = False
|
|
|
|
|
|
def create_hnsw_embedding_server(
|
|
passages_file: Optional[str] = None,
|
|
zmq_port: int = 5555,
|
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
distance_metric: str = "mips",
|
|
embedding_mode: str = "sentence-transformers",
|
|
):
|
|
"""
|
|
Create and start a ZMQ-based embedding server for HNSW backend.
|
|
Simplified version using unified embedding computation module.
|
|
"""
|
|
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
|
|
|
# Add leann-core to path for unified embedding computation
|
|
current_dir = Path(__file__).parent
|
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
sys.path.insert(0, str(leann_core_path))
|
|
|
|
try:
|
|
from leann.api import PassageManager
|
|
from leann.embedding_compute import compute_embeddings
|
|
|
|
logger.info("Successfully imported unified embedding computation module")
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import embedding computation module: {e}")
|
|
return
|
|
finally:
|
|
sys.path.pop(0)
|
|
|
|
# Check port availability
|
|
import socket
|
|
|
|
def check_port(port):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("localhost", port)) == 0
|
|
|
|
if check_port(zmq_port):
|
|
logger.error(f"Port {zmq_port} is already in use")
|
|
return
|
|
|
|
# Only support metadata file, fail fast for everything else
|
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
|
|
|
# Load metadata to get passage sources
|
|
with open(passages_file) as f:
|
|
meta = json.load(f)
|
|
|
|
# Let PassageManager handle path resolution uniformly
|
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
|
logger.info(
|
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
|
)
|
|
|
|
def zmq_server_thread():
|
|
"""ZMQ server thread"""
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REP)
|
|
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
|
socket.bind(f"tcp://*:{zmq_port}")
|
|
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
|
|
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
|
|
while True:
|
|
try:
|
|
message_bytes = socket.recv()
|
|
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
|
|
|
e2e_start = time.time()
|
|
request_payload = msgpack.unpackb(message_bytes)
|
|
|
|
# Handle direct text embedding request
|
|
if isinstance(request_payload, list) and len(request_payload) > 0:
|
|
# Check if this is a direct text request (list of strings)
|
|
if all(isinstance(item, str) for item in request_payload):
|
|
logger.info(
|
|
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
|
)
|
|
|
|
# Use unified embedding computation (now with model caching)
|
|
embeddings = compute_embeddings(
|
|
request_payload, model_name, mode=embedding_mode
|
|
)
|
|
|
|
response = embeddings.tolist()
|
|
socket.send(msgpack.packb(response))
|
|
e2e_end = time.time()
|
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
|
continue
|
|
|
|
# Handle distance calculation requests
|
|
if (
|
|
isinstance(request_payload, list)
|
|
and len(request_payload) == 2
|
|
and isinstance(request_payload[0], list)
|
|
and isinstance(request_payload[1], list)
|
|
):
|
|
node_ids = request_payload[0]
|
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
|
|
|
logger.debug("Distance calculation request received")
|
|
logger.debug(f" Node IDs: {node_ids}")
|
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
|
|
|
# Get embeddings for node IDs
|
|
texts = []
|
|
for nid in node_ids:
|
|
try:
|
|
passage_data = passages.get_passage(str(nid))
|
|
txt = passage_data["text"]
|
|
texts.append(txt)
|
|
except KeyError:
|
|
logger.error(f"Passage ID {nid} not found")
|
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
except Exception as e:
|
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
raise
|
|
|
|
# Process embeddings
|
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
logger.info(
|
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
)
|
|
|
|
# Calculate distances
|
|
if distance_metric == "l2":
|
|
distances = np.sum(
|
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
)
|
|
else: # mips or cosine
|
|
distances = -np.dot(embeddings, query_vector)
|
|
|
|
response_payload = distances.flatten().tolist()
|
|
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
|
logger.debug(f"Sending distance response with {len(distances)} distances")
|
|
|
|
socket.send(response_bytes)
|
|
e2e_end = time.time()
|
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
continue
|
|
|
|
# Standard embedding request (passage ID lookup)
|
|
if (
|
|
not isinstance(request_payload, list)
|
|
or len(request_payload) != 1
|
|
or not isinstance(request_payload[0], list)
|
|
):
|
|
logger.error(
|
|
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
|
)
|
|
socket.send(msgpack.packb([[], []]))
|
|
continue
|
|
|
|
node_ids = request_payload[0]
|
|
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
|
|
|
# Look up texts by node IDs
|
|
texts = []
|
|
for nid in node_ids:
|
|
try:
|
|
passage_data = passages.get_passage(str(nid))
|
|
txt = passage_data["text"]
|
|
if not txt:
|
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
texts.append(txt)
|
|
except KeyError:
|
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
except Exception as e:
|
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
raise
|
|
|
|
# Process embeddings
|
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
logger.info(
|
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
)
|
|
|
|
# Serialization and response
|
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
|
logger.error(
|
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
|
)
|
|
raise AssertionError()
|
|
|
|
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
response_payload = [
|
|
list(hidden_contiguous_f32.shape),
|
|
hidden_contiguous_f32.flatten().tolist(),
|
|
]
|
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
|
|
socket.send(response_bytes)
|
|
e2e_end = time.time()
|
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
|
|
except zmq.Again:
|
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in ZMQ server loop: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
socket.send(msgpack.packb([[], []]))
|
|
|
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
|
zmq_thread.start()
|
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("HNSW Server shutting down...")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import signal
|
|
import sys
|
|
|
|
def signal_handler(sig, frame):
|
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
sys.exit(0)
|
|
|
|
# Register signal handlers for graceful shutdown
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
parser.add_argument(
|
|
"--passages-file",
|
|
type=str,
|
|
help="JSON file containing passage ID to text mapping",
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
default="sentence-transformers/all-mpnet-base-v2",
|
|
help="Embedding model name",
|
|
)
|
|
parser.add_argument(
|
|
"--distance-metric", type=str, default="mips", help="Distance metric to use"
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-mode",
|
|
type=str,
|
|
default="sentence-transformers",
|
|
choices=["sentence-transformers", "openai", "mlx"],
|
|
help="Embedding backend mode",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create and start the HNSW embedding server
|
|
create_hnsw_embedding_server(
|
|
passages_file=args.passages_file,
|
|
zmq_port=args.zmq_port,
|
|
model_name=args.model_name,
|
|
distance_metric=args.distance_metric,
|
|
embedding_mode=args.embedding_mode,
|
|
)
|