fix: resolve all ruff linting errors and add lint CI check
- Fix ambiguous fullwidth characters (commas, parentheses) in strings and comments - Replace Chinese comments with English equivalents - Fix unused imports with proper noqa annotations for intentional imports - Fix bare except clauses with specific exception types - Fix redefined variables and undefined names - Add ruff noqa annotations for generated protobuf files - Add lint and format check to GitHub Actions CI pipeline
This commit is contained in:
@@ -1 +1 @@
|
||||
# This file makes the directory a Python package
|
||||
# This file makes the directory a Python package
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import diskann_backend
|
||||
from . import diskann_backend as diskann_backend
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import numpy as np
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Literal, Optional
|
||||
import contextlib
|
||||
from typing import Any, Literal
|
||||
|
||||
import logging
|
||||
|
||||
from leann.searcher_base import BaseSearcher
|
||||
from leann.registry import register_backend
|
||||
import numpy as np
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendSearcherInterface,
|
||||
)
|
||||
from leann.registry import register_backend
|
||||
from leann.searcher_base import BaseSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,7 +99,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
@@ -186,11 +185,11 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: int | None = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Search for nearest neighbors using DiskANN index.
|
||||
|
||||
@@ -216,14 +215,10 @@ class DiskannSearcher(BaseSearcher):
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError(
|
||||
"zmq_port must be provided if recompute_embeddings is True"
|
||||
)
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(
|
||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||
)
|
||||
logger.debug(f"Updating DiskANN zmq_port from {current_port} to {zmq_port}")
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
@@ -259,8 +254,6 @@ class DiskannSearcher(BaseSearcher):
|
||||
use_global_pruning,
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||
]
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
@@ -3,16 +3,16 @@ DiskANN-specific embedding server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import zmq
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -32,7 +32,7 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
passages_file: str | None = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
@@ -50,8 +50,8 @@ def create_diskann_embedding_server(
|
||||
sys.path.insert(0, str(leann_core_path))
|
||||
|
||||
try:
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
@@ -76,7 +76,7 @@ def create_diskann_embedding_server(
|
||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||
|
||||
# Load metadata to get passage sources
|
||||
with open(passages_file, "r") as f:
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
@@ -150,9 +150,7 @@ def create_diskann_embedding_server(
|
||||
):
|
||||
texts = request
|
||||
is_text_request = True
|
||||
logger.info(
|
||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||
)
|
||||
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||
else:
|
||||
raise ValueError("Not a valid msgpack text request")
|
||||
except Exception as msgpack_error:
|
||||
@@ -167,9 +165,7 @@ def create_diskann_embedding_server(
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(
|
||||
f"FATAL: Empty text for passage ID {nid}"
|
||||
)
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
except KeyError as e:
|
||||
logger.error(f"Passage ID {nid} not found: {e}")
|
||||
@@ -180,9 +176,7 @@ def create_diskann_embedding_server(
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"Processing {len(texts)} texts")
|
||||
logger.debug(
|
||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||
) # Show first 5
|
||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||
|
||||
# Process embeddings using unified computation
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
@@ -199,9 +193,7 @@ def create_diskann_embedding_server(
|
||||
else:
|
||||
# For DiskANN C++ compatibility: return protobuf format
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(
|
||||
embeddings, dtype=np.float32
|
||||
)
|
||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
|
||||
# Serialize embeddings data
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
|
||||
@@ -1,27 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: embedding.proto
|
||||
# ruff: noqa
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||
)
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
Reference in New Issue
Block a user