feat: laion, also required idmaps support

This commit is contained in:
Andy Lee
2025-08-22 13:32:33 -07:00
parent 069bce558b
commit a7c7e8801d
7 changed files with 1457 additions and 5 deletions

View File

@@ -89,6 +89,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
@@ -149,6 +158,16 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
query: np.ndarray,
@@ -244,7 +263,17 @@ class HNSWSearcher(BaseSearcher):
faiss.swig_ptr(labels),
params,
)
if self._id_map:
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [[map_label(int(l)) for l in batch_labels] for batch_labels in labels]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -94,6 +94,35 @@ def create_hnsw_embedding_server(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
@@ -170,13 +199,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
@@ -240,13 +270,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:

View File

@@ -372,6 +372,14 @@ class LeannBuilder:
is_build=True,
)
string_ids = [chunk["id"] for chunk in self.chunks]
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
try:
idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
@@ -490,6 +498,14 @@ class LeannBuilder:
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
# Persist ID map (order == embeddings order)
try:
idmap_file = index_dir / f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)