fix readme

This commit is contained in:
yichuan-w
2025-10-08 21:38:55 +00:00
parent 3ec5e8d035
commit 5be0c144ad
72 changed files with 16608 additions and 4175 deletions

View File

@@ -5,6 +5,8 @@ import os
import struct
import sys
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@@ -14,7 +14,7 @@ from leann.interface import (
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
@@ -90,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
@@ -133,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -150,6 +161,16 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
query: np.ndarray,
@@ -248,6 +269,19 @@ class HNSWSearcher(BaseSearcher):
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import msgpack
import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
# Ensure we have handlers if none exist
if not logger.handlers:
handler = logging.StreamHandler()
stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server(
@@ -92,6 +114,35 @@ def create_hnsw_embedding_server(
embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
@@ -138,7 +189,12 @@ def create_hnsw_embedding_server(
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -168,13 +224,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
@@ -187,7 +244,10 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -238,13 +298,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
@@ -252,7 +313,12 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)