Format code with ruff

This commit is contained in:
Andy Lee
2025-12-24 00:23:25 +00:00
parent b754474c44
commit a0e53ef8f1
4 changed files with 68 additions and 8 deletions

View File

@@ -314,7 +314,9 @@ class WeChatHistoryReader(BaseReader):
return concatenated_groups
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> tuple[str, str]:
def _create_concatenated_content(
self, message_group: dict, contact_name: str
) -> tuple[str, str]:
"""
Create concatenated content from a group of messages.

View File

@@ -113,7 +113,8 @@ def load_vidore_v2_data(
# Try to get a sample to see actual language values
try:
sample_ds = cast(
Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)
Dataset,
load_dataset(dataset_path, "queries", split=split, revision=revision),
)
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
sample_langs = set(sample_ds["language"])

View File

@@ -581,7 +581,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query, top_k, complexity=64, beam_width=1, prune_ratio=0.0, recompute_embeddings=False, pruning_strategy="global", zmq_port=None, **kwargs):
def search(
self,
query,
top_k,
complexity=64,
beam_width=1,
prune_ratio=0.0,
recompute_embeddings=False,
pruning_strategy="global",
zmq_port=None,
**kwargs,
):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
@@ -625,7 +636,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
# Create a concrete implementation for testing
class TestSearcher(BaseSearcher):
def search(self, query, top_k, complexity=64, beam_width=1, prune_ratio=0.0, recompute_embeddings=False, pruning_strategy="global", zmq_port=None, **kwargs):
def search(
self,
query,
top_k,
complexity=64,
beam_width=1,
prune_ratio=0.0,
recompute_embeddings=False,
pruning_strategy="global",
zmq_port=None,
**kwargs,
):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
@@ -671,7 +693,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query, top_k, complexity=64, beam_width=1, prune_ratio=0.0, recompute_embeddings=False, pruning_strategy="global", zmq_port=None, **kwargs):
def search(
self,
query,
top_k,
complexity=64,
beam_width=1,
prune_ratio=0.0,
recompute_embeddings=False,
pruning_strategy="global",
zmq_port=None,
**kwargs,
):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
@@ -710,7 +743,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query, top_k, complexity=64, beam_width=1, prune_ratio=0.0, recompute_embeddings=False, pruning_strategy="global", zmq_port=None, **kwargs):
def search(
self,
query,
top_k,
complexity=64,
beam_width=1,
prune_ratio=0.0,
recompute_embeddings=False,
pruning_strategy="global",
zmq_port=None,
**kwargs,
):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)
@@ -774,7 +818,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
from leann.searcher_base import BaseSearcher
class TestSearcher(BaseSearcher):
def search(self, query, top_k, complexity=64, beam_width=1, prune_ratio=0.0, recompute_embeddings=False, pruning_strategy="global", zmq_port=None, **kwargs):
def search(
self,
query,
top_k,
complexity=64,
beam_width=1,
prune_ratio=0.0,
recompute_embeddings=False,
pruning_strategy="global",
zmq_port=None,
**kwargs,
):
return {"labels": [], "distances": []}
searcher = object.__new__(TestSearcher)

View File

@@ -98,7 +98,9 @@ def test_backend_options():
with tempfile.TemporaryDirectory() as temp_dir:
# Use smaller model in CI to avoid memory issues
is_ci = os.environ.get("CI") == "true"
embedding_model = "sentence-transformers/all-MiniLM-L6-v2" if is_ci else "facebook/contriever"
embedding_model = (
"sentence-transformers/all-MiniLM-L6-v2" if is_ci else "facebook/contriever"
)
dimensions = 384 if is_ci else None
# Test HNSW backend (as shown in README)