fix: do not reuse emb_server and close it properly

This commit is contained in:
Andy Lee
2025-07-20 18:07:51 -07:00
parent f4998bb316
commit 7e226a51c9
7 changed files with 58 additions and 130 deletions

View File

@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, "wb") as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs}
metric_enum = _get_diskann_metrics().get(
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
)
string_labels = [
[
self.label_map.get(int_label, f"unknown_{int_label}")
for int_label in batch_labels
]
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
]

View File

@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
def __init__(self, passage_manager):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
string_id = str(int_id)
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
return len(self.passage_manager.global_offset_map)
def keys(self):
return self.label_map.keys()
return self.passage_manager.global_offset_map.keys()
loader = LazyPassageLoader(passage_manager, label_map)
loader = LazyPassageLoader(passage_manager)
loader._meta_path = meta_file
return loader
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
# Load passages directly by their sequential IDs
passages_data = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
passages_data[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread(