fix: do not reuse emb_server and close it properly
This commit is contained in:
@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels]
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
for batch_labels in labels
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
|
|
||||||
# Load label map
|
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||||
passages_dir = Path(meta_file).parent
|
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
if label_map_file.exists():
|
|
||||||
import pickle
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
|
||||||
|
|
||||||
class LazyPassageLoader(SimplePassageLoader):
|
class LazyPassageLoader(SimplePassageLoader):
|
||||||
def __init__(self, passage_manager, label_map):
|
def __init__(self, passage_manager):
|
||||||
self.passage_manager = passage_manager
|
self.passage_manager = passage_manager
|
||||||
self.label_map = label_map
|
|
||||||
# Initialize parent with empty data
|
# Initialize parent with empty data
|
||||||
super().__init__({})
|
super().__init__({})
|
||||||
|
|
||||||
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
"""Get passage by ID with lazy loading"""
|
"""Get passage by ID with lazy loading"""
|
||||||
try:
|
try:
|
||||||
int_id = int(passage_id)
|
int_id = int(passage_id)
|
||||||
if int_id in self.label_map:
|
string_id = str(int_id)
|
||||||
string_id = self.label_map[int_id]
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
if passage_data and passage_data.get("text"):
|
||||||
if passage_data and passage_data.get("text"):
|
return {"text": passage_data["text"]}
|
||||||
return {"text": passage_data["text"]}
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.label_map)
|
return len(self.passage_manager.global_offset_map)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self.label_map.keys()
|
return self.passage_manager.global_offset_map.keys()
|
||||||
|
|
||||||
loader = LazyPassageLoader(passage_manager, label_map)
|
loader = LazyPassageLoader(passage_manager)
|
||||||
loader._meta_path = meta_file
|
loader._meta_path = meta_file
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|||||||
if not passages_file.endswith('.jsonl'):
|
if not passages_file.endswith('.jsonl'):
|
||||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||||
|
|
||||||
# Load label map (int -> string_id)
|
# Load passages directly by their sequential IDs
|
||||||
passages_dir = Path(passages_file).parent
|
passages_data = {}
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
label_map = {}
|
|
||||||
if label_map_file.exists():
|
|
||||||
with open(label_map_file, 'rb') as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
# Load passages by string ID
|
|
||||||
string_id_passages = {}
|
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
if line.strip():
|
if line.strip():
|
||||||
passage = json.loads(line)
|
passage = json.loads(line)
|
||||||
string_id_passages[passage['id']] = passage['text']
|
passages_data[passage['id']] = passage['text']
|
||||||
|
|
||||||
# Create int ID -> text mapping using label map
|
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||||
passages_data = {}
|
|
||||||
for int_id, string_id in label_map.items():
|
|
||||||
if string_id in string_id_passages:
|
|
||||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
|
||||||
else:
|
|
||||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
|
||||||
|
|
||||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
|
||||||
return SimplePassageLoader(passages_data)
|
return SimplePassageLoader(passages_data)
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
def create_embedding_server_thread(
|
||||||
|
|||||||
@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
|
||||||
label_map_file = index_dir / "leann.labels.map"
|
|
||||||
with open(label_map_file, "wb") as f:
|
|
||||||
pickle.dump(label_map, f)
|
|
||||||
|
|
||||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
# Load label mapping
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
|
||||||
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
self.label_map = pickle.load(f)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[
|
[str(int_label) for int_label in batch_labels]
|
||||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
|
||||||
for int_label in batch_labels
|
|
||||||
]
|
|
||||||
for batch_labels in labels
|
for batch_labels in labels
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -114,25 +114,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
|
|
||||||
# Load label map
|
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||||
passages_dir = Path(meta_file).parent
|
|
||||||
label_map_file = passages_dir / "leann.labels.map"
|
|
||||||
|
|
||||||
if label_map_file.exists():
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
label_map = pickle.load(f)
|
|
||||||
print(f"Loaded label map with {len(label_map)} entries")
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
|
|
||||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
|
||||||
|
|
||||||
class LazyPassageLoader(SimplePassageLoader):
|
class LazyPassageLoader(SimplePassageLoader):
|
||||||
def __init__(self, passage_manager, label_map):
|
def __init__(self, passage_manager):
|
||||||
self.passage_manager = passage_manager
|
self.passage_manager = passage_manager
|
||||||
self.label_map = label_map
|
|
||||||
# Initialize parent with empty data
|
# Initialize parent with empty data
|
||||||
super().__init__({})
|
super().__init__({})
|
||||||
|
|
||||||
@@ -140,28 +126,24 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
"""Get passage by ID with lazy loading"""
|
"""Get passage by ID with lazy loading"""
|
||||||
try:
|
try:
|
||||||
int_id = int(passage_id)
|
int_id = int(passage_id)
|
||||||
if int_id in self.label_map:
|
string_id = str(int_id)
|
||||||
string_id = self.label_map[int_id]
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
if passage_data and passage_data.get("text"):
|
||||||
if passage_data and passage_data.get("text"):
|
return {"text": passage_data["text"]}
|
||||||
return {"text": passage_data["text"]}
|
|
||||||
else:
|
|
||||||
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
|
|
||||||
return {"text": ""}
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"ID {int_id} not found in label_map")
|
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
|
||||||
return {"text": ""}
|
return {"text": ""}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Exception getting passage {passage_id}: {e}")
|
logger.debug(f"Exception getting passage {passage_id}: {e}")
|
||||||
return {"text": ""}
|
return {"text": ""}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.label_map)
|
return len(self.passage_manager.global_offset_map)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self.label_map.keys()
|
return self.passage_manager.global_offset_map.keys()
|
||||||
|
|
||||||
return LazyPassageLoader(passage_manager, label_map)
|
return LazyPassageLoader(passage_manager)
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class LeannBuilder:
|
|||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||||
self.chunks.append(chunk_data)
|
self.chunks.append(chunk_data)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import sys
|
|||||||
import zmq
|
import zmq
|
||||||
import msgpack
|
import msgpack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
import select
|
import select
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
@@ -156,7 +156,7 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
atexit.register(self.stop_server)
|
self._atexit_registered = False
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -259,6 +259,12 @@ class EmbeddingServerManager:
|
|||||||
self.server_port = port
|
self.server_port = port
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback only when we actually start a process
|
||||||
|
if not self._atexit_registered:
|
||||||
|
# Use a lambda to avoid issues with bound methods
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
max_wait, wait_interval = 120, 0.5
|
max_wait, wait_interval = 120, 0.5
|
||||||
@@ -309,17 +315,22 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
"""Stops the embedding server process if it's running."""
|
"""Stops the embedding server process if it's running."""
|
||||||
if self.server_process and self.server_process.poll() is None:
|
if not self.server_process:
|
||||||
print(
|
return
|
||||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
|
||||||
)
|
if self.server_process.poll() is not None:
|
||||||
self.server_process.terminate()
|
# Process already terminated
|
||||||
try:
|
self.server_process = None
|
||||||
self.server_process.wait(timeout=5)
|
return
|
||||||
print("INFO: Server process terminated.")
|
|
||||||
except subprocess.TimeoutExpired:
|
print(f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}...")
|
||||||
print(
|
self.server_process.terminate()
|
||||||
"WARNING: Server process did not terminate gracefully, killing it."
|
|
||||||
)
|
try:
|
||||||
self.server_process.kill()
|
self.server_process.wait(timeout=5)
|
||||||
|
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it.")
|
||||||
|
self.server_process.kill()
|
||||||
|
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.label_map = self._load_label_map()
|
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name
|
backend_module_name=backend_module_name
|
||||||
@@ -58,13 +57,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
with open(meta_path, "r", encoding="utf-8") as f:
|
with open(meta_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _load_label_map(self) -> Dict[int, str]:
|
|
||||||
"""Loads the mapping from integer IDs to string IDs."""
|
|
||||||
label_map_file = self.index_dir / "leann.labels.map"
|
|
||||||
if not label_map_file.exists():
|
|
||||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
||||||
with open(label_map_file, "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
|
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: int, **kwargs
|
self, passages_source_file: str, port: int, **kwargs
|
||||||
@@ -110,12 +102,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
Query embedding as numpy array
|
Query embedding as numpy array
|
||||||
"""
|
"""
|
||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if (
|
if use_server_if_available:
|
||||||
use_server_if_available
|
|
||||||
and self.embedding_server_manager
|
|
||||||
and self.embedding_server_manager.server_process
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
|
# Ensure we have a server with passages_file for compatibility
|
||||||
|
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
self._ensure_server_running(str(passages_source_file), zmq_port)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
0:1
|
0:1
|
||||||
] # Return (1, D) shape
|
] # Return (1, D) shape
|
||||||
|
|||||||
Reference in New Issue
Block a user