From ab9c6bd69e51764c35952f4fa3b1b01c9641cad0 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 30 Sep 2025 00:58:17 -0700 Subject: [PATCH] Fix update. Should launch embedding server first (#130) * fix: set ntotal for storage as well * fix: launch embedding server before adding --- examples/dynamic_update_no_recompute.py | 57 +++++++++----- packages/leann-core/src/leann/api.py | 100 ++++++++++++++++++------ 2 files changed, 114 insertions(+), 43 deletions(-) diff --git a/examples/dynamic_update_no_recompute.py b/examples/dynamic_update_no_recompute.py index 84362dc..761375d 100644 --- a/examples/dynamic_update_no_recompute.py +++ b/examples/dynamic_update_no_recompute.py @@ -186,6 +186,7 @@ def run_workflow( is_recompute: bool, query: str, top_k: int, + skip_search: bool, ) -> dict[str, Any]: prefix = f"[{label}] " if label else "" @@ -202,12 +203,15 @@ def run_workflow( ) initial_size = index_file_size(index_path) - before_results = run_search( - index_path, - query, - top_k, - recompute_embeddings=is_recompute, - ) + if not skip_search: + before_results = run_search( + index_path, + query, + top_k, + recompute_embeddings=is_recompute, + ) + else: + before_results = None print(f"\n{prefix}Updating index with additional passages...") update_index( @@ -219,20 +223,23 @@ def run_workflow( is_recompute=is_recompute, ) - after_results = run_search( - index_path, - query, - top_k, - recompute_embeddings=is_recompute, - ) + if not skip_search: + after_results = run_search( + index_path, + query, + top_k, + recompute_embeddings=is_recompute, + ) + else: + after_results = None updated_size = index_file_size(index_path) return { "initial_size": initial_size, "updated_size": updated_size, "delta": updated_size - initial_size, - "before_results": before_results, - "after_results": after_results, + "before_results": before_results if not skip_search else None, + "after_results": after_results if not skip_search else None, "metadata": load_metadata_snapshot(index_path), } @@ -318,6 +325,12 @@ def main() -> None: action="store_false", help="Skip building the no-recompute baseline.", ) + parser.add_argument( + "--skip-search", + dest="skip_search", + action="store_true", + help="Skip the search step.", + ) parser.set_defaults(compare_no_recompute=True) args = parser.parse_args() @@ -354,10 +367,13 @@ def main() -> None: is_recompute=True, query=args.query, top_k=args.top_k, + skip_search=args.skip_search, ) - print_results("initial search", recompute_stats["before_results"]) - print_results("after update", recompute_stats["after_results"]) + if not args.skip_search: + print_results("initial search", recompute_stats["before_results"]) + if not args.skip_search: + print_results("after update", recompute_stats["after_results"]) print( f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes" f" (Δ {recompute_stats['delta']})" @@ -382,6 +398,7 @@ def main() -> None: is_recompute=False, query=args.query, top_k=args.top_k, + skip_search=args.skip_search, ) print( @@ -389,8 +406,12 @@ def main() -> None: f" (Δ {baseline_stats['delta']})" ) - after_texts = [res.text for res in recompute_stats["after_results"]] - baseline_after_texts = [res.text for res in baseline_stats["after_results"]] + after_texts = ( + [res.text for res in recompute_stats["after_results"]] if not args.skip_search else None + ) + baseline_after_texts = ( + [res.text for res in baseline_stats["after_results"]] if not args.skip_search else None + ) if after_texts == baseline_after_texts: print( "[no-recompute] Search results match recompute baseline; see above for the shared output." diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 0c18526..bf25666 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code. import json import logging +import os import pickle import re import subprocess @@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace from leann.interface import LeannBackendSearcherInterface from .chat import get_llm +from .embedding_server_manager import EmbeddingServerManager from .interface import LeannBackendFactoryInterface from .metadata_filter import MetadataFilterEngine from .registry import BACKEND_REGISTRY @@ -755,40 +757,88 @@ class LeannBuilder: f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})." ) + passage_meta_mode = meta.get("embedding_mode", self.embedding_mode) + passage_provider_options = meta.get("embedding_options", self.embedding_options) + base_id = index.ntotal for offset, chunk in enumerate(valid_chunks): new_id = str(base_id + offset) chunk.setdefault("metadata", {})["id"] = new_id chunk["id"] = new_id - if needs_recompute: - # sequengtially add embeddings - for i in range(embeddings.shape[0]): - print(f"add {i} embeddings") - index.add(1, faiss.swig_ptr(embeddings[i : i + 1])) - else: - index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) + # Append passages/offsets before we attempt index.add so the ZMQ server + # can resolve newly assigned IDs during recompute. Keep rollback hooks + # so we can restore files if the update fails mid-way. + rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0 + offset_map_backup = offset_map.copy() - # index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) - faiss.write_index(index, str(index_file)) + try: + with open(passages_file, "a", encoding="utf-8") as f: + for chunk in valid_chunks: + offset = f.tell() + json.dump( + { + "id": chunk["id"], + "text": chunk["text"], + "metadata": chunk.get("metadata", {}), + }, + f, + ensure_ascii=False, + ) + f.write("\n") + offset_map[chunk["id"]] = offset - with open(passages_file, "a", encoding="utf-8") as f: - for chunk in valid_chunks: - offset = f.tell() - json.dump( - { - "id": chunk["id"], - "text": chunk["text"], - "metadata": chunk.get("metadata", {}), - }, - f, - ensure_ascii=False, - ) - f.write("\n") - offset_map[chunk["id"]] = offset + with open(offset_file, "wb") as f: + pickle.dump(offset_map, f) - with open(offset_file, "wb") as f: - pickle.dump(offset_map, f) + server_manager: Optional[EmbeddingServerManager] = None + server_started = False + requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557")) + + try: + if needs_recompute: + server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) + server_started, actual_port = server_manager.start_server( + port=requested_zmq_port, + model_name=self.embedding_model, + embedding_mode=passage_meta_mode, + passages_file=str(meta_path), + distance_metric=distance_metric, + provider_options=passage_provider_options, + ) + if not server_started: + raise RuntimeError( + "Failed to start HNSW embedding server for recompute update." + ) + if actual_port != requested_zmq_port: + server_manager.stop_server() + raise RuntimeError( + "Embedding server started on unexpected port " + f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free." + ) + + if needs_recompute: + for i in range(embeddings.shape[0]): + print(f"add {i} embeddings") + index.add(1, faiss.swig_ptr(embeddings[i : i + 1])) + else: + index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) + faiss.write_index(index, str(index_file)) + finally: + if server_started and server_manager is not None: + server_manager.stop_server() + + except Exception: + # Roll back appended passages/offset map to keep files consistent. + if passages_file.exists(): + with open(passages_file, "rb+") as f: + f.truncate(rollback_passages_size) + offset_map = offset_map_backup + with open(offset_file, "wb") as f: + pickle.dump(offset_map, f) + raise meta["total_passages"] = len(offset_map) with open(meta_path, "w", encoding="utf-8") as f: