fix: cache the loaded model

This commit is contained in:
Andy Lee
2025-07-21 21:20:53 -07:00
parent 727724990e
commit b3970793cf
9 changed files with 163 additions and 146 deletions

View File

@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
import json
import pickle
from leann.interface import LeannBackendSearcherInterface
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
@@ -126,6 +128,7 @@ class PassageManager:
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
@@ -373,10 +376,12 @@ class LeannBuilder:
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, "r", encoding="utf-8") as f:
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
@@ -390,7 +395,9 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
def search(
self,
@@ -399,9 +406,9 @@ class LeannSearcher:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
expected_zmq_port: int = 5557,
**kwargs,
) -> List[SearchResult]:
print("🔍 DEBUG LeannSearcher.search() called:")
@@ -409,16 +416,21 @@ class LeannSearcher:
print(f" Top_k: {top_k}")
print(f" Additional kwargs: {kwargs}")
# Use backend's compute_query_embedding method
# This will automatically use embedding server if available and needed
import time
start_time = time.time()
zmq_port = None
if recompute_embeddings:
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=expected_zmq_port,
**kwargs,
)
del expected_zmq_port
query_embedding = self.backend_impl.compute_query_embedding(
query,
expected_zmq_port,
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
print(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
@@ -433,7 +445,7 @@ class LeannSearcher:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
expected_zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
@@ -488,10 +500,10 @@ class LeannChat:
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: Optional[int] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):
if llm_kwargs is None: