refactor: embedding server manager
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from . import diskann_backend
|
||||
@@ -12,6 +12,7 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_backend
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
@@ -42,96 +43,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
||||
f.write(struct.pack('I', dim))
|
||||
f.write(data.tobytes())
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
class EmbeddingServerManager:
|
||||
def __init__(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
||||
return True
|
||||
|
||||
# 检查端口是否已被其他无关进程占用
|
||||
if _check_port(port):
|
||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
||||
return True
|
||||
|
||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
]
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
# stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 30, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print(f"✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
||||
self._log_monitor()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _log_monitor(self):
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
for line in iter(self.server_process.stdout.readline, ''):
|
||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
||||
self.server_process.stdout.close()
|
||||
if self.server_process.stderr:
|
||||
for line in iter(self.server_process.stderr.readline, ''):
|
||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
||||
self.server_process.stderr.close()
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -143,16 +54,13 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
# Pass essential metadata to the searcher
|
||||
kwargs['meta'] = meta
|
||||
return DiskannSearcher(index_path, **kwargs)
|
||||
|
||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
@@ -215,19 +123,29 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
|
||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
self.meta = kwargs.get("meta", {})
|
||||
if not self.meta:
|
||||
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
|
||||
|
||||
dimensions = self.meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata.")
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
metric_enum = METRIC_MAP.get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.embedding_model = self.meta.get("embedding_model")
|
||||
if not self.embedding_model:
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = METRIC_MAP.get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
num_threads = kwargs.get("num_threads", 8)
|
||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||
dimensions = kwargs.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
||||
|
||||
try:
|
||||
full_index_prefix = str(index_dir / index_prefix)
|
||||
@@ -235,7 +153,9 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
||||
)
|
||||
self.num_threads = num_threads
|
||||
self.embedding_server_manager = EmbeddingServerManager()
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_diskann.embedding_server"
|
||||
)
|
||||
print("✅ DiskANN index loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
||||
@@ -255,12 +175,20 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
if recompute_beighbor_embeddings:
|
||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||
zmq_port = kwargs.get("zmq_port", 6666)
|
||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
if not self.embedding_model:
|
||||
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
||||
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
||||
zmq_port = kwargs.get("zmq_port", 6666)
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
port=zmq_port,
|
||||
model_name=self.embedding_model,
|
||||
distance_metric=self.distance_metric
|
||||
)
|
||||
|
||||
if not server_started:
|
||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
||||
kwargs['recompute_beighbor_embeddings'] = False
|
||||
recompute_beighbor_embeddings = False
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
@@ -292,4 +220,4 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'embedding_server_manager'):
|
||||
self.embedding_server_manager.stop_server()
|
||||
self.embedding_server_manager.stop_server()
|
||||
|
||||
@@ -12,6 +12,7 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
|
||||
from leann.registry import register_backend
|
||||
@@ -29,118 +30,6 @@ def get_metric_map():
|
||||
"cosine": faiss.METRIC_INNER_PRODUCT,
|
||||
}
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
"""Check if a port is in use"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
class HNSWEmbeddingServerManager:
|
||||
"""
|
||||
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
|
||||
Mirrors the DiskANN EmbeddingServerManager architecture.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
|
||||
"""
|
||||
Start the HNSW embedding server process.
|
||||
|
||||
Args:
|
||||
port: ZMQ port for the server
|
||||
model_name: Name of the embedding model to use
|
||||
passages_file: Optional path to passages JSON file
|
||||
distance_metric: The distance metric to use
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
||||
return True
|
||||
|
||||
# Check if port is already in use
|
||||
if _check_port(port):
|
||||
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
|
||||
return True
|
||||
|
||||
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", "leann_backend_hnsw.hnsw_embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name,
|
||||
"--distance-metric", distance_metric
|
||||
]
|
||||
|
||||
if passages_file:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
print(f"INFO: Running HNSW command from project root: {project_root}")
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 30, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print(f"✅ HNSW embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
|
||||
self._log_monitor()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _log_monitor(self):
|
||||
"""Monitor server logs"""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
for line in iter(self.server_process.stdout.readline, ''):
|
||||
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
|
||||
self.server_process.stdout.close()
|
||||
if self.server_process.stderr:
|
||||
for line in iter(self.server_process.stderr.readline, ''):
|
||||
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
|
||||
self.server_process.stderr.close()
|
||||
except Exception as e:
|
||||
print(f"HNSW Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
"""Stop the HNSW embedding server process"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: HNSW server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -152,16 +41,12 @@ class HNSWBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
kwargs['meta'] = meta
|
||||
return HNSWSearcher(index_path, **kwargs)
|
||||
|
||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
@@ -376,47 +261,49 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
from . import faiss
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
|
||||
# Store configuration and paths for later use
|
||||
self.config = kwargs.copy()
|
||||
self.config["index_path"] = index_path
|
||||
self.index_dir = index_dir
|
||||
self.index_prefix = index_prefix
|
||||
|
||||
metric_str = self.config.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
self.meta = kwargs.get("meta", {})
|
||||
if not self.meta:
|
||||
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
|
||||
|
||||
self.dimensions = self.meta.get("dimensions")
|
||||
if not self.dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata.")
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.embedding_model = self.meta.get("embedding_model")
|
||||
if not self.embedding_model:
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||
|
||||
path = Path(index_path)
|
||||
self.index_dir = path.parent
|
||||
self.index_prefix = path.stem
|
||||
|
||||
dimensions = self.config.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Vector dimension not provided to HNSWSearcher.")
|
||||
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
index_file = self.index_dir / f"{self.index_prefix}.index"
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||
|
||||
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
|
||||
|
||||
# Validate configuration constraints
|
||||
if not self.is_compact and self.config.get("is_skip_neighbors", False):
|
||||
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
|
||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
||||
|
||||
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
|
||||
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
|
||||
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
|
||||
|
||||
hnsw_config = faiss.HNSWIndexConfig()
|
||||
hnsw_config.is_compact = self.is_compact
|
||||
|
||||
# Apply additional configuration options with strict validation
|
||||
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
|
||||
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
|
||||
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
|
||||
hnsw_config.external_storage_path = self.config.get("external_storage_path")
|
||||
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
|
||||
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
|
||||
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
||||
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
||||
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
|
||||
|
||||
if self.is_pruned and not hnsw_config.is_recompute:
|
||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||
@@ -431,82 +318,55 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
else:
|
||||
print("✅ Standard HNSW index loaded successfully.")
|
||||
|
||||
self.metric_str = metric_str
|
||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
||||
|
||||
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
|
||||
"""Get the appropriate index file path based on format"""
|
||||
# We always use the same filename now, format is detected internally
|
||||
return index_dir / f"{index_prefix}.index"
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||
"""Search using HNSW index with optional recompute functionality"""
|
||||
from . import faiss
|
||||
# Merge config with search-time kwargs
|
||||
search_config = self.config.copy()
|
||||
search_config.update(kwargs)
|
||||
|
||||
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search
|
||||
ef = kwargs.get("ef", 200)
|
||||
|
||||
# Recompute parameters
|
||||
zmq_port = search_config.get("zmq_port", 5557)
|
||||
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
passages_file = search_config.get("passages_file", None)
|
||||
|
||||
# For recompute mode, try to find the passages file automatically
|
||||
if self.is_pruned and not passages_file:
|
||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
|
||||
if potential_passages_file.exists():
|
||||
passages_file = str(potential_passages_file)
|
||||
print(f"INFO: Found passages file for recompute mode: {passages_file}")
|
||||
else:
|
||||
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
|
||||
|
||||
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
|
||||
if self.is_pruned:
|
||||
print(f"INFO: Index is pruned - starting embedding server for recompute")
|
||||
|
||||
# CRITICAL: Check passages file exists - fail fast if not
|
||||
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
|
||||
if not self.embedding_model:
|
||||
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if not passages_file:
|
||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
|
||||
|
||||
# Check if server is already running first
|
||||
if _check_port(zmq_port):
|
||||
print(f"INFO: Embedding server already running on port {zmq_port}")
|
||||
else:
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
|
||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
||||
|
||||
# Give server extra time to fully initialize
|
||||
print(f"INFO: Waiting for embedding server to fully initialize...")
|
||||
time.sleep(3)
|
||||
|
||||
# Final verification
|
||||
if not _check_port(zmq_port):
|
||||
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
|
||||
else:
|
||||
print(f"INFO: Index has embeddings stored - no recompute needed")
|
||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||
if potential_passages_file.exists():
|
||||
passages_file = str(potential_passages_file)
|
||||
print(f"INFO: Automatically found passages file: {passages_file}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
|
||||
|
||||
zmq_port = kwargs.get("zmq_port", 5557)
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
port=zmq_port,
|
||||
model_name=self.embedding_model,
|
||||
passages_file=passages_file,
|
||||
distance_metric=self.distance_metric
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if query.ndim == 1:
|
||||
query = np.expand_dims(query, axis=0)
|
||||
|
||||
# Normalize query if using cosine similarity
|
||||
if self.metric_str == "cosine":
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
|
||||
try:
|
||||
# Set search parameter
|
||||
self._index.hnsw.efSearch = ef
|
||||
|
||||
# Prepare output arrays for the older FAISS SWIG API
|
||||
batch_size = query.shape[0]
|
||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
||||
|
||||
# Use standard FAISS search - recompute is handled internally by FAISS
|
||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
||||
|
||||
return {"labels": labels, "distances": distances}
|
||||
@@ -517,4 +377,4 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'embedding_server_manager'):
|
||||
self.embedding_server_manager.stop_server()
|
||||
self.embedding_server_manager.stop_server()
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# This file makes the 'leann' directory a Python package.
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult
|
||||
|
||||
# Import backends to ensure they are registered
|
||||
try:
|
||||
import leann_backend_hnsw
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']
|
||||
|
||||
132
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
132
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
@@ -0,0 +1,132 @@
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
"""Check if a port is in use"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||
"""
|
||||
def __init__(self, backend_module_name: str):
|
||||
"""
|
||||
Initializes the manager for a specific backend.
|
||||
|
||||
Args:
|
||||
backend_module_name (str): The full module name of the backend's server script.
|
||||
e.g., "leann_backend_diskann.embedding_server"
|
||||
"""
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port: int, model_name: str, **kwargs) -> bool:
|
||||
"""
|
||||
Starts the embedding server process.
|
||||
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
||||
return True
|
||||
|
||||
if _check_port(port):
|
||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.")
|
||||
return True
|
||||
|
||||
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...")
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", self.backend_module_name,
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
]
|
||||
|
||||
# Add extra arguments for specific backends
|
||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
# stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 30, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print(f"✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
||||
self._log_monitor()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _log_monitor(self):
|
||||
"""Monitors and prints the server's stdout and stderr."""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
for line in iter(self.server_process.stdout.readline, ''):
|
||||
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
|
||||
self.server_process.stdout.close()
|
||||
if self.server_process.stderr:
|
||||
for line in iter(self.server_process.stderr.readline, ''):
|
||||
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
|
||||
self.server_process.stderr.close()
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
Reference in New Issue
Block a user