fix: launch embedding server before adding

This commit is contained in:
Andy Lee
2025-09-30 00:53:22 -07:00
parent f42e086383
commit 82d536b2ae
2 changed files with 110 additions and 36 deletions

View File

@@ -182,6 +182,7 @@ def run_workflow(
is_recompute: bool, is_recompute: bool,
query: str, query: str,
top_k: int, top_k: int,
skip_search: bool,
) -> dict[str, Any]: ) -> dict[str, Any]:
prefix = f"[{label}] " if label else "" prefix = f"[{label}] " if label else ""
@@ -198,12 +199,15 @@ def run_workflow(
) )
initial_size = index_file_size(index_path) initial_size = index_file_size(index_path)
before_results = run_search( if not skip_search:
index_path, before_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...") print(f"\n{prefix}Updating index with additional passages...")
update_index( update_index(
@@ -215,20 +219,23 @@ def run_workflow(
is_recompute=is_recompute, is_recompute=is_recompute,
) )
after_results = run_search( if not skip_search:
index_path, after_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path) updated_size = index_file_size(index_path)
return { return {
"initial_size": initial_size, "initial_size": initial_size,
"updated_size": updated_size, "updated_size": updated_size,
"delta": updated_size - initial_size, "delta": updated_size - initial_size,
"before_results": before_results, "before_results": before_results if not skip_search else None,
"after_results": after_results, "after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path), "metadata": load_metadata_snapshot(index_path),
} }
@@ -314,6 +321,12 @@ def main() -> None:
action="store_false", action="store_false",
help="Skip building the no-recompute baseline.", help="Skip building the no-recompute baseline.",
) )
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True) parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args() args = parser.parse_args()
@@ -350,10 +363,13 @@ def main() -> None:
is_recompute=True, is_recompute=True,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print_results("initial search", recompute_stats["before_results"]) if not args.skip_search:
print_results("after update", recompute_stats["after_results"]) print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print( print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes" f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})" f"{recompute_stats['delta']})"
@@ -378,6 +394,7 @@ def main() -> None:
is_recompute=False, is_recompute=False,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print( print(
@@ -385,8 +402,12 @@ def main() -> None:
f"{baseline_stats['delta']})" f"{baseline_stats['delta']})"
) )
after_texts = [res.text for res in recompute_stats["after_results"]] after_texts = (
baseline_after_texts = [res.text for res in baseline_stats["after_results"]] [res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts: if after_texts == baseline_after_texts:
print( print(
"[no-recompute] Search results match recompute baseline; see above for the shared output." "[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import os
import pickle import pickle
import re import re
import subprocess import subprocess
@@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
@@ -754,32 +756,83 @@ class LeannBuilder:
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})." f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
) )
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks): for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset) new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) # Append passages/offsets before we attempt index.add so the ZMQ server
faiss.write_index(index, str(index_file)) # can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
with open(passages_file, "a", encoding="utf-8") as f: try:
for chunk in valid_chunks: with open(passages_file, "a", encoding="utf-8") as f:
offset = f.tell() for chunk in valid_chunks:
json.dump( offset = f.tell()
{ json.dump(
"id": chunk["id"], {
"text": chunk["text"], "id": chunk["id"],
"metadata": chunk.get("metadata", {}), "text": chunk["text"],
}, "metadata": chunk.get("metadata", {}),
f, },
ensure_ascii=False, f,
) ensure_ascii=False,
f.write("\n") )
offset_map[chunk["id"]] = offset f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f: with open(offset_file, "wb") as f:
pickle.dump(offset_map, f) pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map) meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f: with open(meta_path, "w", encoding="utf-8") as f: