Compare commits

...

4 Commits

Author SHA1 Message Date
Andy Lee
fd5c052bd8 Update faiss for batch distances calc & caching when updating 2025-09-30 12:40:05 -07:00
Andy Lee
2f77d0185c Merge remote-tracking branch 'origin/main' into fix-update 2025-09-30 00:56:27 -07:00
Andy Lee
82d536b2ae fix: launch embedding server before adding 2025-09-30 00:53:22 -07:00
Andy Lee
f42e086383 fix: set ntotal for storage as well 2025-09-29 19:10:09 -07:00
3 changed files with 115 additions and 44 deletions

View File

@@ -186,6 +186,7 @@ def run_workflow(
is_recompute: bool, is_recompute: bool,
query: str, query: str,
top_k: int, top_k: int,
skip_search: bool,
) -> dict[str, Any]: ) -> dict[str, Any]:
prefix = f"[{label}] " if label else "" prefix = f"[{label}] " if label else ""
@@ -202,12 +203,15 @@ def run_workflow(
) )
initial_size = index_file_size(index_path) initial_size = index_file_size(index_path)
before_results = run_search( if not skip_search:
index_path, before_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...") print(f"\n{prefix}Updating index with additional passages...")
update_index( update_index(
@@ -219,20 +223,23 @@ def run_workflow(
is_recompute=is_recompute, is_recompute=is_recompute,
) )
after_results = run_search( if not skip_search:
index_path, after_results = run_search(
query, index_path,
top_k, query,
recompute_embeddings=is_recompute, top_k,
) recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path) updated_size = index_file_size(index_path)
return { return {
"initial_size": initial_size, "initial_size": initial_size,
"updated_size": updated_size, "updated_size": updated_size,
"delta": updated_size - initial_size, "delta": updated_size - initial_size,
"before_results": before_results, "before_results": before_results if not skip_search else None,
"after_results": after_results, "after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path), "metadata": load_metadata_snapshot(index_path),
} }
@@ -318,6 +325,12 @@ def main() -> None:
action="store_false", action="store_false",
help="Skip building the no-recompute baseline.", help="Skip building the no-recompute baseline.",
) )
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True) parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args() args = parser.parse_args()
@@ -354,10 +367,13 @@ def main() -> None:
is_recompute=True, is_recompute=True,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print_results("initial search", recompute_stats["before_results"]) if not args.skip_search:
print_results("after update", recompute_stats["after_results"]) print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print( print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes" f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})" f"{recompute_stats['delta']})"
@@ -382,6 +398,7 @@ def main() -> None:
is_recompute=False, is_recompute=False,
query=args.query, query=args.query,
top_k=args.top_k, top_k=args.top_k,
skip_search=args.skip_search,
) )
print( print(
@@ -389,8 +406,12 @@ def main() -> None:
f"{baseline_stats['delta']})" f"{baseline_stats['delta']})"
) )
after_texts = [res.text for res in recompute_stats["after_results"]] after_texts = (
baseline_after_texts = [res.text for res in baseline_stats["after_results"]] [res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts: if after_texts == baseline_after_texts:
print( print(
"[no-recompute] Search results match recompute baseline; see above for the shared output." "[no-recompute] Search results match recompute baseline; see above for the shared output."

View File

@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import os
import pickle import pickle
import re import re
import subprocess import subprocess
@@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
@@ -755,40 +757,88 @@ class LeannBuilder:
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})." f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
) )
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks): for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset) new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id chunk["id"] = new_id
if needs_recompute: # Append passages/offsets before we attempt index.add so the ZMQ server
# sequengtially add embeddings # can resolve newly assigned IDs during recompute. Keep rollback hooks
for i in range(embeddings.shape[0]): # so we can restore files if the update fails mid-way.
print(f"add {i} embeddings") rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
index.add(1, faiss.swig_ptr(embeddings[i : i + 1])) offset_map_backup = offset_map.copy()
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
# index.add(embeddings.shape[0], faiss.swig_ptr(embeddings)) try:
faiss.write_index(index, str(index_file)) with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(passages_file, "a", encoding="utf-8") as f: with open(offset_file, "wb") as f:
for chunk in valid_chunks: pickle.dump(offset_map, f)
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f: server_manager: Optional[EmbeddingServerManager] = None
pickle.dump(offset_map, f) server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
server_manager.stop_server()
raise RuntimeError(
"Embedding server started on unexpected port "
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
)
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map) meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f: with open(meta_path, "w", encoding="utf-8") as f: