fix: cache the loaded model
This commit is contained in:
@@ -5,7 +5,9 @@ with the correct, original embedding logic from the user's reference code.
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
@@ -126,6 +128,7 @@ class PassageManager:
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
@@ -373,10 +376,12 @@ class LeannBuilder:
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
with open(meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
@@ -390,7 +395,9 @@ class LeannSearcher:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||
index_path, **final_kwargs
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -399,9 +406,9 @@ class LeannSearcher:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
print("🔍 DEBUG LeannSearcher.search() called:")
|
||||
@@ -409,16 +416,21 @@ class LeannSearcher:
|
||||
print(f" Top_k: {top_k}")
|
||||
print(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Use backend's compute_query_embedding method
|
||||
# This will automatically use embedding server if available and needed
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
zmq_port = None
|
||||
if recompute_embeddings:
|
||||
zmq_port = self.backend_impl._ensure_server_running(
|
||||
self.meta_path_str,
|
||||
port=expected_zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
del expected_zmq_port
|
||||
|
||||
query_embedding = self.backend_impl.compute_query_embedding(
|
||||
query,
|
||||
expected_zmq_port,
|
||||
use_server_if_available=recompute_embeddings,
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
print(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
@@ -433,7 +445,7 @@ class LeannSearcher:
|
||||
prune_ratio=prune_ratio,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
expected_zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
@@ -488,10 +500,10 @@ class LeannChat:
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: Optional[int] = None,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
if llm_kwargs is None:
|
||||
|
||||
Reference in New Issue
Block a user