feat: disable warmup by default
This commit is contained in:
@@ -237,7 +237,7 @@ class LeannBuilder:
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
def __init__(self, index_path: str, **backend_kwargs):
|
||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||
meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
|
||||
@@ -251,6 +251,7 @@ class LeannSearcher:
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||
final_kwargs["enable_warmup"] = enable_warmup
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
|
||||
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
|
||||
@@ -306,9 +307,9 @@ from .chat import get_llm
|
||||
|
||||
class LeannChat:
|
||||
def __init__(
|
||||
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs
|
||||
self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, enable_warmup: bool = False, **kwargs
|
||||
):
|
||||
self.searcher = LeannSearcher(index_path, **kwargs)
|
||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||
self.llm = get_llm(llm_config)
|
||||
|
||||
def ask(self, question: str, top_k=5, **kwargs):
|
||||
|
||||
@@ -184,7 +184,7 @@ class EmbeddingServerManager:
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
@@ -312,6 +312,8 @@ class EmbeddingServerManager:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if "use_mlx" in kwargs and kwargs["use_mlx"]:
|
||||
command.extend(["--use-mlx"])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
|
||||
@@ -79,6 +79,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
use_mlx=kwargs.get("use_mlx", False),
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
|
||||
Reference in New Issue
Block a user