fix: use Python 3.9 compatible builtin generics

- Convert List[str] to list[str], Dict[str, Any] to dict[str, Any], etc.
- Use ruff --unsafe-fixes to automatically apply all type annotation updates
- Remove deprecated typing imports (List, Dict, Tuple) where no longer needed
- Keep Optional[str] syntax (union operator | not supported in Python 3.9)

Now all type annotations are Python 3.9 compatible with modern builtin generics.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Andy Lee
2025-08-10 00:38:33 +00:00
parent ffba435252
commit 6d1ac4a503
11 changed files with 49 additions and 49 deletions

View File

@@ -23,13 +23,13 @@ from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
def get_registered_backends() -> List[str]:
def get_registered_backends() -> list[str]:
"""Get list of registered backend names."""
return list(BACKEND_REGISTRY.keys())
def compute_embeddings(
chunks: List[str],
chunks: list[str],
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
@@ -70,7 +70,7 @@ def compute_embeddings(
)
def compute_embeddings_via_server(chunks: List[str], model_name: str, port: int) -> np.ndarray:
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
@@ -113,12 +113,12 @@ class SearchResult:
id: str
score: float
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
class PassageManager:
def __init__(
self, passage_sources: list[Dict[str, Any]], metadata_file_path: Optional[str] = None
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps = {}
self.passage_files = {}
@@ -162,7 +162,7 @@ class PassageManager:
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
def get_passage(self, passage_id: str) -> dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
@@ -260,9 +260,9 @@ class LeannBuilder:
)
self.backend_kwargs = backend_kwargs
self.chunks: list[Dict[str, Any]] = []
self.chunks: list[dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(len(self.chunks)))
@@ -618,7 +618,7 @@ class LeannChat:
def __init__(
self,
index_path: str,
llm_config: Optional[Dict[str, Any]] = None,
llm_config: Optional[dict[str, Any]] = None,
enable_warmup: bool = False,
**kwargs,
):
@@ -634,7 +634,7 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[Dict[str, Any]] = None,
llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):

View File

@@ -8,7 +8,7 @@ import difflib
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple
from typing import Any, Optional
import torch
@@ -17,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models(host: str) -> List[str]:
def check_ollama_models(host: str) -> list[str]:
"""Check available Ollama models and return a list"""
try:
import requests
@@ -31,7 +31,7 @@ def check_ollama_models(host: str) -> List[str]:
return []
def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]]:
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
@@ -94,7 +94,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
@@ -169,7 +169,7 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
@@ -190,7 +190,7 @@ def check_hf_model_exists(model_name: str) -> bool:
return False
def get_popular_hf_models() -> List[str]:
def get_popular_hf_models() -> list[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
@@ -222,7 +222,7 @@ def get_popular_hf_models() -> List[str]:
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> List[str]:
def _get_fallback_hf_models() -> list[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
@@ -238,7 +238,7 @@ def _get_fallback_hf_models() -> List[str]:
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
@@ -304,7 +304,7 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
return []
def search_hf_models(query: str, limit: int = 10) -> List[str]:
def search_hf_models(query: str, limit: int = 10) -> list[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
@@ -761,7 +761,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.

View File

@@ -7,7 +7,7 @@ Preserves all optimization parameters to ensure performance
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List
from typing import Any
import numpy as np
import torch
@@ -19,11 +19,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: Dict[str, Any] = {}
_model_cache: dict[str, Any] = {}
def compute_embeddings(
texts: List[str],
texts: list[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
@@ -63,7 +63,7 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers(
texts: List[str],
texts: list[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
@@ -235,7 +235,7 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
@@ -296,7 +296,7 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
@@ -371,7 +371,7 @@ def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: List[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API.

View File

@@ -6,7 +6,7 @@ import subprocess
import sys
import time
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional
import psutil
@@ -147,7 +147,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bo
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> Tuple[int, bool]:
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
@@ -193,7 +193,7 @@ class EmbeddingServerManager:
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> Tuple[bool, int]:
) -> tuple[bool, int]:
"""Start the embedding server."""
passages_file = kwargs.get("passages_file")
@@ -225,7 +225,7 @@ class EmbeddingServerManager:
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> Tuple[bool, int]:
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
@@ -261,7 +261,7 @@ class EmbeddingServerManager:
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> Tuple[bool, int]:
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -321,7 +321,7 @@ class EmbeddingServerManager:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> Tuple[bool, int]:
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
@@ -403,7 +403,7 @@ class EmbeddingServerManager:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> Tuple[bool, int]:
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab

View File

@@ -8,7 +8,7 @@ class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
"""Build index
Args:
@@ -52,7 +52,7 @@ class LeannBackendSearcherInterface(ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Search for nearest neighbors
Args:

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):

View File

@@ -46,7 +46,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
backend_module_name=backend_module_name,
)
def _load_meta(self) -> Dict[str, Any]:
def _load_meta(self) -> dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
@@ -174,7 +174,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.