- Add --distance-metric parameter to diskann_embedding_server.py for consistency with other backends - Remove pytest.skip and pytest.xfail markers for macOS C++ ABI issues as they have been fixed - Fix test assertions to handle SearchResult objects correctly - All tests now pass on macOS with the C++ ABI compatibility fixes
285 lines
10 KiB
Python
285 lines
10 KiB
Python
"""
|
|
DiskANN-specific embedding server
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import zmq
|
|
|
|
# Set up logging based on environment variable
|
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
logger.setLevel(log_level)
|
|
|
|
# Ensure we have a handler if none exists
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
logger.propagate = False
|
|
|
|
|
|
def create_diskann_embedding_server(
|
|
passages_file: str | None = None,
|
|
zmq_port: int = 5555,
|
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
embedding_mode: str = "sentence-transformers",
|
|
distance_metric: str = "l2",
|
|
):
|
|
"""
|
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
|
"""
|
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
|
|
|
# Add leann-core to path for unified embedding computation
|
|
current_dir = Path(__file__).parent
|
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
sys.path.insert(0, str(leann_core_path))
|
|
|
|
try:
|
|
from leann.api import PassageManager
|
|
from leann.embedding_compute import compute_embeddings
|
|
|
|
logger.info("Successfully imported unified embedding computation module")
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import embedding computation module: {e}")
|
|
return
|
|
finally:
|
|
sys.path.pop(0)
|
|
|
|
# Check port availability
|
|
import socket
|
|
|
|
def check_port(port):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("localhost", port)) == 0
|
|
|
|
if check_port(zmq_port):
|
|
logger.error(f"Port {zmq_port} is already in use")
|
|
return
|
|
|
|
# Only support metadata file, fail fast for everything else
|
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
|
|
|
# Load metadata to get passage sources
|
|
with open(passages_file) as f:
|
|
meta = json.load(f)
|
|
|
|
passages = PassageManager(meta["passage_sources"])
|
|
logger.info(
|
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
|
)
|
|
|
|
# Import protobuf after ensuring the path is correct
|
|
try:
|
|
from . import embedding_pb2
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import protobuf module: {e}")
|
|
return
|
|
|
|
def zmq_server_thread():
|
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
|
context = zmq.Context()
|
|
socket = context.socket(
|
|
zmq.REP
|
|
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
|
socket.bind(f"tcp://*:{zmq_port}")
|
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
|
|
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
|
|
while True:
|
|
try:
|
|
# REP socket receives single-part messages
|
|
message = socket.recv()
|
|
|
|
# Check for empty messages - REP socket requires response to every request
|
|
if len(message) == 0:
|
|
logger.debug("Received empty message, sending empty response")
|
|
socket.send(b"") # REP socket must respond to every request
|
|
continue
|
|
|
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
|
|
|
e2e_start = time.time()
|
|
|
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
|
texts = []
|
|
node_ids = []
|
|
is_text_request = False
|
|
|
|
try:
|
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
req_proto.ParseFromString(message)
|
|
node_ids = list(req_proto.node_ids)
|
|
|
|
if not node_ids:
|
|
raise RuntimeError(
|
|
f"PROTOBUF: Received empty node_ids! Message size: {len(message)}"
|
|
)
|
|
|
|
logger.info(
|
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
|
)
|
|
except Exception as protobuf_error:
|
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
|
try:
|
|
import msgpack
|
|
|
|
request = msgpack.unpackb(message)
|
|
# For BaseSearcher compatibility, request is a list of texts directly
|
|
if isinstance(request, list) and all(
|
|
isinstance(item, str) for item in request
|
|
):
|
|
texts = request
|
|
is_text_request = True
|
|
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
|
else:
|
|
raise ValueError("Not a valid msgpack text request")
|
|
except Exception as msgpack_error:
|
|
raise RuntimeError(
|
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
|
)
|
|
|
|
# Look up texts by node IDs (only if not direct text request)
|
|
if not is_text_request:
|
|
for nid in node_ids:
|
|
try:
|
|
passage_data = passages.get_passage(str(nid))
|
|
txt = passage_data["text"]
|
|
if not txt:
|
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
texts.append(txt)
|
|
except KeyError as e:
|
|
logger.error(f"Passage ID {nid} not found: {e}")
|
|
raise e
|
|
except Exception as e:
|
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
raise
|
|
|
|
# Debug logging
|
|
logger.debug(f"Processing {len(texts)} texts")
|
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
|
|
|
# Process embeddings using unified computation
|
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
logger.info(
|
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
)
|
|
|
|
# Prepare response based on request type
|
|
if is_text_request:
|
|
# For BaseSearcher compatibility: return msgpack format
|
|
import msgpack
|
|
|
|
response_data = msgpack.packb(embeddings.tolist())
|
|
else:
|
|
# For DiskANN C++ compatibility: return protobuf format
|
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
|
|
# Serialize embeddings data
|
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
|
|
response_data = resp_proto.SerializeToString()
|
|
|
|
# Send response back to the client
|
|
socket.send(response_data)
|
|
|
|
e2e_end = time.time()
|
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
|
|
except zmq.Again:
|
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error in ZMQ server loop: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
|
zmq_thread.start()
|
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
logger.info("DiskANN Server shutting down...")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import signal
|
|
import sys
|
|
|
|
def signal_handler(sig, frame):
|
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
sys.exit(0)
|
|
|
|
# Register signal handlers for graceful shutdown
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
parser.add_argument(
|
|
"--passages-file",
|
|
type=str,
|
|
help="Metadata JSON file containing passage sources",
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
default="sentence-transformers/all-mpnet-base-v2",
|
|
help="Embedding model name",
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-mode",
|
|
type=str,
|
|
default="sentence-transformers",
|
|
choices=["sentence-transformers", "openai", "mlx"],
|
|
help="Embedding backend mode",
|
|
)
|
|
parser.add_argument(
|
|
"--distance-metric",
|
|
type=str,
|
|
default="l2",
|
|
choices=["l2", "mips", "cosine"],
|
|
help="Distance metric for similarity computation",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create and start the DiskANN embedding server
|
|
create_diskann_embedding_server(
|
|
passages_file=args.passages_file,
|
|
zmq_port=args.zmq_port,
|
|
model_name=args.model_name,
|
|
embedding_mode=args.embedding_mode,
|
|
distance_metric=args.distance_metric,
|
|
)
|