feat: laion, also required idmaps support
This commit is contained in:
@@ -89,6 +89,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_prefix}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for id_str in ids:
|
||||
f.write(str(id_str) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write ID map: {e}")
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
@@ -149,6 +158,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load ID map if available
|
||||
self._id_map: list[str] = []
|
||||
try:
|
||||
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
self._id_map = [line.rstrip("\n") for line in f]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
@@ -244,7 +263,17 @@ class HNSWSearcher(BaseSearcher):
|
||||
faiss.swig_ptr(labels),
|
||||
params,
|
||||
)
|
||||
if self._id_map:
|
||||
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
def map_label(x: int) -> str:
|
||||
if 0 <= x < len(self._id_map):
|
||||
return self._id_map[x]
|
||||
return str(x)
|
||||
|
||||
string_labels = [[map_label(int(l)) for l in batch_labels] for batch_labels in labels]
|
||||
else:
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -94,6 +94,35 @@ def create_hnsw_embedding_server(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
|
||||
id_map: list[str] = []
|
||||
try:
|
||||
meta_path = Path(passages_file)
|
||||
base = meta_path.name
|
||||
if base.endswith(".meta.json"):
|
||||
base = base[: -len(".meta.json")] # e.g., laion_index.leann
|
||||
if base.endswith(".leann"):
|
||||
base = base[: -len(".leann")] # e.g., laion_index
|
||||
idmap_file = meta_path.parent / f"{base}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
id_map = [line.rstrip("\n") for line in f]
|
||||
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
|
||||
else:
|
||||
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def _map_node_id(nid) -> str:
|
||||
try:
|
||||
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
|
||||
idx = int(nid)
|
||||
if 0 <= idx < len(id_map):
|
||||
return id_map[idx]
|
||||
except Exception:
|
||||
pass
|
||||
return str(nid)
|
||||
|
||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||
|
||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||
@@ -170,13 +199,14 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as e:
|
||||
@@ -240,13 +270,14 @@ def create_hnsw_embedding_server(
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
|
||||
@@ -372,6 +372,14 @@ class LeannBuilder:
|
||||
is_build=True,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||
@@ -490,6 +498,14 @@ class LeannBuilder:
|
||||
|
||||
# Build the vector index using precomputed embeddings
|
||||
string_ids = [str(id_val) for id_val in ids]
|
||||
# Persist ID map (order == embeddings order)
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for sid in string_ids:
|
||||
f.write(str(sid) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path)
|
||||
|
||||
Reference in New Issue
Block a user