style: organize imports per ruff; finish py39 Optional changes

- Fix import ordering in embedding servers and graph_partition_simple
- Remove duplicate Optional import
- Complete Optional[...] replacements
This commit is contained in:
Andy Lee
2025-08-07 15:06:25 -07:00
parent 65bbff1d93
commit 575b354976
5 changed files with 14 additions and 10 deletions

View File

@@ -10,6 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import zmq import zmq
@@ -32,7 +33,7 @@ if not logger.handlers:
def create_diskann_embedding_server( def create_diskann_embedding_server(
passages_file: str | None = None, passages_file: Optional[str] = None,
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",

View File

@@ -12,6 +12,7 @@ import shutil
import subprocess import subprocess
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
class GraphPartitioner: class GraphPartitioner:
@@ -88,8 +89,8 @@ class GraphPartitioner:
def partition_graph( def partition_graph(
self, self,
index_prefix_path: str, index_prefix_path: str,
output_dir: str | None = None, output_dir: Optional[str] = None,
partition_prefix: str | None = None, partition_prefix: Optional[str] = None,
**kwargs, **kwargs,
) -> tuple[str, str]: ) -> tuple[str, str]:
""" """

View File

@@ -10,10 +10,11 @@ import os
import subprocess import subprocess
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
def partition_graph_simple( def partition_graph_simple(
index_prefix_path: str, output_dir: str | None = None, **kwargs index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
) -> tuple[str, str]: ) -> tuple[str, str]:
""" """
Simple function to partition a graph index. Simple function to partition a graph index.

View File

@@ -10,6 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional
import msgpack import msgpack
import numpy as np import numpy as np
@@ -33,7 +34,7 @@ if not logger.handlers:
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: str | None = None, passages_file: Optional[str] = None,
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips", distance_metric: str = "mips",

View File

@@ -116,7 +116,7 @@ class SearchResult:
class PassageManager: class PassageManager:
def __init__( def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: str | None = None self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
): ):
self.offset_maps = {} self.offset_maps = {}
self.passage_files = {} self.passage_files = {}
@@ -180,7 +180,7 @@ class LeannBuilder:
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name) backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory self.backend_factory = backend_factory
@@ -260,7 +260,7 @@ class LeannBuilder:
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
self.chunks: list[dict[str, Any]] = [] self.chunks: list[dict[str, Any]] = []
def add_text(self, text: str, metadata: dict[str, Any] | None = None): def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
passage_id = metadata.get("id", str(len(self.chunks))) passage_id = metadata.get("id", str(len(self.chunks)))
@@ -611,7 +611,7 @@ class LeannChat:
def __init__( def __init__(
self, self,
index_path: str, index_path: str,
llm_config: dict[str, Any] | None = None, llm_config: Optional[dict[str, Any]] = None,
enable_warmup: bool = False, enable_warmup: bool = False,
**kwargs, **kwargs,
): ):
@@ -627,7 +627,7 @@ class LeannChat:
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = True, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: dict[str, Any] | None = None, llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
**search_kwargs, **search_kwargs,
): ):