fix: resolve all ruff linting errors and add lint CI check
- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments - Replace Chinese comments with English equivalents - Fix unused imports with proper noqa annotations for intentional imports - Fix bare except clauses with specific exception types - Fix redefined variables and undefined names - Add ruff noqa annotations for generated protobuf files - Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
@@ -4,27 +4,30 @@ with the correct, original embedding logic from the user's reference code.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
from .chat import get_llm
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .registry import BACKEND_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: List[str],
|
||||
chunks: list[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: Optional[int] = None,
|
||||
port: int | None = None,
|
||||
is_build=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@@ -61,9 +64,7 @@ def compute_embeddings(
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_via_server(
|
||||
chunks: List[str], model_name: str, port: int
|
||||
) -> np.ndarray:
|
||||
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
@@ -73,9 +74,9 @@ def compute_embeddings_via_server(
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
@@ -104,11 +105,11 @@ class SearchResult:
|
||||
id: str
|
||||
score: float
|
||||
text: str
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PassageManager:
|
||||
def __init__(self, passage_sources: List[Dict[str, Any]]):
|
||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
@@ -117,15 +118,15 @@ class PassageManager:
|
||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
|
||||
|
||||
# Fix path resolution for Colab and other environments
|
||||
if not Path(index_file).is_absolute():
|
||||
# If relative path, try to resolve it properly
|
||||
index_file = str(Path(index_file).resolve())
|
||||
|
||||
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
|
||||
|
||||
with open(index_file, "rb") as f:
|
||||
offset_map = pickle.load(f)
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
@@ -135,11 +136,11 @@ class PassageManager:
|
||||
for passage_id, offset in offset_map.items():
|
||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||
if passage_id in self.global_offset_map:
|
||||
passage_file, offset = self.global_offset_map[passage_id]
|
||||
# Lazy file opening - only open when needed
|
||||
with open(passage_file, "r", encoding="utf-8") as f:
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
return json.loads(f.readline())
|
||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||
@@ -150,14 +151,12 @@ class LeannBuilder:
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: Optional[int] = None,
|
||||
dimensions: int | None = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
|
||||
backend_name
|
||||
)
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||
self.backend_factory = backend_factory
|
||||
@@ -165,9 +164,9 @@ class LeannBuilder:
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
@@ -197,9 +196,7 @@ class LeannBuilder:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
chunk_iterator = tqdm(
|
||||
self.chunks, desc="Writing passages", unit="chunk"
|
||||
)
|
||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||
except ImportError:
|
||||
chunk_iterator = self.chunks
|
||||
|
||||
@@ -229,9 +226,7 @@ class LeannBuilder:
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
builder_instance.build(
|
||||
embeddings, string_ids, index_path, **current_backend_kwargs
|
||||
)
|
||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
meta_data = {
|
||||
"version": "1.0",
|
||||
@@ -280,9 +275,7 @@ class LeannBuilder:
|
||||
ids, embeddings = data
|
||||
|
||||
if not isinstance(embeddings, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
||||
)
|
||||
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
@@ -294,9 +287,7 @@ class LeannBuilder:
|
||||
if self.dimensions is None:
|
||||
self.dimensions = embedding_dim
|
||||
elif self.dimensions != embedding_dim:
|
||||
raise ValueError(
|
||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||
)
|
||||
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
||||
|
||||
logger.info(
|
||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||
@@ -381,9 +372,7 @@ class LeannBuilder:
|
||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
|
||||
logger.info(
|
||||
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||
)
|
||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||
|
||||
|
||||
class LeannSearcher:
|
||||
@@ -391,20 +380,16 @@ class LeannSearcher:
|
||||
# Fix path resolution for Colab and other environments
|
||||
if not Path(index_path).is_absolute():
|
||||
index_path = str(Path(index_path).resolve())
|
||||
|
||||
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}"
|
||||
)
|
||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get(
|
||||
"embedding_mode", "sentence-transformers"
|
||||
)
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
@@ -426,7 +411,7 @@ class LeannSearcher:
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: int = 5557,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
) -> list[SearchResult]:
|
||||
logger.info("🔍 LeannSearcher.search() called:")
|
||||
logger.info(f" Query: '{query}'")
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
@@ -453,7 +438,7 @@ class LeannSearcher:
|
||||
zmq_port=zmq_port,
|
||||
)
|
||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||
embedding_time = time.time() - start_time
|
||||
time.time() - start_time
|
||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||
|
||||
start_time = time.time()
|
||||
@@ -468,17 +453,15 @@ class LeannSearcher:
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
time.time() - start_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(
|
||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||
)
|
||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||
):
|
||||
try:
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
@@ -490,15 +473,15 @@ class LeannSearcher:
|
||||
metadata=passage_data.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Color codes for better logging
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# Truncate text for display (first 100 chars)
|
||||
display_text = passage_data['text']
|
||||
display_text = passage_data["text"]
|
||||
logger.info(
|
||||
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||
)
|
||||
@@ -516,7 +499,7 @@ class LeannChat:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
llm_config: Optional[Dict[str, Any]] = None,
|
||||
llm_config: dict[str, Any] | None = None,
|
||||
enable_warmup: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -532,7 +515,7 @@ class LeannChat:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
llm_kwargs: dict[str, Any] | None = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user