fix: cache the loaded model
This commit is contained in:
@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
passages_data: Optional[Dict[str, str]] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
distance_metric: str = "mips",
|
||||
@@ -39,12 +38,6 @@ def create_hnsw_embedding_server(
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
Simplified version using unified embedding computation module.
|
||||
"""
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if embedding_mode == "sentence-transformers" and model_name.startswith(
|
||||
"text-embedding-"
|
||||
):
|
||||
embedding_mode = "openai"
|
||||
|
||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
print(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
@@ -64,6 +57,7 @@ def create_hnsw_embedding_server(
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
|
||||
@@ -78,13 +72,15 @@ def create_hnsw_embedding_server(
|
||||
# Only support metadata file, fail fast for everything else
|
||||
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata")
|
||||
print(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
@@ -112,7 +108,7 @@ def create_hnsw_embedding_server(
|
||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||
)
|
||||
|
||||
# Use unified embedding computation
|
||||
# Use unified embedding computation (now with model caching)
|
||||
embeddings = compute_embeddings(
|
||||
request_payload, model_name, mode=embedding_mode
|
||||
)
|
||||
@@ -148,15 +144,15 @@ def create_hnsw_embedding_server(
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
print(f"ERROR: Passage ID {nid} not found")
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Passage with ID {nid} not found"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
print(
|
||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
@@ -204,7 +200,9 @@ def create_hnsw_embedding_server(
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
|
||||
Reference in New Issue
Block a user