refactor: embedding server manager
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
from . import diskann_backend
|
||||||
@@ -12,6 +12,7 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
LeannBackendFactoryInterface,
|
||||||
@@ -42,96 +43,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
|||||||
f.write(struct.pack('I', dim))
|
f.write(struct.pack('I', dim))
|
||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
|
||||||
def __init__(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查端口是否已被其他无关进程占用
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name
|
|
||||||
]
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running command from project root: {project_root}")
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
# stdout=subprocess.PIPE,
|
|
||||||
# stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"✅ Embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
|
||||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
|
||||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
|
||||||
self.server_process.stderr.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -143,16 +54,13 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
|||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
meta_path = path.parent / f"{path.name}.meta.json"
|
||||||
if not meta_path.exists():
|
if not meta_path.exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
|
||||||
|
|
||||||
with open(meta_path, 'r') as f:
|
with open(meta_path, 'r') as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
dimensions = meta.get("dimensions")
|
# Pass essential metadata to the searcher
|
||||||
if not dimensions:
|
kwargs['meta'] = meta
|
||||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
|
||||||
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
return DiskannSearcher(index_path, **kwargs)
|
return DiskannSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||||
@@ -215,19 +123,29 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
|
|
||||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
class DiskannSearcher(LeannBackendSearcherInterface):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
|
self.meta = kwargs.get("meta", {})
|
||||||
|
if not self.meta:
|
||||||
|
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
|
||||||
|
|
||||||
|
dimensions = self.meta.get("dimensions")
|
||||||
|
if not dimensions:
|
||||||
|
raise ValueError("Dimensions not found in Leann metadata.")
|
||||||
|
|
||||||
|
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||||
|
metric_enum = METRIC_MAP.get(self.distance_metric)
|
||||||
|
if metric_enum is None:
|
||||||
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
|
if not self.embedding_model:
|
||||||
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
|
||||||
metric_enum = METRIC_MAP.get(metric_str)
|
|
||||||
if metric_enum is None:
|
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
|
||||||
|
|
||||||
num_threads = kwargs.get("num_threads", 8)
|
num_threads = kwargs.get("num_threads", 8)
|
||||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||||
dimensions = kwargs.get("dimensions")
|
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
full_index_prefix = str(index_dir / index_prefix)
|
full_index_prefix = str(index_dir / index_prefix)
|
||||||
@@ -235,7 +153,9 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
||||||
)
|
)
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.embedding_server_manager = EmbeddingServerManager()
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_diskann.embedding_server"
|
||||||
|
)
|
||||||
print("✅ DiskANN index loaded successfully.")
|
print("✅ DiskANN index loaded successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
||||||
@@ -255,12 +175,20 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
if recompute_beighbor_embeddings:
|
if recompute_beighbor_embeddings:
|
||||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||||
zmq_port = kwargs.get("zmq_port", 6666)
|
if not self.embedding_model:
|
||||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
|
||||||
|
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
zmq_port = kwargs.get("zmq_port", 6666)
|
||||||
|
|
||||||
|
server_started = self.embedding_server_manager.start_server(
|
||||||
|
port=zmq_port,
|
||||||
|
model_name=self.embedding_model,
|
||||||
|
distance_metric=self.distance_metric
|
||||||
|
)
|
||||||
|
|
||||||
|
if not server_started:
|
||||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
||||||
kwargs['recompute_beighbor_embeddings'] = False
|
recompute_beighbor_embeddings = False
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
@@ -292,4 +220,4 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
if hasattr(self, 'embedding_server_manager'):
|
||||||
self.embedding_server_manager.stop_server()
|
self.embedding_server_manager.stop_server()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
|
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
@@ -29,118 +30,6 @@ def get_metric_map():
|
|||||||
"cosine": faiss.METRIC_INNER_PRODUCT,
|
"cosine": faiss.METRIC_INNER_PRODUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
|
||||||
"""Check if a port is in use"""
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
class HNSWEmbeddingServerManager:
|
|
||||||
"""
|
|
||||||
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
|
|
||||||
Mirrors the DiskANN EmbeddingServerManager architecture.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
atexit.register(self.stop_server)
|
|
||||||
|
|
||||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
|
|
||||||
"""
|
|
||||||
Start the HNSW embedding server process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port: ZMQ port for the server
|
|
||||||
model_name: Name of the embedding model to use
|
|
||||||
passages_file: Optional path to passages JSON file
|
|
||||||
distance_metric: The distance metric to use
|
|
||||||
"""
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if port is already in use
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
command = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "leann_backend_hnsw.hnsw_embedding_server",
|
|
||||||
"--zmq-port", str(port),
|
|
||||||
"--model-name", model_name,
|
|
||||||
"--distance-metric", distance_metric
|
|
||||||
]
|
|
||||||
|
|
||||||
if passages_file:
|
|
||||||
command.extend(["--passages-file", str(passages_file)])
|
|
||||||
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
print(f"INFO: Running HNSW command from project root: {project_root}")
|
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
cwd=project_root,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
max_wait, wait_interval = 30, 0.5
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
print(f"✅ HNSW embedding server is up and ready for this session.")
|
|
||||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
|
||||||
log_thread.start()
|
|
||||||
return True
|
|
||||||
if self.server_process.poll() is not None:
|
|
||||||
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
|
|
||||||
self._log_monitor()
|
|
||||||
return False
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
"""Monitor server logs"""
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
for line in iter(self.server_process.stdout.readline, ''):
|
|
||||||
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
|
|
||||||
self.server_process.stdout.close()
|
|
||||||
if self.server_process.stderr:
|
|
||||||
for line in iter(self.server_process.stderr.readline, ''):
|
|
||||||
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
|
|
||||||
self.server_process.stderr.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"HNSW Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
|
||||||
"""Stop the HNSW embedding server process"""
|
|
||||||
if self.server_process and self.server_process.poll() is None:
|
|
||||||
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
|
|
||||||
self.server_process.terminate()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=5)
|
|
||||||
print("INFO: HNSW server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
|
|
||||||
self.server_process.kill()
|
|
||||||
self.server_process = None
|
|
||||||
|
|
||||||
@register_backend("hnsw")
|
@register_backend("hnsw")
|
||||||
class HNSWBackend(LeannBackendFactoryInterface):
|
class HNSWBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -152,16 +41,12 @@ class HNSWBackend(LeannBackendFactoryInterface):
|
|||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
meta_path = path.parent / f"{path.name}.meta.json"
|
meta_path = path.parent / f"{path.name}.meta.json"
|
||||||
if not meta_path.exists():
|
if not meta_path.exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
|
||||||
|
|
||||||
with open(meta_path, 'r') as f:
|
with open(meta_path, 'r') as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
dimensions = meta.get("dimensions")
|
kwargs['meta'] = meta
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
|
||||||
|
|
||||||
kwargs['dimensions'] = dimensions
|
|
||||||
return HNSWSearcher(index_path, **kwargs)
|
return HNSWSearcher(index_path, **kwargs)
|
||||||
|
|
||||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||||
@@ -376,47 +261,49 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
from . import faiss
|
from . import faiss
|
||||||
path = Path(index_path)
|
self.meta = kwargs.get("meta", {})
|
||||||
index_dir = path.parent
|
if not self.meta:
|
||||||
index_prefix = path.stem
|
raise ValueError("HNSWSearcher requires metadata from .meta.json.")
|
||||||
|
|
||||||
# Store configuration and paths for later use
|
self.dimensions = self.meta.get("dimensions")
|
||||||
self.config = kwargs.copy()
|
if not self.dimensions:
|
||||||
self.config["index_path"] = index_path
|
raise ValueError("Dimensions not found in Leann metadata.")
|
||||||
self.index_dir = index_dir
|
|
||||||
self.index_prefix = index_prefix
|
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||||
|
metric_enum = get_metric_map().get(self.distance_metric)
|
||||||
metric_str = self.config.get("distance_metric", "mips").lower()
|
|
||||||
metric_enum = get_metric_map().get(metric_str)
|
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
|
if not self.embedding_model:
|
||||||
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||||
|
|
||||||
|
path = Path(index_path)
|
||||||
|
self.index_dir = path.parent
|
||||||
|
self.index_prefix = path.stem
|
||||||
|
|
||||||
dimensions = self.config.get("dimensions")
|
index_file = self.index_dir / f"{self.index_prefix}.index"
|
||||||
if not dimensions:
|
|
||||||
raise ValueError("Vector dimension not provided to HNSWSearcher.")
|
|
||||||
|
|
||||||
index_file = index_dir / f"{index_prefix}.index"
|
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||||
|
|
||||||
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
|
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
|
||||||
|
|
||||||
# Validate configuration constraints
|
# Validate configuration constraints
|
||||||
if not self.is_compact and self.config.get("is_skip_neighbors", False):
|
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
|
||||||
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
|
||||||
|
|
||||||
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
|
if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
|
||||||
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
|
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
|
||||||
|
|
||||||
hnsw_config = faiss.HNSWIndexConfig()
|
hnsw_config = faiss.HNSWIndexConfig()
|
||||||
hnsw_config.is_compact = self.is_compact
|
hnsw_config.is_compact = self.is_compact
|
||||||
|
|
||||||
# Apply additional configuration options with strict validation
|
# Apply additional configuration options with strict validation
|
||||||
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
|
hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
|
||||||
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
|
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
|
||||||
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
|
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
|
||||||
hnsw_config.external_storage_path = self.config.get("external_storage_path")
|
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
|
||||||
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
|
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
|
|
||||||
if self.is_pruned and not hnsw_config.is_recompute:
|
if self.is_pruned and not hnsw_config.is_recompute:
|
||||||
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
|
||||||
@@ -431,82 +318,55 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
else:
|
else:
|
||||||
print("✅ Standard HNSW index loaded successfully.")
|
print("✅ Standard HNSW index loaded successfully.")
|
||||||
|
|
||||||
self.metric_str = metric_str
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
|
|
||||||
"""Get the appropriate index file path based on format"""
|
|
||||||
# We always use the same filename now, format is detected internally
|
|
||||||
return index_dir / f"{index_prefix}.index"
|
|
||||||
|
|
||||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
|
||||||
"""Search using HNSW index with optional recompute functionality"""
|
"""Search using HNSW index with optional recompute functionality"""
|
||||||
from . import faiss
|
from . import faiss
|
||||||
# Merge config with search-time kwargs
|
|
||||||
search_config = self.config.copy()
|
|
||||||
search_config.update(kwargs)
|
|
||||||
|
|
||||||
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search
|
ef = kwargs.get("ef", 200)
|
||||||
|
|
||||||
# Recompute parameters
|
|
||||||
zmq_port = search_config.get("zmq_port", 5557)
|
|
||||||
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
|
||||||
passages_file = search_config.get("passages_file", None)
|
|
||||||
|
|
||||||
# For recompute mode, try to find the passages file automatically
|
|
||||||
if self.is_pruned and not passages_file:
|
|
||||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
|
||||||
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
|
|
||||||
if potential_passages_file.exists():
|
|
||||||
passages_file = str(potential_passages_file)
|
|
||||||
print(f"INFO: Found passages file for recompute mode: {passages_file}")
|
|
||||||
else:
|
|
||||||
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
|
|
||||||
|
|
||||||
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
|
|
||||||
if self.is_pruned:
|
if self.is_pruned:
|
||||||
print(f"INFO: Index is pruned - starting embedding server for recompute")
|
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
|
||||||
|
if not self.embedding_model:
|
||||||
# CRITICAL: Check passages file exists - fail fast if not
|
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||||
|
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
if not passages_file:
|
if not passages_file:
|
||||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
|
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||||
|
if potential_passages_file.exists():
|
||||||
# Check if server is already running first
|
passages_file = str(potential_passages_file)
|
||||||
if _check_port(zmq_port):
|
print(f"INFO: Automatically found passages file: {passages_file}")
|
||||||
print(f"INFO: Embedding server already running on port {zmq_port}")
|
else:
|
||||||
else:
|
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
|
||||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str):
|
|
||||||
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
zmq_port = kwargs.get("zmq_port", 5557)
|
||||||
|
server_started = self.embedding_server_manager.start_server(
|
||||||
# Give server extra time to fully initialize
|
port=zmq_port,
|
||||||
print(f"INFO: Waiting for embedding server to fully initialize...")
|
model_name=self.embedding_model,
|
||||||
time.sleep(3)
|
passages_file=passages_file,
|
||||||
|
distance_metric=self.distance_metric
|
||||||
# Final verification
|
)
|
||||||
if not _check_port(zmq_port):
|
if not server_started:
|
||||||
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
|
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
|
||||||
else:
|
|
||||||
print(f"INFO: Index has embeddings stored - no recompute needed")
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if query.ndim == 1:
|
if query.ndim == 1:
|
||||||
query = np.expand_dims(query, axis=0)
|
query = np.expand_dims(query, axis=0)
|
||||||
|
|
||||||
# Normalize query if using cosine similarity
|
if self.distance_metric == "cosine":
|
||||||
if self.metric_str == "cosine":
|
|
||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set search parameter
|
|
||||||
self._index.hnsw.efSearch = ef
|
self._index.hnsw.efSearch = ef
|
||||||
|
|
||||||
# Prepare output arrays for the older FAISS SWIG API
|
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
distances = np.empty((batch_size, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
labels = np.empty((batch_size, top_k), dtype=np.int64)
|
||||||
|
|
||||||
# Use standard FAISS search - recompute is handled internally by FAISS
|
|
||||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
|
||||||
|
|
||||||
return {"labels": labels, "distances": distances}
|
return {"labels": labels, "distances": distances}
|
||||||
@@ -517,4 +377,4 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, 'embedding_server_manager'):
|
if hasattr(self, 'embedding_server_manager'):
|
||||||
self.embedding_server_manager.stop_server()
|
self.embedding_server_manager.stop_server()
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
# This file makes the 'leann' directory a Python package.
|
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult
|
|
||||||
|
|
||||||
# Import backends to ensure they are registered
|
|
||||||
try:
|
|
||||||
import leann_backend_hnsw
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
import leann_backend_diskann
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']
|
|
||||||
|
|||||||
132
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
132
packages/leann-core/src/leann/embedding_server_manager.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import atexit
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def _check_port(port: int) -> bool:
|
||||||
|
"""Check if a port is in use"""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(('localhost', port)) == 0
|
||||||
|
|
||||||
|
class EmbeddingServerManager:
|
||||||
|
"""
|
||||||
|
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||||
|
"""
|
||||||
|
def __init__(self, backend_module_name: str):
|
||||||
|
"""
|
||||||
|
Initializes the manager for a specific backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend_module_name (str): The full module name of the backend's server script.
|
||||||
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
|
"""
|
||||||
|
self.backend_module_name = backend_module_name
|
||||||
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
|
self.server_port: Optional[int] = None
|
||||||
|
atexit.register(self.stop_server)
|
||||||
|
|
||||||
|
def start_server(self, port: int, model_name: str, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
Starts the embedding server process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port (int): The ZMQ port for the server.
|
||||||
|
model_name (str): The name of the embedding model to use.
|
||||||
|
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the server is started successfully or already running, False otherwise.
|
||||||
|
"""
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if _check_port(port):
|
||||||
|
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m", self.backend_module_name,
|
||||||
|
"--zmq-port", str(port),
|
||||||
|
"--model-name", model_name
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add extra arguments for specific backends
|
||||||
|
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||||
|
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||||
|
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||||
|
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||||
|
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
|
print(f"INFO: Running command from project root: {project_root}")
|
||||||
|
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
cwd=project_root,
|
||||||
|
# stdout=subprocess.PIPE,
|
||||||
|
# stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
self.server_port = port
|
||||||
|
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
max_wait, wait_interval = 30, 0.5
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
print(f"✅ Embedding server is up and ready for this session.")
|
||||||
|
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||||
|
log_thread.start()
|
||||||
|
return True
|
||||||
|
if self.server_process.poll() is not None:
|
||||||
|
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
||||||
|
self._log_monitor()
|
||||||
|
return False
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _log_monitor(self):
|
||||||
|
"""Monitors and prints the server's stdout and stderr."""
|
||||||
|
if not self.server_process:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if self.server_process.stdout:
|
||||||
|
for line in iter(self.server_process.stdout.readline, ''):
|
||||||
|
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
|
||||||
|
self.server_process.stdout.close()
|
||||||
|
if self.server_process.stderr:
|
||||||
|
for line in iter(self.server_process.stderr.readline, ''):
|
||||||
|
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
|
||||||
|
self.server_process.stderr.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Log monitor error: {e}")
|
||||||
|
|
||||||
|
def stop_server(self):
|
||||||
|
"""Stops the embedding server process if it's running."""
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
||||||
|
self.server_process.terminate()
|
||||||
|
try:
|
||||||
|
self.server_process.wait(timeout=5)
|
||||||
|
print("INFO: Server process terminated.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print("WARNING: Server process did not terminate gracefully, killing it.")
|
||||||
|
self.server_process.kill()
|
||||||
|
self.server_process = None
|
||||||
Reference in New Issue
Block a user