fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
"""
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith(
"text-embedding-"
):
embedding_mode = "openai"
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}")
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
finally:
sys.path.pop(0)
# Check port availability
import socket
@@ -78,13 +72,15 @@ def create_hnsw_embedding_server(
# Only support metadata file, fail fast for everything else
if not passages_file or not passages_file.endswith(".meta.json"):
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
print(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread():
"""ZMQ server thread"""
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Use unified embedding computation
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")