fix readme
This commit is contained in:
@@ -14,7 +14,7 @@ from leann.interface import (
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
|
||||
try:
|
||||
idmap_file = index_dir / f"{index_prefix}.ids.txt"
|
||||
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||
for id_str in ids:
|
||||
f.write(str(id_str) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write ID map: {e}")
|
||||
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
elif self.is_recompute:
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
def _convert_to_csr(self, index_file: Path):
|
||||
"""Convert built index to CSR format"""
|
||||
@@ -133,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.is_compact, self.is_pruned = (
|
||||
self.meta.get("is_compact", True),
|
||||
self.meta.get("is_pruned", True),
|
||||
)
|
||||
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||
|
||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||
if not index_file.exists():
|
||||
@@ -150,6 +161,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load ID map if available
|
||||
self._id_map: list[str] = []
|
||||
try:
|
||||
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
|
||||
if idmap_file.exists():
|
||||
with open(idmap_file, encoding="utf-8") as f:
|
||||
self._id_map = [line.rstrip("\n") for line in f]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ID map: {e}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
@@ -248,6 +269,19 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
search_time = time.time() - search_time
|
||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
if self._id_map:
|
||||
|
||||
def map_label(x: int) -> str:
|
||||
if 0 <= x < len(self._id_map):
|
||||
return self._id_map[x]
|
||||
return str(x)
|
||||
|
||||
string_labels = [
|
||||
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
else:
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
Reference in New Issue
Block a user