Compare commits
4 Commits
fix/securi
...
fix-update
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd5c052bd8 | ||
|
|
2f77d0185c | ||
|
|
82d536b2ae | ||
|
|
f42e086383 |
@@ -186,6 +186,7 @@ def run_workflow(
|
|||||||
is_recompute: bool,
|
is_recompute: bool,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
skip_search: bool,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
prefix = f"[{label}] " if label else ""
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
@@ -202,12 +203,15 @@ def run_workflow(
|
|||||||
)
|
)
|
||||||
|
|
||||||
initial_size = index_file_size(index_path)
|
initial_size = index_file_size(index_path)
|
||||||
before_results = run_search(
|
if not skip_search:
|
||||||
index_path,
|
before_results = run_search(
|
||||||
query,
|
index_path,
|
||||||
top_k,
|
query,
|
||||||
recompute_embeddings=is_recompute,
|
top_k,
|
||||||
)
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
before_results = None
|
||||||
|
|
||||||
print(f"\n{prefix}Updating index with additional passages...")
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
update_index(
|
update_index(
|
||||||
@@ -219,20 +223,23 @@ def run_workflow(
|
|||||||
is_recompute=is_recompute,
|
is_recompute=is_recompute,
|
||||||
)
|
)
|
||||||
|
|
||||||
after_results = run_search(
|
if not skip_search:
|
||||||
index_path,
|
after_results = run_search(
|
||||||
query,
|
index_path,
|
||||||
top_k,
|
query,
|
||||||
recompute_embeddings=is_recompute,
|
top_k,
|
||||||
)
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
after_results = None
|
||||||
updated_size = index_file_size(index_path)
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"initial_size": initial_size,
|
"initial_size": initial_size,
|
||||||
"updated_size": updated_size,
|
"updated_size": updated_size,
|
||||||
"delta": updated_size - initial_size,
|
"delta": updated_size - initial_size,
|
||||||
"before_results": before_results,
|
"before_results": before_results if not skip_search else None,
|
||||||
"after_results": after_results,
|
"after_results": after_results if not skip_search else None,
|
||||||
"metadata": load_metadata_snapshot(index_path),
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +325,12 @@ def main() -> None:
|
|||||||
action="store_false",
|
action="store_false",
|
||||||
help="Skip building the no-recompute baseline.",
|
help="Skip building the no-recompute baseline.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-search",
|
||||||
|
dest="skip_search",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip the search step.",
|
||||||
|
)
|
||||||
parser.set_defaults(compare_no_recompute=True)
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -354,10 +367,13 @@ def main() -> None:
|
|||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
query=args.query,
|
query=args.query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
|
skip_search=args.skip_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_results("initial search", recompute_stats["before_results"])
|
if not args.skip_search:
|
||||||
print_results("after update", recompute_stats["after_results"])
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
|
if not args.skip_search:
|
||||||
|
print_results("after update", recompute_stats["after_results"])
|
||||||
print(
|
print(
|
||||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
f" (Δ {recompute_stats['delta']})"
|
f" (Δ {recompute_stats['delta']})"
|
||||||
@@ -382,6 +398,7 @@ def main() -> None:
|
|||||||
is_recompute=False,
|
is_recompute=False,
|
||||||
query=args.query,
|
query=args.query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
|
skip_search=args.skip_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@@ -389,8 +406,12 @@ def main() -> None:
|
|||||||
f" (Δ {baseline_stats['delta']})"
|
f" (Δ {baseline_stats['delta']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
after_texts = [res.text for res in recompute_stats["after_results"]]
|
after_texts = (
|
||||||
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
||||||
|
)
|
||||||
|
baseline_after_texts = (
|
||||||
|
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
||||||
|
)
|
||||||
if after_texts == baseline_after_texts:
|
if after_texts == baseline_after_texts:
|
||||||
print(
|
print(
|
||||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 1d51f0c074...5952745237
@@ -5,6 +5,7 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -20,6 +21,7 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
|||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
|
from .embedding_server_manager import EmbeddingServerManager
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .metadata_filter import MetadataFilterEngine
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
@@ -755,40 +757,88 @@ class LeannBuilder:
|
|||||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
|
||||||
|
passage_provider_options = meta.get("embedding_options", self.embedding_options)
|
||||||
|
|
||||||
base_id = index.ntotal
|
base_id = index.ntotal
|
||||||
for offset, chunk in enumerate(valid_chunks):
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
new_id = str(base_id + offset)
|
new_id = str(base_id + offset)
|
||||||
chunk.setdefault("metadata", {})["id"] = new_id
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
chunk["id"] = new_id
|
chunk["id"] = new_id
|
||||||
|
|
||||||
if needs_recompute:
|
# Append passages/offsets before we attempt index.add so the ZMQ server
|
||||||
# sequengtially add embeddings
|
# can resolve newly assigned IDs during recompute. Keep rollback hooks
|
||||||
for i in range(embeddings.shape[0]):
|
# so we can restore files if the update fails mid-way.
|
||||||
print(f"add {i} embeddings")
|
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
offset_map_backup = offset_map.copy()
|
||||||
else:
|
|
||||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
|
||||||
|
|
||||||
# index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
try:
|
||||||
faiss.write_index(index, str(index_file))
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
with open(passages_file, "a", encoding="utf-8") as f:
|
with open(offset_file, "wb") as f:
|
||||||
for chunk in valid_chunks:
|
pickle.dump(offset_map, f)
|
||||||
offset = f.tell()
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"id": chunk["id"],
|
|
||||||
"text": chunk["text"],
|
|
||||||
"metadata": chunk.get("metadata", {}),
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
f.write("\n")
|
|
||||||
offset_map[chunk["id"]] = offset
|
|
||||||
|
|
||||||
with open(offset_file, "wb") as f:
|
server_manager: Optional[EmbeddingServerManager] = None
|
||||||
pickle.dump(offset_map, f)
|
server_started = False
|
||||||
|
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if needs_recompute:
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=requested_zmq_port,
|
||||||
|
model_name=self.embedding_model,
|
||||||
|
embedding_mode=passage_meta_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
provider_options=passage_provider_options,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to start HNSW embedding server for recompute update."
|
||||||
|
)
|
||||||
|
if actual_port != requested_zmq_port:
|
||||||
|
server_manager.stop_server()
|
||||||
|
raise RuntimeError(
|
||||||
|
"Embedding server started on unexpected port "
|
||||||
|
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
|
||||||
|
)
|
||||||
|
|
||||||
|
if needs_recompute:
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
print(f"add {i} embeddings")
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
else:
|
||||||
|
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
if server_started and server_manager is not None:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Roll back appended passages/offset map to keep files consistent.
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_passages_size)
|
||||||
|
offset_map = offset_map_backup
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
raise
|
||||||
|
|
||||||
meta["total_passages"] = len(offset_map)
|
meta["total_passages"] = len(offset_map)
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user