fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -1,10 +1,9 @@
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal
import pickle
from typing import Dict, Any, List, Literal, Optional
import shutil
import time
import logging
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -16,6 +15,8 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
logger = logging.getLogger(__name__)
def get_metric_map():
from . import faiss # type: ignore
@@ -57,9 +58,9 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -81,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
@@ -90,11 +91,11 @@ class HNSWBuilder(LeannBackendBuilderInterface):
)
if success:
print("✅ CSR conversion successful.")
logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
print(
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
@@ -131,13 +132,11 @@ class HNSWSearcher(BaseSearcher):
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned but recompute is disabled.")
hnsw_config.is_recompute = (
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
def search(
self,
@@ -146,9 +145,9 @@ class HNSWSearcher(BaseSearcher):
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: int = 5557,
expected_zmq_port: Optional[int] = None,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -166,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server
expected_zmq_port: ZMQ port for embedding server
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -175,15 +174,9 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
# Use recompute_embeddings parameter
use_recompute = recompute_embeddings or self.is_pruned
if use_recompute:
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_file_path.exists():
raise RuntimeError(
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
)
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -191,7 +184,7 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
params.zmq_port = zmq_port
params.zmq_port = expected_zmq_port
params.efSearch = complexity
params.beam_size = beam_width
@@ -228,8 +221,7 @@ class HNSWSearcher(BaseSearcher):
)
string_labels = [
[str(int_label) for int_label in batch_labels]
for batch_labels in labels
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
"""
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith(
"text-embedding-"
):
embedding_mode = "openai"
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}")
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
finally:
sys.path.pop(0)
# Check port availability
import socket
@@ -78,13 +72,15 @@ def create_hnsw_embedding_server(
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
print(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread():
"""ZMQ server thread"""
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Use unified embedding computation
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")