fix: no longger do embedding server reuse

This commit is contained in:
Andy Lee
2025-07-20 12:15:17 -07:00
parent 7522de1d41
commit f4998bb316
5 changed files with 232 additions and 311 deletions

View File

@@ -96,7 +96,7 @@ def compute_embeddings_sentence_transformers(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started = server_manager.start_server(
server_started, actual_port = server_manager.start_server(
port=port,
model_name=model_name,
embedding_mode="sentence-transformers",
@@ -104,7 +104,10 @@ def compute_embeddings_sentence_transformers(
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Use the actual port for connection
port = actual_port
# Connect to embedding server
context = zmq.Context()

View File

@@ -9,6 +9,7 @@ import msgpack
from pathlib import Path
from typing import Optional
import select
import psutil
def _check_port(port: int) -> bool:
@@ -17,151 +18,131 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str = None
) -> bool:
"""
Check if the existing server on the port is using the correct meta file.
Returns True if the server has the right meta path, False otherwise.
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a special control message to query the server's meta path
control_request = ["__QUERY_META_PATH__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the meta path and if it matches
if isinstance(response, list) and len(response) > 0:
server_meta_path = response[0]
# Normalize paths for comparison
expected_path = Path(expected_meta_path).resolve()
server_path = Path(server_meta_path).resolve() if server_meta_path else None
return server_path == expected_path
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file)
print(f"DEBUG: No process found listening on port {port}")
return False
except Exception as e:
print(f"WARNING: Could not query server meta path on port {port}: {e}")
print(f"WARNING: Could not check process on port {port}: {e}")
return False
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
"""
Send a control message to update the server's meta path.
Returns True if successful, False otherwise.
"""
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the meta path
control_request = ["__UPDATE_META_PATH__", new_meta_path]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except Exception as e:
print(f"ERROR: Could not update server meta path on port {port}: {e}")
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_server_model(port: int, expected_model: str) -> bool:
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str = None
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(server_type in cmdline_str for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server"
])
if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}")
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool:
"""Check if the command line contains the expected passages file."""
if not expected_passages_file:
return True # No passages file expected
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Check if the existing server on the port is using the correct embedding model.
Returns True if the server has the right model, False otherwise.
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
socket.connect(f"tcp://localhost:{port}")
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Send a special control message to query the server's model
control_request = ["__QUERY_MODEL__"]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
print(f"✅ Found compatible server on port {port}")
return port, True
else:
print(f"⚠️ Port {port} has incompatible server, trying next port...")
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the model name and if it matches
if isinstance(response, list) and len(response) > 0:
server_model = response[0]
return server_model == expected_model
return False
except Exception as e:
print(f"WARNING: Could not query server model on port {port}: {e}")
return False
def _update_server_model(port: int, new_model: str) -> bool:
"""
Send a control message to update the server's embedding model.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the model
control_request = ["__UPDATE_MODEL__", new_model]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server model on port {port}: {e}")
return False
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
"""
A generic manager for handling the lifecycle of a backend-specific embedding server process.
A simplified manager for embedding server processes that avoids complex update mechanisms.
"""
def __init__(self, backend_module_name: str):
@@ -177,208 +158,132 @@ class EmbeddingServerManager:
self.server_port: Optional[int] = None
atexit.register(self.stop_server)
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""
Starts the embedding server process.
Args:
port (int): The ZMQ port for the server.
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
**kwargs: Additional arguments for the server.
Returns:
bool: True if the server is started successfully or already running, False otherwise.
tuple[bool, int]: (success, actual_port_used)
"""
if self.server_process and self.server_process.poll() is None:
# Even if we have a running process, check if model/meta path match
if self.server_port is not None:
port_in_use = _check_port(self.server_port)
if port_in_use:
print(
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
)
passages_file = kwargs.get("passages_file")
# Check model compatibility
model_matches = _check_server_model(self.server_port, model_name)
if model_matches:
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(self.server_port, model_name):
print(
"❌ Failed to update existing server model. Restarting server..."
)
self.stop_server()
# Continue to start new server below
else:
print(
f"✅ Successfully updated existing server model to: {model_name}"
)
# Also check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
self.stop_server()
# Continue to start new server below
else:
# No port stored - restart
print("⚠️ No port information stored. Restarting server...")
self.stop_server()
# Continue to start new server below
if _check_port(port):
# Port is in use, check if it's using the correct meta file and model
passages_file = kwargs.get("passages_file")
print(f"INFO: Port {port} is in use. Checking server compatibility...")
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(port, model_name):
raise RuntimeError(
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):
meta_matches = _check_server_meta_path(port, str(passages_file))
if not meta_matches:
print(
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
)
if not _update_server_meta_path(port, str(passages_file)):
raise RuntimeError(
"❌ Failed to update server meta path. This may cause data synchronization issues."
)
print(
f"✅ Successfully updated server meta path to: {passages_file}"
)
else:
print(
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
)
print(f"✅ Server on port {port} is compatible and ready to use.")
return True
print(
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
)
# Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file):
return True, self.server_port
# Find available port (compatible or free)
try:
command = [
sys.executable,
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding="utf-8",
bufsize=1, # Line buffered
universal_newlines=True,
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
except RuntimeError as e:
print(f"{e}")
return False, port
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print(
"❌ ERROR: Server process terminated unexpectedly during startup."
)
self._print_recent_output()
return False
time.sleep(wait_interval)
if is_compatible:
print(f"✅ Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
print(
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
)
self.stop_server()
if actual_port != port:
print(f"⚠️ Using port {actual_port} instead of {port}")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible")
return True
print("⚠️ Existing server process is incompatible. Stopping it...")
self.stop_server()
return False
def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
print(f"❌ ERROR: Failed to start embedding server: {e}")
return False, port
def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable, "-m", self.backend_module_name,
"--zmq-port", str(port),
"--model-name", model_name,
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("enable_warmup") is False:
command.extend(["--disable-warmup"])
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command, cwd=project_root,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, encoding="utf-8", bufsize=1, universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
return True, port
if self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.")
self._print_recent_output()
return False, port
time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:

View File

@@ -80,7 +80,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started = self.embedding_server_manager.start_server(
server_started, actual_port = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
@@ -89,7 +89,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}")
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Update the port information for future use
if hasattr(self, '_actual_server_port'):
self._actual_server_port = actual_port
def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True

View File

@@ -35,6 +35,7 @@ dependencies = [
"llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
"psutil>=5.8.0",
]
[project.optional-dependencies]

16
uv.lock generated
View File

@@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" }
dependencies = [
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "tqdm" },
]
[package.metadata]
requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }]
requires-dist = [
{ name = "numpy", specifier = ">=1.20.0" },
{ name = "tqdm", specifier = ">=4.60.0" },
]
[[package]]
name = "leann-workspace"
@@ -1851,7 +1855,6 @@ dependencies = [
{ name = "flask" },
{ name = "flask-compress" },
{ name = "ipykernel" },
{ name = "leann-backend-diskann" },
{ name = "leann-backend-hnsw" },
{ name = "leann-core" },
{ name = "llama-index" },
@@ -1867,6 +1870,7 @@ dependencies = [
{ name = "ollama" },
{ name = "openai" },
{ name = "protobuf" },
{ name = "psutil" },
{ name = "pypdf2" },
{ name = "requests" },
{ name = "sentence-transformers" },
@@ -1884,6 +1888,9 @@ dev = [
{ name = "pytest-cov" },
{ name = "ruff" },
]
diskann = [
{ name = "leann-backend-diskann" },
]
[package.metadata]
requires-dist = [
@@ -1896,7 +1903,7 @@ requires-dist = [
{ name = "flask-compress" },
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
{ name = "ipykernel", specifier = "==6.29.5" },
{ name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" },
{ name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" },
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
{ name = "leann-core", editable = "packages/leann-core" },
{ name = "llama-index", specifier = ">=0.12.44" },
@@ -1912,6 +1919,7 @@ requires-dist = [
{ name = "ollama" },
{ name = "openai", specifier = ">=1.0.0" },
{ name = "protobuf", specifier = "==4.25.3" },
{ name = "psutil", specifier = ">=5.8.0" },
{ name = "pypdf2", specifier = ">=3.0.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
@@ -1922,7 +1930,7 @@ requires-dist = [
{ name = "torch" },
{ name = "tqdm" },
]
provides-extras = ["dev"]
provides-extras = ["dev", "diskann"]
[[package]]
name = "llama-cloud"