fix: no longger do embedding server reuse

This commit is contained in:
Andy Lee
2025-07-20 12:15:17 -07:00
parent 7522de1d41
commit f4998bb316
5 changed files with 232 additions and 311 deletions

View File

@@ -96,7 +96,7 @@ def compute_embeddings_sentence_transformers(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server" backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
) )
server_started = server_manager.start_server( server_started, actual_port = server_manager.start_server(
port=port, port=port,
model_name=model_name, model_name=model_name,
embedding_mode="sentence-transformers", embedding_mode="sentence-transformers",
@@ -104,7 +104,10 @@ def compute_embeddings_sentence_transformers(
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}") raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Use the actual port for connection
port = actual_port
# Connect to embedding server # Connect to embedding server
context = zmq.Context() context = zmq.Context()

View File

@@ -9,6 +9,7 @@ import msgpack
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import select import select
import psutil
def _check_port(port: int) -> bool: def _check_port(port: int) -> bool:
@@ -17,151 +18,131 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0 return s.connect_ex(("localhost", port)) == 0
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool: def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str = None
) -> bool:
""" """
Check if the existing server on the port is using the correct meta file. Check if the process using the port matches our expected model and passages file.
Returns True if the server has the right meta path, False otherwise. Returns True if matches, False otherwise.
""" """
try: try:
context = zmq.Context() for proc in psutil.process_iter(["pid", "cmdline"]):
socket = context.socket(zmq.REQ) if not _is_process_listening_on_port(proc, port):
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout continue
socket.connect(f"tcp://localhost:{port}")
cmdline = proc.info["cmdline"]
# Send a special control message to query the server's meta path if not cmdline:
control_request = ["__QUERY_META_PATH__"] continue
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes) return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file)
# Wait for response print(f"DEBUG: No process found listening on port {port}")
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the response contains the meta path and if it matches
if isinstance(response, list) and len(response) > 0:
server_meta_path = response[0]
# Normalize paths for comparison
expected_path = Path(expected_meta_path).resolve()
server_path = Path(server_meta_path).resolve() if server_meta_path else None
return server_path == expected_path
return False return False
except Exception as e: except Exception as e:
print(f"WARNING: Could not query server meta path on port {port}: {e}") print(f"WARNING: Could not check process on port {port}: {e}")
return False return False
def _update_server_meta_path(port: int, new_meta_path: str) -> bool: def _is_process_listening_on_port(proc, port: int) -> bool:
""" """Check if a process is listening on the given port."""
Send a control message to update the server's meta path.
Returns True if successful, False otherwise.
"""
try: try:
context = zmq.Context() connections = proc.net_connections()
socket = context.socket(zmq.REQ) for conn in connections:
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
socket.connect(f"tcp://localhost:{port}") return True
# Send a control message to update the meta path
control_request = ["__UPDATE_META_PATH__", new_meta_path]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
except Exception as e:
print(f"ERROR: Could not update server meta path on port {port}: {e}")
return False return False
def _check_server_model(port: int, expected_model: str) -> bool: def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str = None
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(server_type in cmdline_str for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server"
])
if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}")
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool:
"""Check if the command line contains the expected passages file."""
if not expected_passages_file:
return True # No passages file expected
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100
) -> tuple[int, bool]:
""" """
Check if the existing server on the port is using the correct embedding model. Find a port that either has a compatible server or is available.
Returns True if the server has the right model, False otherwise. Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
""" """
try: for port in range(start_port, start_port + max_attempts):
context = zmq.Context() if not _check_port(port):
socket = context.socket(zmq.REQ) # Port is available
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout return port, False
socket.connect(f"tcp://localhost:{port}")
# Send a special control message to query the server's model # Port is in use, check if it's compatible
control_request = ["__QUERY_MODEL__"] if _check_process_matches_config(port, model_name, passages_file):
request_bytes = msgpack.packb(control_request) print(f"✅ Found compatible server on port {port}")
socket.send(request_bytes) return port, True
else:
print(f"⚠️ Port {port} has incompatible server, trying next port...")
# Wait for response raise RuntimeError(
response_bytes = socket.recv() f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
response = msgpack.unpackb(response_bytes) )
socket.close()
context.term()
# Check if the response contains the model name and if it matches
if isinstance(response, list) and len(response) > 0:
server_model = response[0]
return server_model == expected_model
return False
except Exception as e:
print(f"WARNING: Could not query server model on port {port}: {e}")
return False
def _update_server_model(port: int, new_model: str) -> bool:
"""
Send a control message to update the server's embedding model.
Returns True if successful, False otherwise.
"""
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
socket.connect(f"tcp://localhost:{port}")
# Send a control message to update the model
control_request = ["__UPDATE_MODEL__", new_model]
request_bytes = msgpack.packb(control_request)
socket.send(request_bytes)
# Wait for response
response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Check if the update was successful
if isinstance(response, list) and len(response) > 0:
return response[0] == "SUCCESS"
return False
except Exception as e:
print(f"ERROR: Could not update server model on port {port}: {e}")
return False
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A generic manager for handling the lifecycle of a backend-specific embedding server process. A simplified manager for embedding server processes that avoids complex update mechanisms.
""" """
def __init__(self, backend_module_name: str): def __init__(self, backend_module_name: str):
@@ -177,208 +158,132 @@ class EmbeddingServerManager:
self.server_port: Optional[int] = None self.server_port: Optional[int] = None
atexit.register(self.stop_server) atexit.register(self.stop_server)
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool: def start_server(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
""" """
Starts the embedding server process. Starts the embedding server process.
Args: Args:
port (int): The ZMQ port for the server. port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use. model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup). **kwargs: Additional arguments for the server.
Returns: Returns:
bool: True if the server is started successfully or already running, False otherwise. tuple[bool, int]: (success, actual_port_used)
""" """
if self.server_process and self.server_process.poll() is None: passages_file = kwargs.get("passages_file")
# Even if we have a running process, check if model/meta path match
if self.server_port is not None:
port_in_use = _check_port(self.server_port)
if port_in_use:
print(
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
)
# Check model compatibility # Check if we have a compatible running server
model_matches = _check_server_model(self.server_port, model_name) if self._has_compatible_running_server(model_name, passages_file):
if model_matches: return True, self.server_port
print(
f"✅ Existing server already using correct model: {model_name}"
)
# Still check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
print(
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(self.server_port, model_name):
print(
"❌ Failed to update existing server model. Restarting server..."
)
self.stop_server()
# Continue to start new server below
else:
print(
f"✅ Successfully updated existing server model to: {model_name}"
)
# Also check meta path if provided
passages_file = kwargs.get("passages_file")
if passages_file and str(passages_file).endswith(
".meta.json"
):
meta_matches = _check_server_meta_path(
self.server_port, str(passages_file)
)
if not meta_matches:
print("⚠️ Updating meta path to: {passages_file}")
_update_server_meta_path(
self.server_port, str(passages_file)
)
return True
else:
# Server process exists but port not responding - restart
print("⚠️ Server process exists but not responding. Restarting...")
self.stop_server()
# Continue to start new server below
else:
# No port stored - restart
print("⚠️ No port information stored. Restarting server...")
self.stop_server()
# Continue to start new server below
if _check_port(port):
# Port is in use, check if it's using the correct meta file and model
passages_file = kwargs.get("passages_file")
print(f"INFO: Port {port} is in use. Checking server compatibility...")
# Check model compatibility first
model_matches = _check_server_model(port, model_name)
if model_matches:
print(
f"✅ Existing server on port {port} is using correct model: {model_name}"
)
else:
print(
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
)
if not _update_server_model(port, model_name):
raise RuntimeError(
f"❌ Failed to update server model to {model_name}. Consider using a different port."
)
print(f"✅ Successfully updated server model to: {model_name}")
# Check meta path compatibility if provided
if passages_file and str(passages_file).endswith(".meta.json"):
meta_matches = _check_server_meta_path(port, str(passages_file))
if not meta_matches:
print(
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
)
if not _update_server_meta_path(port, str(passages_file)):
raise RuntimeError(
"❌ Failed to update server meta path. This may cause data synchronization issues."
)
print(
f"✅ Successfully updated server meta path to: {passages_file}"
)
else:
print(
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
)
print(f"✅ Server on port {port} is compatible and ready to use.")
return True
print(
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
)
# Find available port (compatible or free)
try: try:
command = [ actual_port, is_compatible = _find_compatible_port_or_next_available(
sys.executable, port, model_name, passages_file
"-m",
self.backend_module_name,
"--zmq-port",
str(port),
"--model-name",
model_name,
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding="utf-8",
bufsize=1, # Line buffered
universal_newlines=True,
) )
self.server_port = port except RuntimeError as e:
print(f"INFO: Server process started with PID: {self.server_process.pid}") print(f"{e}")
return False, port
max_wait, wait_interval = 120, 0.5 if is_compatible:
for _ in range(int(max_wait / wait_interval)): print(f"✅ Using existing compatible server on port {actual_port}")
if _check_port(port): self.server_port = actual_port
print("✅ Embedding server is up and ready for this session.") self.server_process = None # We don't own this process
log_thread = threading.Thread(target=self._log_monitor, daemon=True) return True, actual_port
log_thread.start()
return True
if self.server_process.poll() is not None:
print(
"❌ ERROR: Server process terminated unexpectedly during startup."
)
self._print_recent_output()
return False
time.sleep(wait_interval)
print( if actual_port != port:
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds." print(f"⚠️ Using port {actual_port} instead of {port}")
)
self.stop_server() # Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible")
return True
print("⚠️ Existing server process is incompatible. Stopping it...")
self.stop_server()
return False
def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e: except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}") print(f"❌ ERROR: Failed to start embedding server: {e}")
return False return False, port
def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list:
"""Build the command to start the embedding server."""
command = [
sys.executable, "-m", self.backend_module_name,
"--zmq-port", str(port),
"--model-name", model_name,
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("enable_warmup") is False:
command.extend(["--disable-warmup"])
return command
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command, cwd=project_root,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, encoding="utf-8", bufsize=1, universal_newlines=True,
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("✅ Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
return True, port
if self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.")
self._print_recent_output()
return False, port
time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def _print_recent_output(self): def _print_recent_output(self):
"""Print any recent output from the server process.""" """Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout: if not self.server_process or not self.server_process.stdout:
return return
try: try:
# Read any available output
if select.select([self.server_process.stdout], [], [], 0)[0]: if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read() output = self.server_process.stdout.read()
if output: if output:

View File

@@ -80,7 +80,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started = self.embedding_server_manager.start_server( server_started, actual_port = self.embedding_server_manager.start_server(
port=port, port=port,
model_name=self.embedding_model, model_name=self.embedding_model,
passages_file=passages_source_file, passages_file=passages_source_file,
@@ -89,7 +89,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {port}") raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
# Update the port information for future use
if hasattr(self, '_actual_server_port'):
self._actual_server_port = actual_port
def compute_query_embedding( def compute_query_embedding(
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True

View File

@@ -35,6 +35,7 @@ dependencies = [
"llama-index-embeddings-huggingface>=0.5.5", "llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3", "mlx>=0.26.3",
"mlx-lm>=0.26.0", "mlx-lm>=0.26.0",
"psutil>=5.8.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

16
uv.lock generated
View File

@@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" }
dependencies = [ dependencies = [
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "tqdm" },
] ]
[package.metadata] [package.metadata]
requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }] requires-dist = [
{ name = "numpy", specifier = ">=1.20.0" },
{ name = "tqdm", specifier = ">=4.60.0" },
]
[[package]] [[package]]
name = "leann-workspace" name = "leann-workspace"
@@ -1851,7 +1855,6 @@ dependencies = [
{ name = "flask" }, { name = "flask" },
{ name = "flask-compress" }, { name = "flask-compress" },
{ name = "ipykernel" }, { name = "ipykernel" },
{ name = "leann-backend-diskann" },
{ name = "leann-backend-hnsw" }, { name = "leann-backend-hnsw" },
{ name = "leann-core" }, { name = "leann-core" },
{ name = "llama-index" }, { name = "llama-index" },
@@ -1867,6 +1870,7 @@ dependencies = [
{ name = "ollama" }, { name = "ollama" },
{ name = "openai" }, { name = "openai" },
{ name = "protobuf" }, { name = "protobuf" },
{ name = "psutil" },
{ name = "pypdf2" }, { name = "pypdf2" },
{ name = "requests" }, { name = "requests" },
{ name = "sentence-transformers" }, { name = "sentence-transformers" },
@@ -1884,6 +1888,9 @@ dev = [
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "ruff" }, { name = "ruff" },
] ]
diskann = [
{ name = "leann-backend-diskann" },
]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
@@ -1896,7 +1903,7 @@ requires-dist = [
{ name = "flask-compress" }, { name = "flask-compress" },
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" }, { name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
{ name = "ipykernel", specifier = "==6.29.5" }, { name = "ipykernel", specifier = "==6.29.5" },
{ name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" }, { name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" },
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" }, { name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
{ name = "leann-core", editable = "packages/leann-core" }, { name = "leann-core", editable = "packages/leann-core" },
{ name = "llama-index", specifier = ">=0.12.44" }, { name = "llama-index", specifier = ">=0.12.44" },
@@ -1912,6 +1919,7 @@ requires-dist = [
{ name = "ollama" }, { name = "ollama" },
{ name = "openai", specifier = ">=1.0.0" }, { name = "openai", specifier = ">=1.0.0" },
{ name = "protobuf", specifier = "==4.25.3" }, { name = "protobuf", specifier = "==4.25.3" },
{ name = "psutil", specifier = ">=5.8.0" },
{ name = "pypdf2", specifier = ">=3.0.0" }, { name = "pypdf2", specifier = ">=3.0.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
@@ -1922,7 +1930,7 @@ requires-dist = [
{ name = "torch" }, { name = "torch" },
{ name = "tqdm" }, { name = "tqdm" },
] ]
provides-extras = ["dev"] provides-extras = ["dev", "diskann"]
[[package]] [[package]]
name = "llama-cloud" name = "llama-cloud"