fix: do not reuse emb_server and close it properly

This commit is contained in:
Andy Lee
2025-07-20 18:07:51 -07:00
parent f4998bb316
commit 7e226a51c9
7 changed files with 58 additions and 130 deletions

View File

@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
]

View File

@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
def __init__(self, passage_manager):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
string_id = str(int_id)
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
return len(self.passage_manager.global_offset_map)
def keys(self):
return self.label_map.keys()
return self.passage_manager.global_offset_map.keys()
loader = LazyPassageLoader(passage_manager, label_map)
loader = LazyPassageLoader(passage_manager)
loader._meta_path = meta_file
return loader
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
# Load passages directly by their sequential IDs
passages_data = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
passages_data[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread(

View File

@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if data.dtype != np.float32:
data = data.astype(np.float32)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load label mapping
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
with open(label_map_file, "rb") as f:
self.label_map = pickle.load(f)
def search(
self,
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
]

View File

@@ -114,25 +114,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, "rb") as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
def __init__(self, passage_manager):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
@@ -140,28 +126,24 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
string_id = str(int_id)
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
logger.debug(f"ID {int_id} not found in label_map")
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
except Exception as e:
logger.debug(f"Exception getting passage {passage_id}: {e}")
return {"text": ""}
def __len__(self) -> int:
return len(self.label_map)
return len(self.passage_manager.global_offset_map)
def keys(self):
return self.label_map.keys()
return self.passage_manager.global_offset_map.keys()
return LazyPassageLoader(passage_manager, label_map)
return LazyPassageLoader(passage_manager)
def create_hnsw_embedding_server(

View File

@@ -354,7 +354,7 @@ class LeannBuilder:
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(uuid.uuid4()))
passage_id = metadata.get("id", str(len(self.chunks)))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)

View File

@@ -7,7 +7,7 @@ import sys
import zmq
import msgpack
from pathlib import Path
from typing import Optional
from typing import Optional, Dict
import select
import psutil
@@ -156,7 +156,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
atexit.register(self.stop_server)
self._atexit_registered = False
def start_server(
self,
@@ -258,6 +258,12 @@ class EmbeddingServerManager:
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
@@ -309,17 +315,22 @@ class EmbeddingServerManager:
def stop_server(self):
"""Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None:
print(
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print(
"WARNING: Server process did not terminate gracefully, killing it."
)
self.server_process.kill()
if not self.server_process:
return
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
return
print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None

View File

@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.label_map = self._load_label_map()
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name
@@ -58,13 +57,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def _load_label_map(self) -> Dict[int, str]:
"""Loads the mapping from integer IDs to string IDs."""
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, "rb") as f:
return pickle.load(f)
def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
@@ -110,12 +102,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
Query embedding as numpy array
"""
# Try to use embedding server if available and requested
if (
use_server_if_available
and self.embedding_server_manager
and self.embedding_server_manager.server_process
):
if use_server_if_available:
try:
# Ensure we have a server with passages_file for compatibility
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
self._ensure_server_running(str(passages_source_file), zmq_port)
return self._compute_embedding_via_server([query], zmq_port)[
0:1
] # Return (1, D) shape