fix: Python 3.9 compatibility - replace union types and builtin generics

- Replace 'str | None' with 'Optional[str]'
- Replace 'list[str]' with 'List[str]'
- Replace 'dict[' with 'Dict['
- Replace 'tuple[' with 'Tuple['
- Add missing typing imports (List, Dict, Tuple)

Fixes TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Andy Lee
2025-08-10 00:29:46 +00:00
parent 728fa42ad5
commit ffba435252
13 changed files with 59 additions and 58 deletions

View File

@@ -4,7 +4,7 @@ import os
import struct
import sys
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Tuple
import numpy as np
import psutil
@@ -85,7 +85,7 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
f.write(data.tobytes())
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
def _calculate_smart_memory_config(data: np.ndarray) -> Tuple[float, float]:
"""
Calculate smart memory configuration for DiskANN based on data size and system specs.
@@ -202,7 +202,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
@@ -388,7 +388,7 @@ class DiskannSearcher(BaseSearcher):
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Search for nearest neighbors using DiskANN index.

View File

@@ -12,7 +12,7 @@ import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
class GraphPartitioner:
@@ -92,7 +92,7 @@ class GraphPartitioner:
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
) -> Tuple[str, str]:
"""
Partition a disk-based index for improved performance.
@@ -263,11 +263,11 @@ class GraphPartitioner:
def partition_graph(
index_prefix_path: str,
output_dir: str | None = None,
partition_prefix: str | None = None,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
) -> Tuple[str, str]:
"""
Convenience function to partition a graph index.

View File

@@ -15,7 +15,7 @@ from typing import Optional
def partition_graph_simple(
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
) -> tuple[str, str]:
) -> Tuple[str, str]:
"""
Simple function to partition a graph index.

View File

@@ -61,7 +61,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
)
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss # type: ignore
path = Path(index_path)
@@ -160,7 +160,7 @@ class HNSWSearcher(BaseSearcher):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.

View File

@@ -23,13 +23,13 @@ from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
def get_registered_backends() -> list[str]:
def get_registered_backends() -> List[str]:
"""Get list of registered backend names."""
return list(BACKEND_REGISTRY.keys())
def compute_embeddings(
chunks: list[str],
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
@@ -70,7 +70,7 @@ def compute_embeddings(
)
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
def compute_embeddings_via_server(chunks: List[str], model_name: str, port: int) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
@@ -113,12 +113,12 @@ class SearchResult:
id: str
score: float
text: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
self, passage_sources: list[Dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps = {}
self.passage_files = {}
@@ -162,7 +162,7 @@ class PassageManager:
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> dict[str, Any]:
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
@@ -260,9 +260,9 @@ class LeannBuilder:
)
self.backend_kwargs = backend_kwargs
self.chunks: list[dict[str, Any]] = []
self.chunks: list[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(len(self.chunks)))
@@ -618,7 +618,7 @@ class LeannChat:
def __init__(
self,
index_path: str,
llm_config: Optional[dict[str, Any]] = None,
llm_config: Optional[Dict[str, Any]] = None,
enable_warmup: bool = False,
**kwargs,
):
@@ -634,7 +634,7 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
**search_kwargs,
):

View File

