Compare commits
36 Commits
fix-update
...
financeben
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e9e2f3da0 | ||
|
|
ed167f43b0 | ||
|
|
f9746d3fe2 | ||
|
|
a090a3444a | ||
|
|
aaaba27a4f | ||
|
|
f40f539456 | ||
|
|
576a2dcb49 | ||
|
|
ad8ab84675 | ||
|
|
58b96b64d8 | ||
|
|
a76c3cdac4 | ||
|
|
520619deab | ||
|
|
dea08c95b4 | ||
|
|
3357d5765e | ||
|
|
9dbd0c64cc | ||
|
|
9c400acd7e | ||
|
|
ac560964f5 | ||
|
|
07e4f176e1 | ||
|
|
b1daf021e0 | ||
|
|
3578680cb6 | ||
|
|
a0d6857faa | ||
|
|
d7011bbea0 | ||
|
|
ef4c69d128 | ||
|
|
75c8aeee5f | ||
|
|
3d79741f9c | ||
|
|
df34c84bd3 | ||
|
|
8dfd2f015c | ||
|
|
ed72232bab | ||
|
|
26d961bfc5 | ||
|
|
722bda4ebb | ||
|
|
a7c7e8801d | ||
|
|
069bce558b | ||
|
|
772894012e | ||
|
|
5c163737c4 | ||
|
|
6d1d67ead7 | ||
|
|
ed27ea6990 | ||
|
|
baf2d76e0e |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -99,9 +99,3 @@ benchmarks/data/
|
|||||||
|
|
||||||
## multi vector
|
## multi vector
|
||||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|
||||||
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
|
||||||
# If you need to commit a specific demo PDF, remove this negation locally.
|
|
||||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
|
||||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
|
||||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
|
||||||
|
|
||||||
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
|
||||||
|
|
||||||
### What you’ll run
|
|
||||||
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
|
||||||
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
|
||||||
|
|
||||||
## Prerequisites (macOS)
|
|
||||||
|
|
||||||
### 1) Homebrew poppler (for pdf2image)
|
|
||||||
```bash
|
|
||||||
brew install poppler
|
|
||||||
which pdfinfo && pdfinfo -v
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2) Python environment
|
|
||||||
Use uv (recommended) or pip. Python 3.9+.
|
|
||||||
|
|
||||||
Using uv:
|
|
||||||
```bash
|
|
||||||
uv pip install \
|
|
||||||
colpali_engine \
|
|
||||||
pdf2image \
|
|
||||||
pillow \
|
|
||||||
matplotlib qwen_vl_utils \
|
|
||||||
einops \
|
|
||||||
seaborn
|
|
||||||
```
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- On first run, models download from Hugging Face. Login/config if needed.
|
|
||||||
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
|
||||||
```bash
|
|
||||||
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run the demos
|
|
||||||
|
|
||||||
### A) Local PDF example
|
|
||||||
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
|
||||||
# If you don't have the sample PDF locally, download it (ignored by Git)
|
|
||||||
mkdir -p pdfs
|
|
||||||
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
|
||||||
ls pdfs/2004.12832v2.pdf
|
|
||||||
# Ensure output dir exists
|
|
||||||
mkdir -p pages
|
|
||||||
python multi-vector-leann-paper-example.py
|
|
||||||
```
|
|
||||||
Expected:
|
|
||||||
- Page images in `pages/`.
|
|
||||||
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
|
||||||
|
|
||||||
To use your own PDF: edit `pdf_path` near the top of the script.
|
|
||||||
|
|
||||||
### B) Similarity map + answer demo
|
|
||||||
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
|
||||||
python multi-vector-leann-similarity-map.py
|
|
||||||
```
|
|
||||||
Artifacts (when enabled):
|
|
||||||
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
|
||||||
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
|
||||||
|
|
||||||
Key knobs in the script (top of file):
|
|
||||||
- `QUERY`: your question
|
|
||||||
- `MODEL`: `"colqwen2"` or `"colpali"`
|
|
||||||
- `USE_HF_DATASET`: set `False` to use local pages
|
|
||||||
- `PDF`, `PAGES_DIR`: for local mode
|
|
||||||
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
|
||||||
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
|
||||||
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
|
||||||
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
|
||||||
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
|
||||||
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
|
||||||
- For local PDFs, page images go to `./pages/`.
|
|
||||||
|
|
||||||
|
|
||||||
### Retrieval and Visualization Example
|
|
||||||
|
|
||||||
Example settings in `multi-vector-leann-similarity-map.py`:
|
|
||||||
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
|
||||||
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
|
||||||
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
|
||||||
|
|
||||||
Run:
|
|
||||||
```bash
|
|
||||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
|
||||||
python multi-vector-leann-similarity-map.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Outputs (by default):
|
|
||||||
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
|
||||||
- Similarity map: `./figures/similarity_map_rank1.png`
|
|
||||||
|
|
||||||
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
|
||||||
"):
|
|
||||||

|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
|
||||||
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 166 KiB |
@@ -169,7 +169,7 @@ def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
doc_vecs: list[Any] = []
|
doc_vecs: list[Any] = []
|
||||||
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
for batch_doc in dataloader:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
@@ -200,7 +200,7 @@ def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
q_vecs: list[Any] = []
|
q_vecs: list[Any] = []
|
||||||
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
for batch_query in dataloader:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
if model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
@@ -362,7 +362,7 @@ if USE_HF_DATASET:
|
|||||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
filepaths: list[str] = []
|
filepaths: list[str] = []
|
||||||
images: list[Image.Image] = []
|
images: list[Image.Image] = []
|
||||||
for i in tqdm(range(N), desc="Loading dataset", total=N ):
|
for i in tqdm(range(N), desc="Loading dataset"):
|
||||||
p = dataset[i]
|
p = dataset[i]
|
||||||
# Compose a descriptive identifier for printing later
|
# Compose a descriptive identifier for printing later
|
||||||
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
|||||||
@@ -4,24 +4,39 @@
|
|||||||
# pip install tqdm
|
# pip install tqdm
|
||||||
# pip install pillow
|
# pip install pillow
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
|
||||||
|
pdf_path = "pdfs/2004.12832v2.pdf"
|
||||||
|
images = convert_from_path(pdf_path)
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(f"pages/page_{i + 1}.png", "PNG")
|
||||||
|
|
||||||
|
# %%
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from PIL import Image
|
# Make local leann packages importable without installing
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# Ensure local leann packages are importable before importing them
|
|
||||||
_repo_root = Path(__file__).resolve().parents[3]
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
import sys
|
||||||
|
|
||||||
if str(_leann_core_src) not in sys.path:
|
if str(_leann_core_src) not in sys.path:
|
||||||
sys.path.append(str(_leann_core_src))
|
sys.path.append(str(_leann_core_src))
|
||||||
if str(_leann_hnsw_pkg) not in sys.path:
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
sys.path.append(str(_leann_hnsw_pkg))
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector
|
||||||
|
|
||||||
|
|
||||||
|
class LeannRetriever(LeannMultiVector):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colpali_engine.models import ColPali
|
from colpali_engine.models import ColPali
|
||||||
@@ -73,6 +88,13 @@ for batch_query in dataloader:
|
|||||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
print(qs[0].shape)
|
print(qs[0].shape)
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
@@ -43,11 +43,7 @@ from apps.chunking import create_text_chunks
|
|||||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
DEFAULT_QUERY = "What's LEANN?"
|
DEFAULT_QUERY = "What's LEANN?"
|
||||||
DEFAULT_INITIAL_FILES = [
|
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
|
||||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
|
||||||
REPO_ROOT / "data" / "PrideandPrejudice.txt",
|
|
||||||
]
|
|
||||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
@@ -186,7 +182,6 @@ def run_workflow(
|
|||||||
is_recompute: bool,
|
is_recompute: bool,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
skip_search: bool,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
prefix = f"[{label}] " if label else ""
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
@@ -203,15 +198,12 @@ def run_workflow(
|
|||||||
)
|
)
|
||||||
|
|
||||||
initial_size = index_file_size(index_path)
|
initial_size = index_file_size(index_path)
|
||||||
if not skip_search:
|
before_results = run_search(
|
||||||
before_results = run_search(
|
index_path,
|
||||||
index_path,
|
query,
|
||||||
query,
|
top_k,
|
||||||
top_k,
|
recompute_embeddings=is_recompute,
|
||||||
recompute_embeddings=is_recompute,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
before_results = None
|
|
||||||
|
|
||||||
print(f"\n{prefix}Updating index with additional passages...")
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
update_index(
|
update_index(
|
||||||
@@ -223,23 +215,20 @@ def run_workflow(
|
|||||||
is_recompute=is_recompute,
|
is_recompute=is_recompute,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not skip_search:
|
after_results = run_search(
|
||||||
after_results = run_search(
|
index_path,
|
||||||
index_path,
|
query,
|
||||||
query,
|
top_k,
|
||||||
top_k,
|
recompute_embeddings=is_recompute,
|
||||||
recompute_embeddings=is_recompute,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
after_results = None
|
|
||||||
updated_size = index_file_size(index_path)
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"initial_size": initial_size,
|
"initial_size": initial_size,
|
||||||
"updated_size": updated_size,
|
"updated_size": updated_size,
|
||||||
"delta": updated_size - initial_size,
|
"delta": updated_size - initial_size,
|
||||||
"before_results": before_results if not skip_search else None,
|
"before_results": before_results,
|
||||||
"after_results": after_results if not skip_search else None,
|
"after_results": after_results,
|
||||||
"metadata": load_metadata_snapshot(index_path),
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,12 +314,6 @@ def main() -> None:
|
|||||||
action="store_false",
|
action="store_false",
|
||||||
help="Skip building the no-recompute baseline.",
|
help="Skip building the no-recompute baseline.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--skip-search",
|
|
||||||
dest="skip_search",
|
|
||||||
action="store_true",
|
|
||||||
help="Skip the search step.",
|
|
||||||
)
|
|
||||||
parser.set_defaults(compare_no_recompute=True)
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -367,13 +350,10 @@ def main() -> None:
|
|||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
query=args.query,
|
query=args.query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
skip_search=args.skip_search,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.skip_search:
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
print_results("initial search", recompute_stats["before_results"])
|
print_results("after update", recompute_stats["after_results"])
|
||||||
if not args.skip_search:
|
|
||||||
print_results("after update", recompute_stats["after_results"])
|
|
||||||
print(
|
print(
|
||||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
f" (Δ {recompute_stats['delta']})"
|
f" (Δ {recompute_stats['delta']})"
|
||||||
@@ -398,7 +378,6 @@ def main() -> None:
|
|||||||
is_recompute=False,
|
is_recompute=False,
|
||||||
query=args.query,
|
query=args.query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
skip_search=args.skip_search,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@@ -406,12 +385,8 @@ def main() -> None:
|
|||||||
f" (Δ {baseline_stats['delta']})"
|
f" (Δ {baseline_stats['delta']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
after_texts = (
|
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||||
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||||
)
|
|
||||||
baseline_after_texts = (
|
|
||||||
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
|
||||||
)
|
|
||||||
if after_texts == baseline_after_texts:
|
if after_texts == baseline_after_texts:
|
||||||
print(
|
print(
|
||||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 5952745237...1d51f0c074
@@ -5,7 +5,6 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -21,7 +20,6 @@ from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
|||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
from .embedding_server_manager import EmbeddingServerManager
|
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .metadata_filter import MetadataFilterEngine
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
@@ -730,7 +728,6 @@ class LeannBuilder:
|
|||||||
index = faiss.read_index(str(index_file))
|
index = faiss.read_index(str(index_file))
|
||||||
if hasattr(index, "is_recompute"):
|
if hasattr(index, "is_recompute"):
|
||||||
index.is_recompute = needs_recompute
|
index.is_recompute = needs_recompute
|
||||||
print(f"index.is_recompute: {index.is_recompute}")
|
|
||||||
if getattr(index, "storage", None) is None:
|
if getattr(index, "storage", None) is None:
|
||||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
storage_index = faiss.IndexFlatIP(index.d)
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
@@ -738,107 +735,37 @@ class LeannBuilder:
|
|||||||
storage_index = faiss.IndexFlatL2(index.d)
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
index.storage = storage_index
|
index.storage = storage_index
|
||||||
index.own_fields = True
|
index.own_fields = True
|
||||||
# Faiss expects storage.ntotal to reflect the existing graph's
|
|
||||||
# population (even if the vectors themselves were pruned from disk
|
|
||||||
# for recompute mode). When we attach a fresh IndexFlat here its
|
|
||||||
# ntotal starts at zero, which later causes IndexHNSW::add to
|
|
||||||
# believe new "preset" levels were provided and trips the
|
|
||||||
# `n0 + n == levels.size()` assertion. Seed the temporary storage
|
|
||||||
# with the current ntotal so Faiss maintains the proper offset for
|
|
||||||
# incoming vectors.
|
|
||||||
try:
|
|
||||||
storage_index.ntotal = index.ntotal
|
|
||||||
except AttributeError:
|
|
||||||
# Older Faiss builds may not expose ntotal as a writable
|
|
||||||
# attribute; in that case we fall back to the default behaviour.
|
|
||||||
pass
|
|
||||||
if index.d != embedding_dim:
|
if index.d != embedding_dim:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||||
)
|
)
|
||||||
|
|
||||||
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
|
|
||||||
passage_provider_options = meta.get("embedding_options", self.embedding_options)
|
|
||||||
|
|
||||||
base_id = index.ntotal
|
base_id = index.ntotal
|
||||||
for offset, chunk in enumerate(valid_chunks):
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
new_id = str(base_id + offset)
|
new_id = str(base_id + offset)
|
||||||
chunk.setdefault("metadata", {})["id"] = new_id
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
chunk["id"] = new_id
|
chunk["id"] = new_id
|
||||||
|
|
||||||
# Append passages/offsets before we attempt index.add so the ZMQ server
|
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||||
# can resolve newly assigned IDs during recompute. Keep rollback hooks
|
faiss.write_index(index, str(index_file))
|
||||||
# so we can restore files if the update fails mid-way.
|
|
||||||
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
|
|
||||||
offset_map_backup = offset_map.copy()
|
|
||||||
|
|
||||||
try:
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
with open(passages_file, "a", encoding="utf-8") as f:
|
for chunk in valid_chunks:
|
||||||
for chunk in valid_chunks:
|
offset = f.tell()
|
||||||
offset = f.tell()
|
json.dump(
|
||||||
json.dump(
|
{
|
||||||
{
|
"id": chunk["id"],
|
||||||
"id": chunk["id"],
|
"text": chunk["text"],
|
||||||
"text": chunk["text"],
|
"metadata": chunk.get("metadata", {}),
|
||||||
"metadata": chunk.get("metadata", {}),
|
},
|
||||||
},
|
f,
|
||||||
f,
|
ensure_ascii=False,
|
||||||
ensure_ascii=False,
|
)
|
||||||
)
|
f.write("\n")
|
||||||
f.write("\n")
|
offset_map[chunk["id"]] = offset
|
||||||
offset_map[chunk["id"]] = offset
|
|
||||||
|
|
||||||
with open(offset_file, "wb") as f:
|
with open(offset_file, "wb") as f:
|
||||||
pickle.dump(offset_map, f)
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
server_manager: Optional[EmbeddingServerManager] = None
|
|
||||||
server_started = False
|
|
||||||
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
|
|
||||||
|
|
||||||
try:
|
|
||||||
if needs_recompute:
|
|
||||||
server_manager = EmbeddingServerManager(
|
|
||||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
|
||||||
)
|
|
||||||
server_started, actual_port = server_manager.start_server(
|
|
||||||
port=requested_zmq_port,
|
|
||||||
model_name=self.embedding_model,
|
|
||||||
embedding_mode=passage_meta_mode,
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
distance_metric=distance_metric,
|
|
||||||
provider_options=passage_provider_options,
|
|
||||||
)
|
|
||||||
if not server_started:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Failed to start HNSW embedding server for recompute update."
|
|
||||||
)
|
|
||||||
if actual_port != requested_zmq_port:
|
|
||||||
server_manager.stop_server()
|
|
||||||
raise RuntimeError(
|
|
||||||
"Embedding server started on unexpected port "
|
|
||||||
f"{actual_port}; expected {requested_zmq_port}. Make sure the desired ZMQ port is free."
|
|
||||||
)
|
|
||||||
|
|
||||||
if needs_recompute:
|
|
||||||
for i in range(embeddings.shape[0]):
|
|
||||||
print(f"add {i} embeddings")
|
|
||||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
|
||||||
else:
|
|
||||||
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
|
||||||
faiss.write_index(index, str(index_file))
|
|
||||||
finally:
|
|
||||||
if server_started and server_manager is not None:
|
|
||||||
server_manager.stop_server()
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# Roll back appended passages/offset map to keep files consistent.
|
|
||||||
if passages_file.exists():
|
|
||||||
with open(passages_file, "rb+") as f:
|
|
||||||
f.truncate(rollback_passages_size)
|
|
||||||
offset_map = offset_map_backup
|
|
||||||
with open(offset_file, "wb") as f:
|
|
||||||
pickle.dump(offset_map, f)
|
|
||||||
raise
|
|
||||||
|
|
||||||
meta["total_passages"] = len(offset_map)
|
meta["total_passages"] = len(offset_map)
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
@@ -49,85 +48,6 @@ def _check_port(port: int) -> bool:
|
|||||||
# Note: All cross-process scanning helpers removed for simplicity
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
|
|
||||||
|
|
||||||
def _safe_resolve(path: Path) -> str:
|
|
||||||
"""Resolve paths safely even if the target does not yet exist."""
|
|
||||||
try:
|
|
||||||
return str(path.resolve(strict=False))
|
|
||||||
except Exception:
|
|
||||||
return str(path)
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_stat_signature(path: Path) -> dict:
|
|
||||||
"""Return a lightweight signature describing the current state of a path."""
|
|
||||||
signature: dict[str, object] = {"path": _safe_resolve(path)}
|
|
||||||
try:
|
|
||||||
stat = path.stat()
|
|
||||||
except FileNotFoundError:
|
|
||||||
signature["missing"] = True
|
|
||||||
except Exception as exc: # pragma: no cover - unexpected filesystem errors
|
|
||||||
signature["error"] = str(exc)
|
|
||||||
else:
|
|
||||||
signature["mtime_ns"] = stat.st_mtime_ns
|
|
||||||
signature["size"] = stat.st_size
|
|
||||||
return signature
|
|
||||||
|
|
||||||
|
|
||||||
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
|
|
||||||
"""Collect modification signatures for metadata and referenced passage files."""
|
|
||||||
if not passages_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
meta_path = Path(passages_file)
|
|
||||||
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with meta_path.open(encoding="utf-8") as fh:
|
|
||||||
meta = json.load(fh)
|
|
||||||
except FileNotFoundError:
|
|
||||||
signature["meta_missing"] = True
|
|
||||||
signature["sources"] = []
|
|
||||||
return signature
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
signature["meta_error"] = f"json_error:{exc}"
|
|
||||||
signature["sources"] = []
|
|
||||||
return signature
|
|
||||||
except Exception as exc: # pragma: no cover - unexpected errors
|
|
||||||
signature["meta_error"] = str(exc)
|
|
||||||
signature["sources"] = []
|
|
||||||
return signature
|
|
||||||
|
|
||||||
base_dir = meta_path.parent
|
|
||||||
seen_paths: set[str] = set()
|
|
||||||
source_signatures: list[dict[str, object]] = []
|
|
||||||
|
|
||||||
for source in meta.get("passage_sources", []):
|
|
||||||
for key, kind in (
|
|
||||||
("path", "passages"),
|
|
||||||
("path_relative", "passages"),
|
|
||||||
("index_path", "index"),
|
|
||||||
("index_path_relative", "index"),
|
|
||||||
):
|
|
||||||
raw_path = source.get(key)
|
|
||||||
if not raw_path:
|
|
||||||
continue
|
|
||||||
candidate = Path(raw_path)
|
|
||||||
if not candidate.is_absolute():
|
|
||||||
candidate = base_dir / candidate
|
|
||||||
resolved = _safe_resolve(candidate)
|
|
||||||
if resolved in seen_paths:
|
|
||||||
continue
|
|
||||||
seen_paths.add(resolved)
|
|
||||||
sig = _safe_stat_signature(candidate)
|
|
||||||
sig["kind"] = kind
|
|
||||||
source_signatures.append(sig)
|
|
||||||
|
|
||||||
signature["sources"] = source_signatures
|
|
||||||
return signature
|
|
||||||
|
|
||||||
|
|
||||||
# Note: All cross-process scanning helpers removed for simplicity
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
"""
|
"""
|
||||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||||
@@ -165,14 +85,13 @@ class EmbeddingServerManager:
|
|||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
provider_options = kwargs.pop("provider_options", None)
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
passages_file = kwargs.get("passages_file", "")
|
|
||||||
|
|
||||||
config_signature = self._build_config_signature(
|
config_signature = {
|
||||||
model_name=model_name,
|
"model_name": model_name,
|
||||||
embedding_mode=embedding_mode,
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
provider_options=provider_options,
|
"embedding_mode": embedding_mode,
|
||||||
passages_file=passages_file,
|
"provider_options": provider_options or {},
|
||||||
)
|
}
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# If this manager already has a live server, just reuse it
|
||||||
if (
|
if (
|
||||||
@@ -196,7 +115,6 @@ class EmbeddingServerManager:
|
|||||||
port,
|
port,
|
||||||
model_name,
|
model_name,
|
||||||
embedding_mode,
|
embedding_mode,
|
||||||
config_signature=config_signature,
|
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -218,30 +136,11 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_config_signature(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
model_name: str,
|
|
||||||
embedding_mode: str,
|
|
||||||
provider_options: Optional[dict],
|
|
||||||
passages_file: Optional[str],
|
|
||||||
) -> dict:
|
|
||||||
"""Create a signature describing the current server configuration."""
|
|
||||||
return {
|
|
||||||
"model_name": model_name,
|
|
||||||
"passages_file": passages_file or "",
|
|
||||||
"embedding_mode": embedding_mode,
|
|
||||||
"provider_options": provider_options or {},
|
|
||||||
"passages_signature": _build_passages_signature(passages_file),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
*,
|
|
||||||
config_signature: Optional[dict] = None,
|
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
@@ -264,11 +163,10 @@ class EmbeddingServerManager:
|
|||||||
command,
|
command,
|
||||||
actual_port,
|
actual_port,
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
config_signature=config_signature,
|
|
||||||
)
|
)
|
||||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
if started:
|
if started:
|
||||||
self._server_config = config_signature or {
|
self._server_config = {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"passages_file": kwargs.get("passages_file", ""),
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
"embedding_mode": embedding_mode,
|
"embedding_mode": embedding_mode,
|
||||||
@@ -300,7 +198,6 @@ class EmbeddingServerManager:
|
|||||||
command,
|
command,
|
||||||
port,
|
port,
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
config_signature=config_signature,
|
|
||||||
)
|
)
|
||||||
started, ready_port = self._wait_for_server_ready(port)
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
if started:
|
if started:
|
||||||
@@ -344,9 +241,7 @@ class EmbeddingServerManager:
|
|||||||
self,
|
self,
|
||||||
command: list,
|
command: list,
|
||||||
port: int,
|
port: int,
|
||||||
*,
|
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
config_signature: Optional[dict] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
@@ -381,29 +276,26 @@ class EmbeddingServerManager:
|
|||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse (best effort; refined later when ready)
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
if config_signature is not None:
|
try:
|
||||||
self._server_config = config_signature
|
self._server_config = {
|
||||||
else: # Fallback for unexpected code paths
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
try:
|
if "--model-name" in command
|
||||||
self._server_config = {
|
else "",
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
"passages_file": command[command.index("--passages-file") + 1]
|
||||||
if "--model-name" in command
|
if "--passages-file" in command
|
||||||
else "",
|
else "",
|
||||||
"passages_file": command[command.index("--passages-file") + 1]
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
if "--passages-file" in command
|
if "--embedding-mode" in command
|
||||||
else "",
|
else "sentence-transformers",
|
||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
"provider_options": provider_options or {},
|
||||||
if "--embedding-mode" in command
|
}
|
||||||
else "sentence-transformers",
|
except Exception:
|
||||||
"provider_options": provider_options or {},
|
self._server_config = {
|
||||||
}
|
"model_name": "",
|
||||||
except Exception:
|
"passages_file": "",
|
||||||
self._server_config = {
|
"embedding_mode": "sentence-transformers",
|
||||||
"model_name": "",
|
"provider_options": provider_options or {},
|
||||||
"passages_file": "",
|
}
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
"provider_options": provider_options or {},
|
|
||||||
}
|
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
@@ -511,9 +403,7 @@ class EmbeddingServerManager:
|
|||||||
self,
|
self,
|
||||||
command: list,
|
command: list,
|
||||||
port: int,
|
port: int,
|
||||||
*,
|
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
config_signature: Optional[dict] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
@@ -539,15 +429,12 @@ class EmbeddingServerManager:
|
|||||||
atexit.register(self._finalize_process)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
# Record config for in-process reuse is best-effort in Colab mode
|
# Record config for in-process reuse is best-effort in Colab mode
|
||||||
if config_signature is not None:
|
self._server_config = {
|
||||||
self._server_config = config_signature
|
"model_name": "",
|
||||||
else:
|
"passages_file": "",
|
||||||
self._server_config = {
|
"embedding_mode": "sentence-transformers",
|
||||||
"model_name": "",
|
"provider_options": provider_options or {},
|
||||||
"passages_file": "",
|
}
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
"provider_options": provider_options or {},
|
|
||||||
}
|
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ target-version = "py39"
|
|||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = [
|
extend-exclude = [
|
||||||
"third_party",
|
"third_party",
|
||||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py",
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
|
||||||
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from leann.embedding_server_manager import EmbeddingServerManager
|
|
||||||
|
|
||||||
|
|
||||||
class DummyProcess:
|
|
||||||
def __init__(self):
|
|
||||||
self.pid = 12345
|
|
||||||
self._terminated = False
|
|
||||||
|
|
||||||
def poll(self):
|
|
||||||
return 0 if self._terminated else None
|
|
||||||
|
|
||||||
def terminate(self):
|
|
||||||
self._terminated = True
|
|
||||||
|
|
||||||
def kill(self):
|
|
||||||
self._terminated = True
|
|
||||||
|
|
||||||
def wait(self, timeout=None):
|
|
||||||
self._terminated = True
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embedding_manager(monkeypatch):
|
|
||||||
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
|
|
||||||
|
|
||||||
def fake_get_available_port(start_port):
|
|
||||||
return start_port
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"leann.embedding_server_manager._get_available_port",
|
|
||||||
fake_get_available_port,
|
|
||||||
)
|
|
||||||
|
|
||||||
start_calls = []
|
|
||||||
|
|
||||||
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
|
|
||||||
config_signature = kwargs.get("config_signature")
|
|
||||||
start_calls.append(config_signature)
|
|
||||||
self.server_process = DummyProcess()
|
|
||||||
self.server_port = port
|
|
||||||
self._server_config = config_signature
|
|
||||||
return True, port
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
EmbeddingServerManager,
|
|
||||||
"_start_new_server",
|
|
||||||
fake_start_new_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure stop_server doesn't try to operate on real subprocesses
|
|
||||||
def fake_stop_server(self):
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
self._server_config = None
|
|
||||||
|
|
||||||
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
|
|
||||||
|
|
||||||
return manager, start_calls
|
|
||||||
|
|
||||||
|
|
||||||
def _write_meta(meta_path, passages_name, index_name, total):
|
|
||||||
meta_path.write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"backend_name": "hnsw",
|
|
||||||
"embedding_model": "test-model",
|
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
"dimensions": 3,
|
|
||||||
"backend_kwargs": {},
|
|
||||||
"passage_sources": [
|
|
||||||
{
|
|
||||||
"type": "jsonl",
|
|
||||||
"path": passages_name,
|
|
||||||
"index_path": index_name,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"total_passages": total,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
|
|
||||||
manager, start_calls = embedding_manager
|
|
||||||
|
|
||||||
meta_path = tmp_path / "example.meta.json"
|
|
||||||
passages_path = tmp_path / "example.passages.jsonl"
|
|
||||||
index_path = tmp_path / "example.passages.idx"
|
|
||||||
|
|
||||||
passages_path.write_text("first\n", encoding="utf-8")
|
|
||||||
index_path.write_bytes(b"index")
|
|
||||||
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
|
|
||||||
|
|
||||||
# Initial start populates signature
|
|
||||||
ok, port = manager.start_server(
|
|
||||||
port=6000,
|
|
||||||
model_name="test-model",
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
)
|
|
||||||
assert ok
|
|
||||||
assert port == 6000
|
|
||||||
assert len(start_calls) == 1
|
|
||||||
|
|
||||||
initial_signature = start_calls[0]["passages_signature"]
|
|
||||||
|
|
||||||
# No metadata change => reuse existing server
|
|
||||||
ok, port_again = manager.start_server(
|
|
||||||
port=6000,
|
|
||||||
model_name="test-model",
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
)
|
|
||||||
assert ok
|
|
||||||
assert port_again == 6000
|
|
||||||
assert len(start_calls) == 1
|
|
||||||
|
|
||||||
# Modify passage data and metadata to force signature change
|
|
||||||
time.sleep(0.01) # Ensure filesystem timestamps move forward
|
|
||||||
passages_path.write_text("second\n", encoding="utf-8")
|
|
||||||
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
|
|
||||||
|
|
||||||
ok, port_third = manager.start_server(
|
|
||||||
port=6000,
|
|
||||||
model_name="test-model",
|
|
||||||
passages_file=str(meta_path),
|
|
||||||
)
|
|
||||||
assert ok
|
|
||||||
assert port_third == 6000
|
|
||||||
assert len(start_calls) == 2
|
|
||||||
|
|
||||||
updated_signature = start_calls[1]["passages_signature"]
|
|
||||||
assert updated_signature != initial_signature
|
|
||||||
Reference in New Issue
Block a user