fix: use Python 3.9 compatible builtin generics
- Convert List[str] to list[str], Dict[str, Any] to dict[str, Any], etc. - Use ruff --unsafe-fixes to automatically apply all type annotation updates - Remove deprecated typing imports (List, Dict, Tuple) where no longer needed - Keep Optional[str] syntax (union operator | not supported in Python 3.9) Now all type annotations are Python 3.9 compatible with modern builtin generics. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -85,7 +85,7 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
|||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
|
|
||||||
def _calculate_smart_memory_config(data: np.ndarray) -> Tuple[float, float]:
|
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||||
|
|
||||||
@@ -202,7 +202,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
@@ -388,7 +388,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using DiskANN index.
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class GraphPartitioner:
|
class GraphPartitioner:
|
||||||
@@ -92,7 +92,7 @@ class GraphPartitioner:
|
|||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
partition_prefix: Optional[str] = None,
|
partition_prefix: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Partition a disk-based index for improved performance.
|
Partition a disk-based index for improved performance.
|
||||||
|
|
||||||
@@ -267,7 +267,7 @@ def partition_graph(
|
|||||||
partition_prefix: Optional[str] = None,
|
partition_prefix: Optional[str] = None,
|
||||||
build_type: str = "release",
|
build_type: str = "release",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Convenience function to partition a graph index.
|
Convenience function to partition a graph index.
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from typing import Optional
|
|||||||
|
|
||||||
def partition_graph_simple(
|
def partition_graph_simple(
|
||||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Simple function to partition a graph index.
|
Simple function to partition a graph index.
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -160,7 +160,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using HNSW index.
|
Search for nearest neighbors using HNSW index.
|
||||||
|
|
||||||
|
|||||||
@@ -23,13 +23,13 @@ from .registry import BACKEND_REGISTRY
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_registered_backends() -> List[str]:
|
def get_registered_backends() -> list[str]:
|
||||||
"""Get list of registered backend names."""
|
"""Get list of registered backend names."""
|
||||||
return list(BACKEND_REGISTRY.keys())
|
return list(BACKEND_REGISTRY.keys())
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
chunks: List[str],
|
chunks: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
@@ -70,7 +70,7 @@ def compute_embeddings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_via_server(chunks: List[str], model_name: str, port: int) -> np.ndarray:
|
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -113,12 +113,12 @@ class SearchResult:
|
|||||||
id: str
|
id: str
|
||||||
score: float
|
score: float
|
||||||
text: str
|
text: str
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, passage_sources: list[Dict[str, Any]], metadata_file_path: Optional[str] = None
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
):
|
):
|
||||||
self.offset_maps = {}
|
self.offset_maps = {}
|
||||||
self.passage_files = {}
|
self.passage_files = {}
|
||||||
@@ -162,7 +162,7 @@ class PassageManager:
|
|||||||
for passage_id, offset in offset_map.items():
|
for passage_id, offset in offset_map.items():
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
# Lazy file opening - only open when needed
|
# Lazy file opening - only open when needed
|
||||||
@@ -260,9 +260,9 @@ class LeannBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
self.chunks: list[Dict[str, Any]] = []
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -618,7 +618,7 @@ class LeannChat:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[Dict[str, Any]] = None,
|
llm_config: Optional[dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -634,7 +634,7 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import difflib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models(host: str) -> List[str]:
|
def check_ollama_models(host: str) -> list[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
@@ -31,7 +31,7 @@ def check_ollama_models(host: str) -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]]:
|
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||||
"""Check if a model exists in Ollama's remote library and return available tags
|
"""Check if a model exists in Ollama's remote library and return available tags
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -94,7 +94,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]
|
|||||||
return True, []
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
||||||
"""Use intelligent fuzzy search for Ollama models"""
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
@@ -169,7 +169,7 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
|
|||||||
# Remove this too - no need for fallback
|
# Remove this too - no need for fallback
|
||||||
|
|
||||||
|
|
||||||
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
||||||
"""Use difflib to find similar model names"""
|
"""Use difflib to find similar model names"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
@@ -190,7 +190,7 @@ def check_hf_model_exists(model_name: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_popular_hf_models() -> List[str]:
|
def get_popular_hf_models() -> list[str]:
|
||||||
"""Return a list of popular HuggingFace models for suggestions"""
|
"""Return a list of popular HuggingFace models for suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
@@ -222,7 +222,7 @@ def get_popular_hf_models() -> List[str]:
|
|||||||
return _get_fallback_hf_models()
|
return _get_fallback_hf_models()
|
||||||
|
|
||||||
|
|
||||||
def _get_fallback_hf_models() -> List[str]:
|
def _get_fallback_hf_models() -> list[str]:
|
||||||
"""Fallback list of popular HuggingFace models"""
|
"""Fallback list of popular HuggingFace models"""
|
||||||
return [
|
return [
|
||||||
"microsoft/DialoGPT-medium",
|
"microsoft/DialoGPT-medium",
|
||||||
@@ -238,7 +238,7 @@ def _get_fallback_hf_models() -> List[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||||
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
@@ -304,7 +304,7 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||||
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
@@ -761,7 +761,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ Preserves all optimization parameters to ensure performance
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -19,11 +19,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: Dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: List[str],
|
texts: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
@@ -63,7 +63,7 @@ def compute_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_sentence_transformers(
|
||||||
texts: List[str],
|
texts: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_fp16: bool = True,
|
use_fp16: bool = True,
|
||||||
device: str = "auto",
|
device: str = "auto",
|
||||||
@@ -235,7 +235,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
@@ -296,7 +296,7 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Computes embeddings using an MLX model."""
|
"""Computes embeddings using an MLX model."""
|
||||||
try:
|
try:
|
||||||
@@ -371,7 +371,7 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
def compute_embeddings_ollama(
|
||||||
texts: List[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API.
|
Compute embeddings using Ollama API.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bo
|
|||||||
|
|
||||||
def _find_compatible_port_or_next_available(
|
def _find_compatible_port_or_next_available(
|
||||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
) -> Tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
"""
|
"""
|
||||||
Find a port that either has a compatible server or is available.
|
Find a port that either has a compatible server or is available.
|
||||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
@@ -193,7 +193,7 @@ class EmbeddingServerManager:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
passages_file = kwargs.get("passages_file")
|
passages_file = kwargs.get("passages_file")
|
||||||
|
|
||||||
@@ -225,7 +225,7 @@ class EmbeddingServerManager:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start server with Colab-specific configuration."""
|
"""Start server with Colab-specific configuration."""
|
||||||
# Try to find an available port
|
# Try to find an available port
|
||||||
try:
|
try:
|
||||||
@@ -261,7 +261,7 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
) -> Tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
logger.info(f"Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
|
|
||||||
@@ -321,7 +321,7 @@ class EmbeddingServerManager:
|
|||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> Tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
max_wait, wait_interval = 120, 0.5
|
max_wait, wait_interval = 120, 0.5
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
@@ -403,7 +403,7 @@ class EmbeddingServerManager:
|
|||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> Tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class LeannBackendBuilderInterface(ABC):
|
|||||||
"""Backend interface for building indexes"""
|
"""Backend interface for building indexes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
|
||||||
"""Build index
|
"""Build index
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,7 +52,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_meta(self) -> Dict[str, Any]:
|
def _load_meta(self) -> dict[str, Any]:
|
||||||
"""Loads the metadata file associated with the index."""
|
"""Loads the metadata file associated with the index."""
|
||||||
# This is the corrected logic for finding the meta file.
|
# This is the corrected logic for finding the meta file.
|
||||||
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
@@ -174,7 +174,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for the top_k nearest neighbors of the query vector.
|
Search for the top_k nearest neighbors of the query vector.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user