fix: do not reuse emb_server and close it properly

This commit is contained in:
Andy Lee
2025-07-20 18:07:51 -07:00
parent f4998bb316
commit 7e226a51c9
7 changed files with 58 additions and 130 deletions

View File

@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get( metric_enum = _get_diskann_metrics().get(
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
) )
string_labels = [ string_labels = [
[ [str(int_label) for int_label in batch_labels]
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels for batch_labels in labels
] ]

View File

@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
finally: finally:
sys.path.pop(0) sys.path.pop(0)
# Load label map print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader): class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map): def __init__(self, passage_manager):
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data # Initialize parent with empty data
super().__init__({}) super().__init__({})
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""Get passage by ID with lazy loading""" """Get passage by ID with lazy loading"""
try: try:
int_id = int(passage_id) int_id = int(passage_id)
if int_id in self.label_map: string_id = str(int_id)
string_id = self.label_map[int_id] passage_data = self.passage_manager.get_passage(string_id)
passage_data = self.passage_manager.get_passage(string_id) if passage_data and passage_data.get("text"):
if passage_data and passage_data.get("text"): return {"text": passage_data["text"]}
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else: else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map") raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
except Exception as e: except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}") raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int: def __len__(self) -> int:
return len(self.label_map) return len(self.passage_manager.global_offset_map)
def keys(self): def keys(self):
return self.label_map.keys() return self.passage_manager.global_offset_map.keys()
loader = LazyPassageLoader(passage_manager, label_map) loader = LazyPassageLoader(passage_manager)
loader._meta_path = meta_file loader._meta_path = meta_file
return loader return loader
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if not passages_file.endswith('.jsonl'): if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}") raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id) # Load passages directly by their sequential IDs
passages_dir = Path(passages_file).parent passages_data = {}
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f: with open(passages_file, 'r', encoding='utf-8') as f:
for line in f: for line in f:
if line.strip(): if line.strip():
passage = json.loads(line) passage = json.loads(line)
string_id_passages[passage['id']] = passage['text'] passages_data[passage['id']] = passage['text']
# Create int ID -> text mapping using label map print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data) return SimplePassageLoader(passages_data)
def create_embedding_server_thread( def create_embedding_server_thread(

View File

@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if data.dtype != np.float32: if data.dtype != np.float32:
data = data.astype(np.float32) data = data.astype(np.float32)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower()) metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None: if metric_enum is None:
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load label mapping
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
with open(label_map_file, "rb") as f:
self.label_map = pickle.load(f)
def search( def search(
self, self,
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
) )
string_labels = [ string_labels = [
[ [str(int_label) for int_label in batch_labels]
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
for batch_labels in labels for batch_labels in labels
] ]

View File

@@ -114,25 +114,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
finally: finally:
sys.path.pop(0) sys.path.pop(0)
# Load label map print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, "rb") as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader): class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map): def __init__(self, passage_manager):
self.passage_manager = passage_manager self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data # Initialize parent with empty data
super().__init__({}) super().__init__({})
@@ -140,28 +126,24 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""Get passage by ID with lazy loading""" """Get passage by ID with lazy loading"""
try: try:
int_id = int(passage_id) int_id = int(passage_id)
if int_id in self.label_map: string_id = str(int_id)
string_id = self.label_map[int_id] passage_data = self.passage_manager.get_passage(string_id)
passage_data = self.passage_manager.get_passage(string_id) if passage_data and passage_data.get("text"):
if passage_data and passage_data.get("text"): return {"text": passage_data["text"]}
return {"text": passage_data["text"]}
else:
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
else: else:
logger.debug(f"ID {int_id} not found in label_map") logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""} return {"text": ""}
except Exception as e: except Exception as e:
logger.debug(f"Exception getting passage {passage_id}: {e}") logger.debug(f"Exception getting passage {passage_id}: {e}")
return {"text": ""} return {"text": ""}
def __len__(self) -> int: def __len__(self) -> int:
return len(self.label_map) return len(self.passage_manager.global_offset_map)
def keys(self): def keys(self):
return self.label_map.keys() return self.passage_manager.global_offset_map.keys()
return LazyPassageLoader(passage_manager, label_map) return LazyPassageLoader(passage_manager)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(

View File

@@ -354,7 +354,7 @@ class LeannBuilder:
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
passage_id = metadata.get("id", str(uuid.uuid4())) passage_id = metadata.get("id", str(len(self.chunks)))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata} chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data) self.chunks.append(chunk_data)

View File

@@ -7,7 +7,7 @@ import sys
import zmq import zmq
import msgpack import msgpack
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Dict
import select import select
import psutil import psutil
@@ -156,7 +156,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None self.server_port: Optional[int] = None
atexit.register(self.stop_server) self._atexit_registered = False
def start_server( def start_server(
self, self,
@@ -258,6 +258,12 @@ class EmbeddingServerManager:
) )
self.server_port = port self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}") print(f"INFO: Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready.""" """Wait for the server to be ready."""
@@ -309,17 +315,22 @@ class EmbeddingServerManager:
def stop_server(self): def stop_server(self):
"""Stops the embedding server process if it's running.""" """Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None: if not self.server_process:
print( return
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
) if self.server_process.poll() is not None:
self.server_process.terminate() # Process already terminated
try: self.server_process = None
self.server_process.wait(timeout=5) return
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired: print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...")
print( self.server_process.terminate()
"WARNING: Server process did not terminate gracefully, killing it."
) try:
self.server_process.kill() self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None self.server_process = None

View File

@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail." "WARNING: embedding_model not found in meta.json. Recompute will fail."
) )
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name backend_module_name=backend_module_name
@@ -58,13 +57,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f: with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, "rb") as f:
return pickle.load(f)
def _ensure_server_running( def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs self, passages_source_file: str, port: int, **kwargs
@@ -110,12 +102,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
Query embedding as numpy array Query embedding as numpy array
""" """
# Try to use embedding server if available and requested # Try to use embedding server if available and requested
if ( if use_server_if_available:
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
try: try:
# Ensure we have a server with passages_file for compatibility
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
self._ensure_server_running(str(passages_source_file), zmq_port)
return self._compute_embedding_via_server([query], zmq_port)[ return self._compute_embedding_via_server([query], zmq_port)[
0:1 0:1
] # Return (1, D) shape ] # Return (1, D) shape