fix: Python 3.9 compatibility - replace union types and builtin generics
- Replace 'str | None' with 'Optional[str]' - Replace 'list[str]' with 'List[str]' - Replace 'dict[' with 'Dict[' - Replace 'tuple[' with 'Tuple[' - Add missing typing imports (List, Dict, Tuple) Fixes TypeError: unsupported operand type(s) for |: 'type' and 'NoneType' 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -85,7 +85,7 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||
|
||||
@@ -202,7 +202,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
@@ -388,7 +388,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using DiskANN index.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
@@ -92,7 +92,7 @@ class GraphPartitioner:
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
@@ -263,11 +263,11 @@ class GraphPartitioner:
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: str | None = None,
|
||||
partition_prefix: str | None = None,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Optional
|
||||
|
||||
def partition_graph_simple(
|
||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Simple function to partition a graph index.
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
||||
)
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
from . import faiss # type: ignore
|
||||
|
||||
path = Path(index_path)
|
||||
@@ -160,7 +160,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using HNSW index.
|
||||
|
||||
|
||||
@@ -23,13 +23,13 @@ from .registry import BACKEND_REGISTRY
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_registered_backends() -> list[str]:
|
||||
def get_registered_backends() -> List[str]:
|
||||
"""Get list of registered backend names."""
|
||||
return list(BACKEND_REGISTRY.keys())
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: list[str],
|
||||
chunks: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
@@ -70,7 +70,7 @@ def compute_embeddings(
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
||||
def compute_embeddings_via_server(chunks: List[str], model_name: str, port: int) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
@@ -113,12 +113,12 @@ class SearchResult:
|
||||
id: str
|
||||
score: float
|
||||
text: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PassageManager:
|
||||
def __init__(
|
||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||
self, passage_sources: list[Dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||
):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
@@ -162,7 +162,7 @@ class PassageManager:
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
@@ -260,9 +260,9 @@ class LeannBuilder:
|
||||
)
|
||||
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
self.chunks: list[Dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
@@ -618,7 +618,7 @@ class LeannChat:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
llm_config: Optional[dict[str, Any]] = None,
|
||||
llm_config: Optional[Dict[str, Any]] = None,
|
||||
enable_warmup: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -634,7 +634,7 @@ class LeannChat:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
|
||||
@@ -8,7 +8,7 @@ import difflib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +17,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ollama_models(host: str) -> list[str]:
|
||||
def check_ollama_models(host: str) -> List[str]:
|
||||
"""Check available Ollama models and return a list"""
|
||||
try:
|
||||
import requests
|
||||
@@ -31,7 +31,7 @@ def check_ollama_models(host: str) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||
def check_ollama_model_exists_remotely(model_name: str) -> Tuple[bool, List[str]]:
|
||||
"""Check if a model exists in Ollama's remote library and return available tags
|
||||
|
||||
Returns:
|
||||
@@ -94,7 +94,7 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
|
||||
return True, []
|
||||
|
||||
|
||||
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||
"""Use intelligent fuzzy search for Ollama models"""
|
||||
if not available_models:
|
||||
return []
|
||||
@@ -169,7 +169,7 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
|
||||
# Remove this too - no need for fallback
|
||||
|
||||
|
||||
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
||||
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
||||
"""Use difflib to find similar model names"""
|
||||
if not available_models:
|
||||
return []
|
||||
@@ -190,7 +190,7 @@ def check_hf_model_exists(model_name: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_popular_hf_models() -> list[str]:
|
||||
def get_popular_hf_models() -> List[str]:
|
||||
"""Return a list of popular HuggingFace models for suggestions"""
|
||||
try:
|
||||
from huggingface_hub import list_models
|
||||
@@ -222,7 +222,7 @@ def get_popular_hf_models() -> list[str]:
|
||||
return _get_fallback_hf_models()
|
||||
|
||||
|
||||
def _get_fallback_hf_models() -> list[str]:
|
||||
def _get_fallback_hf_models() -> List[str]:
|
||||
"""Fallback list of popular HuggingFace models"""
|
||||
return [
|
||||
"microsoft/DialoGPT-medium",
|
||||
@@ -238,7 +238,7 @@ def _get_fallback_hf_models() -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||
try:
|
||||
from huggingface_hub import list_models
|
||||
@@ -304,14 +304,14 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
||||
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||
return search_hf_models_fuzzy(query, limit)
|
||||
|
||||
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
) -> str | None:
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models(host)
|
||||
@@ -761,7 +761,7 @@ class SimulatedChat(LLMInterface):
|
||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||
|
||||
|
||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
||||
"""
|
||||
Factory function to get an LLM interface based on configuration.
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
@@ -277,7 +278,7 @@ Examples:
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str, custom_file_types: str | None = None):
|
||||
def load_documents(self, docs_dir: str, custom_file_types: Optional[str] = None):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
if custom_file_types:
|
||||
print(f"Using custom file types: {custom_file_types}")
|
||||
|
||||
@@ -7,7 +7,7 @@ Preserves all optimization parameters to ensure performance
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -19,11 +19,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
_model_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
is_build: bool = False,
|
||||
@@ -63,7 +63,7 @@ def compute_embeddings(
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
texts: list[str],
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
@@ -235,7 +235,7 @@ def compute_embeddings_sentence_transformers(
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
@@ -296,7 +296,7 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
# TODO: @yichuan-w add progress bar only in build mode
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
@@ -371,7 +371,7 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
texts: List[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API.
|
||||
|
||||
@@ -6,7 +6,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -147,7 +147,7 @@ def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bo
|
||||
|
||||
def _find_compatible_port_or_next_available(
|
||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||
) -> tuple[int, bool]:
|
||||
) -> Tuple[int, bool]:
|
||||
"""
|
||||
Find a port that either has a compatible server or is available.
|
||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||
@@ -193,7 +193,7 @@ class EmbeddingServerManager:
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
) -> Tuple[bool, int]:
|
||||
"""Start the embedding server."""
|
||||
passages_file = kwargs.get("passages_file")
|
||||
|
||||
@@ -225,7 +225,7 @@ class EmbeddingServerManager:
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
) -> Tuple[bool, int]:
|
||||
"""Start server with Colab-specific configuration."""
|
||||
# Try to find an available port
|
||||
try:
|
||||
@@ -261,7 +261,7 @@ class EmbeddingServerManager:
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
) -> Tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
logger.info(f"Starting embedding server on port {port}...")
|
||||
|
||||
@@ -321,7 +321,7 @@ class EmbeddingServerManager:
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||
def _wait_for_server_ready(self, port: int) -> Tuple[bool, int]:
|
||||
"""Wait for the server to be ready."""
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
@@ -403,7 +403,7 @@ class EmbeddingServerManager:
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||
def _wait_for_server_ready_colab(self, port: int) -> Tuple[bool, int]:
|
||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ class LeannBackendBuilderInterface(ABC):
|
||||
"""Backend interface for building indexes"""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs) -> None:
|
||||
"""Build index
|
||||
|
||||
Args:
|
||||
@@ -52,7 +52,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
|
||||
|
||||
def register_backend(name: str):
|
||||
|
||||
@@ -46,7 +46,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
backend_module_name=backend_module_name,
|
||||
)
|
||||
|
||||
def _load_meta(self) -> dict[str, Any]:
|
||||
def _load_meta(self) -> Dict[str, Any]:
|
||||
"""Loads the metadata file associated with the index."""
|
||||
# This is the corrected logic for finding the meta file.
|
||||
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
@@ -174,7 +174,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for the top_k nearest neighbors of the query vector.
|
||||
|
||||
|
||||
10
uv.lock
generated
10
uv.lock
generated
@@ -2155,7 +2155,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.2.1"
|
||||
version = "0.2.5"
|
||||
source = { editable = "packages/leann-backend-diskann" }
|
||||
dependencies = [
|
||||
{ name = "leann-core" },
|
||||
@@ -2167,14 +2167,14 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "leann-core", specifier = "==0.2.1" },
|
||||
{ name = "leann-core", specifier = "==0.2.5" },
|
||||
{ name = "numpy" },
|
||||
{ name = "protobuf", specifier = ">=3.19.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.2.1"
|
||||
version = "0.2.5"
|
||||
source = { editable = "packages/leann-backend-hnsw" }
|
||||
dependencies = [
|
||||
{ name = "leann-core" },
|
||||
@@ -2187,7 +2187,7 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "leann-core", specifier = "==0.2.1" },
|
||||
{ name = "leann-core", specifier = "==0.2.5" },
|
||||
{ name = "msgpack", specifier = ">=1.0.0" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pyzmq", specifier = ">=23.0.0" },
|
||||
@@ -2195,7 +2195,7 @@ requires-dist = [
|
||||
|
||||
[[package]]
|
||||
name = "leann-core"
|
||||
version = "0.2.1"
|
||||
version = "0.2.5"
|
||||
source = { editable = "packages/leann-core" }
|
||||
dependencies = [
|
||||
{ name = "accelerate" },
|
||||
|
||||
Reference in New Issue
Block a user