@@ -8,7 +8,7 @@ import difflib
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, List, Optional, Tuple
import torch
@@ -17,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models(host: str) -> list[str]:
def check_ollama_models(host: str) -> List[str]:
"""Check available Ollama models and return a list"""
try:
import requests
@@ -31,7 +31,7 @@ def check_ollama_models(host: str) -> list[str]:
return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
@@ -94,7 +94,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
return True, []
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
@@ -169,7 +169,7 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
@@ -190,7 +190,7 @@ def check_hf_model_exists(model_name: str) -> bool:
return False
def get_popular_hf_models() -> list[str]:
def get_popular_hf_models() -> List[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
@@ -222,7 +222,7 @@ def get_popular_hf_models() -> list[str]:
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> list[str]:
def _get_fallback_hf_models() -> List[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
@@ -238,7 +238,7 @@ def _get_fallback_hf_models() -> list[str]:
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
@@ -304,14 +304,14 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
return []
def search_hf_models(query: str, limit: int = 10) -> list[str]:
def search_hf_models(query: str, limit: int = 10) -> List[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> str | None:
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
@@ -761,7 +761,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.

View File

@@ -1,6 +1,7 @@
import argparse
import asyncio
from pathlib import Path
from typing import Optional
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
@@ -277,7 +278,7 @@ Examples:
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str, custom_file_types: str | None = None):
def load_documents(self, docs_dir: str, custom_file_types: Optional[str] = None):
print(f"Loading documents from {docs_dir}...")
if custom_file_types:
print(f"Using custom file types: {custom_file_types}")

View File

@@ -7,7 +7,7 @@ Preserves all optimization parameters to ensure performance
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
from typing import Any, Dict, List
import numpy as np
import torch
@@ -19,11 +19,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: list[str],
texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
@@ -63,7 +63,7 @@ def compute_embeddings(
def compute_embeddings_sentence_transformers(
texts: list[str],
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
@@ -235,7 +235,7 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
@@ -296,7 +296,7 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
return embeddings
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
@@ -371,7 +371,7 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
texts: List[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API.

View File

@@ -6,7 +6,7 @@ import subprocess
import sys
import time
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
import psutil
@@ -147,7 +147,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bo
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
) -> Tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
@@ -193,7 +193,7 @@ class EmbeddingServerManager:
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
) -> Tuple[bool, int]:
"""Start the embedding server."""
passages_file = kwargs.get("passages_file")
@@ -225,7 +225,7 @@ class EmbeddingServerManager:
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
) -> Tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
@@ -261,7 +261,7 @@ class EmbeddingServerManager:
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
) -> Tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -321,7 +321,7 @@ class EmbeddingServerManager:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
def _wait_for_server_ready(self, port: int) -> Tuple[bool, int]:
"""Wait for the server to be ready."""
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
@@ -403,7 +403,7 @@ class EmbeddingServerManager:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
def _wait_for_server_ready_colab(self, port: int) -> Tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab

View File

@@ -8,7 +8,7 @@ class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
"""Build index
Args:
@@ -52,7 +52,7 @@ class LeannBackendSearcherInterface(ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""Search for nearest neighbors
Args:

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):

View File

@@ -46,7 +46,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
backend_module_name=backend_module_name,
)
def _load_meta(self) -> dict[str, Any]:
def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
@@ -174,7 +174,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Search for the top_k nearest neighbors of the query vector.

10
uv.lock generated
View File

@@ -2155,7 +2155,7 @@ wheels = [
[[package]]
name = "leann-backend-diskann"
version = "0.2.1"
version = "0.2.5"
source = { editable = "packages/leann-backend-diskann" }
dependencies = [
{ name = "leann-core" },
@@ -2167,14 +2167,14 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "leann-core", specifier = "==0.2.1" },
{ name = "leann-core", specifier = "==0.2.5" },
{ name = "numpy" },
{ name = "protobuf", specifier = ">=3.19.0" },
]
[[package]]
name = "leann-backend-hnsw"
version = "0.2.1"
version = "0.2.5"
source = { editable = "packages/leann-backend-hnsw" }
dependencies = [
{ name = "leann-core" },
@@ -2187,7 +2187,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "leann-core", specifier = "==0.2.1" },
{ name = "leann-core", specifier = "==0.2.5" },
{ name = "msgpack", specifier = ">=1.0.0" },
{ name = "numpy" },
{ name = "pyzmq", specifier = ">=23.0.0" },
@@ -2195,7 +2195,7 @@ requires-dist = [
[[package]]
name = "leann-core"
version = "0.2.1"
version = "0.2.5"
source = { editable = "packages/leann-core" }
dependencies = [
{ name = "accelerate" },