From 198044d033011ffded29d9afa0b42f5fa92cf432 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Wed, 24 Dec 2025 23:58:06 -0800 Subject: [PATCH] Add ty type checker to CI and fix type errors (fixes bug from PR #157) (#192) * Add ty type checker to CI and fix type errors - Add ty (Astral's fast Python type checker) to GitHub CI workflow - Fix type annotations across all RAG apps: - Update load_data return types from list[str] to list[dict[str, Any]] - Fix base_rag_example.py to properly handle dict format from create_text_chunks - Fix type errors in leann-core: - chunking_utils.py: Add explicit type annotations - cli.py: Fix return type annotations for PDF extraction functions - interactive_utils.py: Fix readline import type handling - Fix type errors in apps: - wechat_history.py: Fix return type annotations - document_rag.py, code_rag.py: Replace **kwargs with explicit arguments - Add ty configuration to pyproject.toml This resolves the bug introduced in PR #157 where create_text_chunks() changed to return list[dict] but callers were not updated. * Fix remaining ty type errors - Fix slack_mcp_reader.py channel parameter can be None - Fix embedding_compute.py ContextProp type issue - Fix searcher_base.py method override signatures - Fix chunking_utils.py chunk_text assignment - Fix slack_rag.py and twitter_rag.py return types - Fix email.py and image_rag.py method overrides * Fix multimodal benchmark scripts type errors - Fix undefined LeannRetriever -> LeannMultiVector - Add proper type casts for HuggingFace Dataset iteration - Cast task config values to correct types - Add type annotations for dataset row dicts * Enable ty check for multimodal scripts in CI All type errors in multimodal scripts have been fixed, so we can now include them in the CI type checking. * Fix all test type errors and enable ty check on tests - Fix test_basic.py: search() takes str not list - Fix test_cli_prompt_template.py: add type: ignore for Mock assignments - Fix test_prompt_template_persistence.py: match BaseSearcher.search signature - Fix test_prompt_template_e2e.py: add type narrowing asserts after skip - Fix test_readme_examples.py: use explicit kwargs instead of **model_args - Fix metadata_filter.py: allow Optional[MetadataFilters] - Update CI to run ty check on tests * Format code with ruff * Format searcher_base.py --- .github/workflows/build-reusable.yml | 23 ++++++- apps/base_rag_example.py | 10 +-- apps/browser_rag.py | 3 +- apps/chatgpt_rag.py | 3 +- apps/claude_rag.py | 3 +- apps/code_rag.py | 18 ++--- apps/document_rag.py | 20 +++--- apps/email_data/email.py | 3 +- apps/email_rag.py | 3 +- apps/history_data/wechat_history.py | 6 +- apps/image_rag.py | 5 +- apps/imessage_rag.py | 3 +- .../multi-vector-leann-paper-example.py | 9 +-- .../multi-vector-leann-similarity-map.py | 8 +-- .../vidore_v1_benchmark.py | 53 ++++++++------- .../vidore_v2_benchmark.py | 60 +++++++++-------- apps/slack_data/slack_mcp_reader.py | 9 ++- apps/slack_rag.py | 6 +- apps/twitter_rag.py | 6 +- apps/wechat_rag.py | 3 +- .../leann-core/src/leann/chunking_utils.py | 6 +- packages/leann-core/src/leann/cli.py | 4 +- .../leann-core/src/leann/embedding_compute.py | 3 +- .../leann-core/src/leann/interactive_utils.py | 7 +- .../leann-core/src/leann/metadata_filter.py | 4 +- .../leann-core/src/leann/searcher_base.py | 8 ++- pyproject.toml | 13 ++++ tests/test_basic.py | 4 +- tests/test_cli_prompt_template.py | 12 ++-- tests/test_prompt_template_e2e.py | 5 +- tests/test_prompt_template_persistence.py | 65 +++++++++++++++++-- tests/test_readme_examples.py | 20 +++--- 32 files changed, 261 insertions(+), 144 deletions(-) diff --git a/.github/workflows/build-reusable.yml b/.github/workflows/build-reusable.yml index f68662f..e094e78 100644 --- a/.github/workflows/build-reusable.yml +++ b/.github/workflows/build-reusable.yml @@ -28,9 +28,30 @@ jobs: run: | uv run --only-group lint pre-commit run --all-files --show-diff-on-failure + type-check: + name: Type Check with ty + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref }} + submodules: recursive + + - name: Install uv and Python + uses: astral-sh/setup-uv@v6 + with: + python-version: '3.11' + + - name: Install ty + run: uv tool install ty + + - name: Run ty type checker + run: | + # Run ty on core packages, apps, and tests + ty check packages/leann-core/src apps tests build: - needs: lint + needs: [lint, type-check] name: Build ${{ matrix.os }} Python ${{ matrix.python }} strategy: matrix: diff --git a/apps/base_rag_example.py b/apps/base_rag_example.py index f695610..1517191 100644 --- a/apps/base_rag_example.py +++ b/apps/base_rag_example.py @@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples. import argparse from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Union +from typing import Any import dotenv from leann.api import LeannBuilder, LeannChat @@ -257,8 +257,8 @@ class BaseRAGExample(ABC): pass @abstractmethod - async def load_data(self, args) -> list[Union[str, dict[str, Any]]]: - """Load data from the source. Returns list of text chunks (strings or dicts with 'text' key).""" + async def load_data(self, args) -> list[dict[str, Any]]: + """Load data from the source. Returns list of text chunks as dicts with 'text' and 'metadata' keys.""" pass def get_llm_config(self, args) -> dict[str, Any]: @@ -282,8 +282,8 @@ class BaseRAGExample(ABC): return config - async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str: - """Build LEANN index from texts (accepts strings or dicts with 'text' key).""" + async def build_index(self, args, texts: list[dict[str, Any]]) -> str: + """Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys).""" index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann") print(f"\n[Building Index] Creating {self.name} index...") diff --git a/apps/browser_rag.py b/apps/browser_rag.py index 6d21964..00bb3f5 100644 --- a/apps/browser_rag.py +++ b/apps/browser_rag.py @@ -6,6 +6,7 @@ Supports Chrome browser history. import os import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -85,7 +86,7 @@ class BrowserRAG(BaseRAGExample): return profiles - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load browser history and convert to text chunks.""" # Determine Chrome profiles if args.chrome_profile and not args.auto_find_profiles: diff --git a/apps/chatgpt_rag.py b/apps/chatgpt_rag.py index 3c92d04..c97d2cd 100644 --- a/apps/chatgpt_rag.py +++ b/apps/chatgpt_rag.py @@ -5,6 +5,7 @@ Supports ChatGPT export data from chat.html files. import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -80,7 +81,7 @@ class ChatGPTRAG(BaseRAGExample): return export_files - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load ChatGPT export data and convert to text chunks.""" export_path = Path(args.export_path) diff --git a/apps/claude_rag.py b/apps/claude_rag.py index 43b499e..2cc80dd 100644 --- a/apps/claude_rag.py +++ b/apps/claude_rag.py @@ -5,6 +5,7 @@ Supports Claude export data from JSON files. import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -80,7 +81,7 @@ class ClaudeRAG(BaseRAGExample): return export_files - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load Claude export data and convert to text chunks.""" export_path = Path(args.export_path) diff --git a/apps/code_rag.py b/apps/code_rag.py index 7518bb9..452e0a6 100644 --- a/apps/code_rag.py +++ b/apps/code_rag.py @@ -6,6 +6,7 @@ optimized chunking parameters. import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -77,7 +78,7 @@ class CodeRAG(BaseRAGExample): help="Try to preserve import statements in chunks (default: True)", ) - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load code files and convert to AST-aware chunks.""" print(f"🔍 Scanning code repository: {args.repo_dir}") print(f"📁 Including extensions: {args.include_extensions}") @@ -88,14 +89,6 @@ class CodeRAG(BaseRAGExample): if not repo_path.exists(): raise ValueError(f"Repository directory not found: {args.repo_dir}") - # Load code files with filtering - reader_kwargs = { - "recursive": True, - "encoding": "utf-8", - "required_exts": args.include_extensions, - "exclude_hidden": True, - } - # Create exclusion filter def file_filter(file_path: str) -> bool: """Filter out unwanted files and directories.""" @@ -120,8 +113,11 @@ class CodeRAG(BaseRAGExample): # Load documents with file filtering documents = SimpleDirectoryReader( args.repo_dir, - file_extractor=None, # Use default extractors - **reader_kwargs, + file_extractor=None, + recursive=True, + encoding="utf-8", + required_exts=args.include_extensions, + exclude_hidden=True, ).load_data(show_progress=True) # Apply custom filtering diff --git a/apps/document_rag.py b/apps/document_rag.py index 280d0fb..f8e0c66 100644 --- a/apps/document_rag.py +++ b/apps/document_rag.py @@ -5,7 +5,7 @@ Supports PDF, TXT, MD, and other document formats. import sys from pathlib import Path -from typing import Any, Union +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -52,7 +52,7 @@ class DocumentRAG(BaseRAGExample): help="Enable AST-aware chunking for code files in the data directory", ) - async def load_data(self, args) -> list[Union[str, dict[str, Any]]]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load documents and convert to text chunks.""" print(f"Loading documents from: {args.data_dir}") if args.file_types: @@ -66,16 +66,12 @@ class DocumentRAG(BaseRAGExample): raise ValueError(f"Data directory not found: {args.data_dir}") # Load documents - reader_kwargs = { - "recursive": True, - "encoding": "utf-8", - } - if args.file_types: - reader_kwargs["required_exts"] = args.file_types - - documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data( - show_progress=True - ) + documents = SimpleDirectoryReader( + args.data_dir, + recursive=True, + encoding="utf-8", + required_exts=args.file_types if args.file_types else None, + ).load_data(show_progress=True) if not documents: print(f"No documents found in {args.data_dir} with extensions {args.file_types}") diff --git a/apps/email_data/email.py b/apps/email_data/email.py index cad4062..0bb003a 100644 --- a/apps/email_data/email.py +++ b/apps/email_data/email.py @@ -127,11 +127,12 @@ class EmlxMboxReader(MboxReader): def load_data( self, - directory: Path, + file: Path, # Note: for EmlxMboxReader, this is actually a directory extra_info: dict | None = None, fs: AbstractFileSystem | None = None, ) -> list[Document]: """Parse .emlx files from directory into strings using MboxReader logic.""" + directory = file # Rename for clarity - this is a directory of .emlx files import os import tempfile diff --git a/apps/email_rag.py b/apps/email_rag.py index ec87bb1..0558678 100644 --- a/apps/email_rag.py +++ b/apps/email_rag.py @@ -5,6 +5,7 @@ Supports Apple Mail on macOS. import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -64,7 +65,7 @@ class EmailRAG(BaseRAGExample): return messages_dirs - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load emails and convert to text chunks.""" # Determine mail directories if args.mail_path: diff --git a/apps/history_data/wechat_history.py b/apps/history_data/wechat_history.py index e985bd4..9c99f77 100644 --- a/apps/history_data/wechat_history.py +++ b/apps/history_data/wechat_history.py @@ -86,7 +86,7 @@ class WeChatHistoryReader(BaseReader): text=True, timeout=5, ) - return result.returncode == 0 and result.stdout.strip() + return result.returncode == 0 and bool(result.stdout.strip()) except Exception: return False @@ -314,7 +314,9 @@ class WeChatHistoryReader(BaseReader): return concatenated_groups - def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str: + def _create_concatenated_content( + self, message_group: dict, contact_name: str + ) -> tuple[str, str]: """ Create concatenated content from a group of messages. diff --git a/apps/image_rag.py b/apps/image_rag.py index 4c33b69..8dcd62b 100644 --- a/apps/image_rag.py +++ b/apps/image_rag.py @@ -14,6 +14,7 @@ import argparse import pickle import tempfile from pathlib import Path +from typing import Any import numpy as np from PIL import Image @@ -65,7 +66,7 @@ class ImageRAG(BaseRAGExample): help="Batch size for CLIP embedding generation (default: 32)", ) - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load images, generate CLIP embeddings, and return text descriptions.""" self._image_data = self._load_images_and_embeddings(args) return [entry["text"] for entry in self._image_data] @@ -168,7 +169,7 @@ class ImageRAG(BaseRAGExample): print(f"✅ Processed {len(image_data)} images") return image_data - async def build_index(self, args, texts: list[str]) -> str: + async def build_index(self, args, texts: list[dict[str, Any]]) -> str: """Build index using pre-computed CLIP embeddings.""" from leann.api import LeannBuilder diff --git a/apps/imessage_rag.py b/apps/imessage_rag.py index 50032ec..bd4ab68 100644 --- a/apps/imessage_rag.py +++ b/apps/imessage_rag.py @@ -6,6 +6,7 @@ This example demonstrates how to build a RAG system on your iMessage conversatio import asyncio from pathlib import Path +from typing import Any from leann.chunking_utils import create_text_chunks @@ -56,7 +57,7 @@ class IMessageRAG(BaseRAGExample): help="Overlap between text chunks (default: 200)", ) - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load iMessage history and convert to text chunks.""" print("Loading iMessage conversation history...") diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py index 22102d3..16107ca 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-paper-example.py @@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3] _leann_core_src = _repo_root / "packages" / "leann-core" / "src" _leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw" if str(_leann_core_src) not in sys.path: - sys.path.append(str(_leann_core_src)) + sys.path.insert(0, str(_leann_core_src)) if str(_leann_hnsw_pkg) not in sys.path: - sys.path.append(str(_leann_hnsw_pkg)) + sys.path.insert(0, str(_leann_hnsw_pkg)) +from leann_multi_vector import LeannMultiVector import torch from colpali_engine.models import ColPali @@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader): print(ds[0].shape) # %% -# Build HNSW index via LeannRetriever primitives and run search +# Build HNSW index via LeannMultiVector primitives and run search index_path = "./indexes/colpali.leann" -retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1])) +retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1])) retriever.create_collection() filepaths = [os.path.join("./pages", name) for name in page_filenames] for i in range(len(filepaths)): diff --git a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py index fcde09f..f1be682 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py @@ -5,7 +5,7 @@ import argparse import faulthandler import os import time -from typing import Any, Optional +from typing import Any, Optional, cast import numpy as np from PIL import Image @@ -223,7 +223,7 @@ if need_to_build_index: # Use filenames as identifiers instead of full paths for cleaner metadata filepaths = [os.path.basename(fp) for fp in filepaths] elif USE_HF_DATASET: - from datasets import load_dataset, concatenate_datasets, DatasetDict + from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset # Determine which datasets to load if DATASET_NAMES is not None: @@ -281,12 +281,12 @@ if need_to_build_index: splits_to_load = DATASET_SPLITS # Load and concatenate multiple splits for this dataset - datasets_to_concat = [] + datasets_to_concat: list[Dataset] = [] for split in splits_to_load: if split not in dataset_dict: print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}") continue - split_dataset = dataset_dict[split] + split_dataset = cast(Dataset, dataset_dict[split]) print(f" Loaded split '{split}': {len(split_dataset)} pages") datasets_to_concat.append(split_dataset) diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py index 79472df..3b2d7df 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v1_benchmark.py @@ -25,9 +25,9 @@ Usage: import argparse import json import os -from typing import Optional +from typing import Any, Optional, cast -from datasets import load_dataset +from datasets import Dataset, load_dataset from leann_multi_vector import ( ViDoReBenchmarkEvaluator, _ensure_repo_paths_importable, @@ -151,40 +151,43 @@ def load_vidore_v1_data( """ print(f"Loading dataset: {dataset_path} (split={split})") - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict + query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)) - queries = {} + queries: dict[str, str] = {} for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + queries[query_id] = row_dict["query"] - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + # Load corpus (images) - cast to Dataset + corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision)) - corpus = {} + corpus: dict[str, Any] = {} for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] + if "image" in row_dict: + corpus[corpus_id] = row_dict["image"] + elif "page_image" in row_dict: + corpus[corpus_id] = row_dict["page_image"] else: raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" + f"No image field found in corpus. Available fields: {list(row_dict.keys())}" ) - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + # Load qrels (relevance judgments) - cast to Dataset + qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision)) - qrels = {} + qrels: dict[str, dict[str, int]] = {} for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" if query_id not in qrels: qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) + qrels[query_id][corpus_id] = int(row_dict["score"]) print( f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" @@ -234,8 +237,8 @@ def evaluate_task( raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}") task_config = VIDORE_V1_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] + dataset_path = str(task_config["dataset_path"]) + revision = str(task_config["revision"]) # Load data corpus, queries, qrels = load_vidore_v1_data( @@ -286,7 +289,7 @@ def evaluate_task( ) # Search queries - task_prompt = task_config.get("prompt") + task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt")) results = evaluator.search_queries( queries=queries, corpus_ids=corpus_ids_ordered, diff --git a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py index 8a34e69..be4eb4f 100644 --- a/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py +++ b/apps/multimodal/vision-based-pdf-multi-vector/vidore_v2_benchmark.py @@ -25,9 +25,9 @@ Usage: import argparse import json import os -from typing import Optional +from typing import Any, Optional, cast -from datasets import load_dataset +from datasets import Dataset, load_dataset from leann_multi_vector import ( ViDoReBenchmarkEvaluator, _ensure_repo_paths_importable, @@ -91,8 +91,8 @@ def load_vidore_v2_data( """ print(f"Loading dataset: {dataset_path} (split={split}, language={language})") - # Load queries - query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision) + # Load queries - cast to Dataset since we know split returns Dataset not DatasetDict + query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision)) # Check if dataset has language field before filtering has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names @@ -112,8 +112,9 @@ def load_vidore_v2_data( if len(query_ds_filtered) == 0: # Try to get a sample to see actual language values try: - sample_ds = load_dataset( - dataset_path, "queries", split=split, revision=revision + sample_ds = cast( + Dataset, + load_dataset(dataset_path, "queries", split=split, revision=revision), ) if len(sample_ds) > 0 and "language" in sample_ds.column_names: sample_langs = set(sample_ds["language"]) @@ -126,37 +127,40 @@ def load_vidore_v2_data( ) query_ds = query_ds_filtered - queries = {} + queries: dict[str, str] = {} for row in query_ds: - query_id = f"query-{split}-{row['query-id']}" - queries[query_id] = row["query"] + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + queries[query_id] = row_dict["query"] - # Load corpus (images) - corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision) + # Load corpus (images) - cast to Dataset + corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision)) - corpus = {} + corpus: dict[str, Any] = {} for row in corpus_ds: - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" # Extract image from the dataset row - if "image" in row: - corpus[corpus_id] = row["image"] - elif "page_image" in row: - corpus[corpus_id] = row["page_image"] + if "image" in row_dict: + corpus[corpus_id] = row_dict["image"] + elif "page_image" in row_dict: + corpus[corpus_id] = row_dict["page_image"] else: raise ValueError( - f"No image field found in corpus. Available fields: {list(row.keys())}" + f"No image field found in corpus. Available fields: {list(row_dict.keys())}" ) - # Load qrels (relevance judgments) - qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision) + # Load qrels (relevance judgments) - cast to Dataset + qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision)) - qrels = {} + qrels: dict[str, dict[str, int]] = {} for row in qrels_ds: - query_id = f"query-{split}-{row['query-id']}" - corpus_id = f"corpus-{split}-{row['corpus-id']}" + row_dict = cast(dict[str, Any], row) + query_id = f"query-{split}-{row_dict['query-id']}" + corpus_id = f"corpus-{split}-{row_dict['corpus-id']}" if query_id not in qrels: qrels[query_id] = {} - qrels[query_id][corpus_id] = int(row["score"]) + qrels[query_id][corpus_id] = int(row_dict["score"]) print( f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings" @@ -204,13 +208,13 @@ def evaluate_task( raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}") task_config = VIDORE_V2_TASKS[task_name] - dataset_path = task_config["dataset_path"] - revision = task_config["revision"] + dataset_path = str(task_config["dataset_path"]) + revision = str(task_config["revision"]) # Determine language if language is None: # Use first language if multiple available - languages = task_config.get("languages") + languages = cast(Optional[list[str]], task_config.get("languages")) if languages is None: # Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval) language = None @@ -269,7 +273,7 @@ def evaluate_task( ) # Search queries - task_prompt = task_config.get("prompt") + task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt")) results = evaluator.search_queries( queries=queries, corpus_ids=corpus_ids_ordered, diff --git a/apps/slack_data/slack_mcp_reader.py b/apps/slack_data/slack_mcp_reader.py index 8f24e1d..f1aaf41 100644 --- a/apps/slack_data/slack_mcp_reader.py +++ b/apps/slack_data/slack_mcp_reader.py @@ -177,7 +177,9 @@ class SlackMCPReader: break # If we get here, all retries failed or it's not a retryable error - raise last_exception + if last_exception is not None: + raise last_exception + raise RuntimeError("Unexpected error: no exception captured during retry loop") async def fetch_slack_messages( self, channel: Optional[str] = None, limit: int = 100 @@ -267,7 +269,10 @@ class SlackMCPReader: messages = json.loads(content["text"]) except json.JSONDecodeError: # If not JSON, try to parse as CSV format (Slack MCP server format) - messages = self._parse_csv_messages(content["text"], channel) + text_content = content.get("text", "") + messages = self._parse_csv_messages( + text_content if text_content else "", channel or "unknown" + ) else: messages = result["content"] else: diff --git a/apps/slack_rag.py b/apps/slack_rag.py index 1135a59..8980457 100644 --- a/apps/slack_rag.py +++ b/apps/slack_rag.py @@ -11,6 +11,7 @@ Usage: import argparse import asyncio +from typing import Any from apps.base_rag_example import BaseRAGExample from apps.slack_data.slack_mcp_reader import SlackMCPReader @@ -139,7 +140,7 @@ class SlackMCPRAG(BaseRAGExample): print("4. Try running the MCP server command directly to test it") return False - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load Slack messages via MCP server.""" print(f"Connecting to Slack MCP server: {args.mcp_server}") @@ -188,7 +189,8 @@ class SlackMCPRAG(BaseRAGExample): print(sample_text) print("-" * 40) - return texts + # Convert strings to dict format expected by base class + return [{"text": text, "metadata": {"source": "slack"}} for text in texts] except Exception as e: print(f"Error loading Slack data: {e}") diff --git a/apps/twitter_rag.py b/apps/twitter_rag.py index a7fd3a4..5446a5a 100644 --- a/apps/twitter_rag.py +++ b/apps/twitter_rag.py @@ -11,6 +11,7 @@ Usage: import argparse import asyncio +from typing import Any from apps.base_rag_example import BaseRAGExample from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader @@ -116,7 +117,7 @@ class TwitterMCPRAG(BaseRAGExample): print("5. Try running the MCP server command directly to test it") return False - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load Twitter bookmarks via MCP server.""" print(f"Connecting to Twitter MCP server: {args.mcp_server}") @@ -156,7 +157,8 @@ class TwitterMCPRAG(BaseRAGExample): print(sample_text) print("-" * 50) - return texts + # Convert strings to dict format expected by base class + return [{"text": text, "metadata": {"source": "twitter"}} for text in texts] except Exception as e: print(f"❌ Error loading Twitter bookmarks: {e}") diff --git a/apps/wechat_rag.py b/apps/wechat_rag.py index 7355c6f..1e5dd31 100644 --- a/apps/wechat_rag.py +++ b/apps/wechat_rag.py @@ -6,6 +6,7 @@ Supports WeChat chat history export and search. import subprocess import sys from pathlib import Path +from typing import Any # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) @@ -91,7 +92,7 @@ class WeChatRAG(BaseRAGExample): print(f"Export error: {e}") return False - async def load_data(self, args) -> list[str]: + async def load_data(self, args) -> list[dict[str, Any]]: """Load WeChat history and convert to text chunks.""" # Initialize WeChat reader with export capabilities reader = WeChatHistoryReader() diff --git a/packages/leann-core/src/leann/chunking_utils.py b/packages/leann-core/src/leann/chunking_utils.py index 34e0779..aae8761 100644 --- a/packages/leann-core/src/leann/chunking_utils.py +++ b/packages/leann-core/src/leann/chunking_utils.py @@ -239,11 +239,11 @@ def create_ast_chunks( chunks = chunk_builder.chunkify(code_content) for chunk in chunks: - chunk_text = None - astchunk_metadata = {} + chunk_text: str | None = None + astchunk_metadata: dict[str, Any] = {} if hasattr(chunk, "text"): - chunk_text = chunk.text + chunk_text = str(chunk.text) if chunk.text else None elif isinstance(chunk, str): chunk_text = chunk elif isinstance(chunk, dict): diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index 708892a..ce51637 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -19,7 +19,7 @@ from .settings import ( ) -def extract_pdf_text_with_pymupdf(file_path: str) -> str: +def extract_pdf_text_with_pymupdf(file_path: str) -> str | None: """Extract text from PDF using PyMuPDF for better quality.""" try: import fitz # PyMuPDF @@ -35,7 +35,7 @@ def extract_pdf_text_with_pymupdf(file_path: str) -> str: return None -def extract_pdf_text_with_pdfplumber(file_path: str) -> str: +def extract_pdf_text_with_pdfplumber(file_path: str) -> str | None: """Extract text from PDF using pdfplumber for better quality.""" try: import pdfplumber diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 093a710..eb2a1be 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -451,7 +451,8 @@ def compute_embeddings_sentence_transformers( # TODO: Haven't tested this yet torch.set_num_threads(min(8, os.cpu_count() or 4)) try: - torch.backends.mkldnn.enabled = True + # PyTorch's ContextProp type is complex; cast for type checker + torch.backends.mkldnn.enabled = True # type: ignore[assignment] except AttributeError: pass diff --git a/packages/leann-core/src/leann/interactive_utils.py b/packages/leann-core/src/leann/interactive_utils.py index 56f7731..ac803d2 100644 --- a/packages/leann-core/src/leann/interactive_utils.py +++ b/packages/leann-core/src/leann/interactive_utils.py @@ -11,14 +11,15 @@ from pathlib import Path from typing import Callable, Optional # Try to import readline with fallback for Windows +HAS_READLINE = False +readline = None # type: ignore[assignment] try: - import readline + import readline # type: ignore[no-redef] HAS_READLINE = True except ImportError: # Windows doesn't have readline by default - HAS_READLINE = False - readline = None + pass class InteractiveSession: diff --git a/packages/leann-core/src/leann/metadata_filter.py b/packages/leann-core/src/leann/metadata_filter.py index 1bf4ac1..5a8ffbd 100644 --- a/packages/leann-core/src/leann/metadata_filter.py +++ b/packages/leann-core/src/leann/metadata_filter.py @@ -7,7 +7,7 @@ operators for different data types including numbers, strings, booleans, and lis """ import logging -from typing import Any, Union +from typing import Any, Optional, Union logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ class MetadataFilterEngine: } def apply_filters( - self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters + self, search_results: list[dict[str, Any]], metadata_filters: Optional[MetadataFilters] ) -> list[dict[str, Any]]: """ Apply metadata filters to a list of search results. diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index f8ab71c..1def0ae 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -56,7 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): with open(meta_path, encoding="utf-8") as f: return json.load(f) - def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int: + def _ensure_server_running( + self, passages_source_file: str, port: Optional[int], **kwargs + ) -> int: """ Ensures the embedding server is running if recompute is needed. This is a helper for subclasses. @@ -81,7 +83,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): } server_started, actual_port = self.embedding_server_manager.start_server( - port=port, + port=port if port is not None else 5557, model_name=self.embedding_model, embedding_mode=self.embedding_mode, passages_file=passages_source_file, @@ -98,7 +100,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): self, query: str, use_server_if_available: bool = True, - zmq_port: int = 5557, + zmq_port: Optional[int] = None, query_template: Optional[str] = None, ) -> np.ndarray: """ diff --git a/pyproject.toml b/pyproject.toml index c19bcf0..dc53b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,19 @@ exclude = ["localhost", "127.0.0.1", "example.com"] exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"] scheme = ["https", "http"] +[tool.ty] +# Type checking with ty (Astral's fast Python type checker) +# ty is 10-100x faster than mypy. See: https://docs.astral.sh/ty/ + +[tool.ty.environment] +python-version = "3.11" +extra-paths = ["apps", "packages/leann-core/src"] + +[tool.ty.rules] +# Disable some noisy rules that have many false positives +possibly-missing-attribute = "ignore" +unresolved-import = "ignore" # Many optional dependencies + [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/tests/test_basic.py b/tests/test_basic.py index 651111f..0268a70 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -91,7 +91,7 @@ def test_large_index(): builder.build_index(index_path) searcher = LeannSearcher(index_path) - results = searcher.search(["word10 word20"], top_k=10) - assert len(results[0]) == 10 + results = searcher.search("word10 word20", top_k=10) + assert len(results) == 10 # Cleanup searcher.cleanup() diff --git a/tests/test_cli_prompt_template.py b/tests/test_cli_prompt_template.py index 981bb78..774e29f 100644 --- a/tests/test_cli_prompt_template.py +++ b/tests/test_cli_prompt_template.py @@ -123,7 +123,7 @@ class TestPromptTemplateStoredInEmbeddingOptions: cli = LeannCLI() # Mock load_documents to return a document so builder is created - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() @@ -175,7 +175,7 @@ class TestPromptTemplateStoredInEmbeddingOptions: cli = LeannCLI() # Mock load_documents to return a document so builder is created - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() @@ -230,7 +230,7 @@ class TestPromptTemplateStoredInEmbeddingOptions: cli = LeannCLI() # Mock load_documents to return a document so builder is created - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() @@ -307,7 +307,7 @@ class TestPromptTemplateStoredInEmbeddingOptions: cli = LeannCLI() # Mock load_documents to return a document so builder is created - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() @@ -376,7 +376,7 @@ class TestPromptTemplateStoredInEmbeddingOptions: cli = LeannCLI() # Mock load_documents to return a document so builder is created - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() @@ -432,7 +432,7 @@ class TestPromptTemplateFlowsToComputeEmbeddings: cli = LeannCLI() # Mock load_documents to return a simple document - cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment] parser = cli.create_parser() diff --git a/tests/test_prompt_template_e2e.py b/tests/test_prompt_template_e2e.py index 80c9cce..6cb40b1 100644 --- a/tests/test_prompt_template_e2e.py +++ b/tests/test_prompt_template_e2e.py @@ -67,7 +67,7 @@ def check_lmstudio_available() -> bool: return False -def get_lmstudio_first_model() -> str: +def get_lmstudio_first_model() -> str | None: """Get the first available model from LM Studio.""" try: response = requests.get("http://localhost:1234/v1/models", timeout=5.0) @@ -91,6 +91,7 @@ class TestPromptTemplateOpenAI: model_name = get_lmstudio_first_model() if not model_name: pytest.skip("No models loaded in LM Studio") + assert model_name is not None # Type narrowing for type checker texts = ["artificial intelligence", "machine learning"] prompt_template = "search_query: " @@ -120,6 +121,7 @@ class TestPromptTemplateOpenAI: model_name = get_lmstudio_first_model() if not model_name: pytest.skip("No models loaded in LM Studio") + assert model_name is not None # Type narrowing for type checker text = "machine learning" base_url = "http://localhost:1234/v1" @@ -271,6 +273,7 @@ class TestLMStudioSDK: model_name = get_lmstudio_first_model() if not model_name: pytest.skip("No models loaded in LM Studio") + assert model_name is not None # Type narrowing for type checker try: from leann.embedding_compute import _query_lmstudio_context_limit diff --git a/tests/test_prompt_template_persistence.py b/tests/test_prompt_template_persistence.py index eefda04..56391ff 100644 --- a/tests/test_prompt_template_persistence.py +++ b/tests/test_prompt_template_persistence.py @@ -581,7 +581,18 @@ class TestQueryTemplateApplicationInComputeEmbedding: # Create a concrete implementation for testing class TestSearcher(BaseSearcher): - def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + def search( + self, + query, + top_k, + complexity=64, + beam_width=1, + prune_ratio=0.0, + recompute_embeddings=False, + pruning_strategy="global", + zmq_port=None, + **kwargs, + ): return {"labels": [], "distances": []} searcher = object.__new__(TestSearcher) @@ -625,7 +636,18 @@ class TestQueryTemplateApplicationInComputeEmbedding: # Create a concrete implementation for testing class TestSearcher(BaseSearcher): - def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + def search( + self, + query, + top_k, + complexity=64, + beam_width=1, + prune_ratio=0.0, + recompute_embeddings=False, + pruning_strategy="global", + zmq_port=None, + **kwargs, + ): return {"labels": [], "distances": []} searcher = object.__new__(TestSearcher) @@ -671,7 +693,18 @@ class TestQueryTemplateApplicationInComputeEmbedding: from leann.searcher_base import BaseSearcher class TestSearcher(BaseSearcher): - def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + def search( + self, + query, + top_k, + complexity=64, + beam_width=1, + prune_ratio=0.0, + recompute_embeddings=False, + pruning_strategy="global", + zmq_port=None, + **kwargs, + ): return {"labels": [], "distances": []} searcher = object.__new__(TestSearcher) @@ -710,7 +743,18 @@ class TestQueryTemplateApplicationInComputeEmbedding: from leann.searcher_base import BaseSearcher class TestSearcher(BaseSearcher): - def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + def search( + self, + query, + top_k, + complexity=64, + beam_width=1, + prune_ratio=0.0, + recompute_embeddings=False, + pruning_strategy="global", + zmq_port=None, + **kwargs, + ): return {"labels": [], "distances": []} searcher = object.__new__(TestSearcher) @@ -774,7 +818,18 @@ class TestQueryTemplateApplicationInComputeEmbedding: from leann.searcher_base import BaseSearcher class TestSearcher(BaseSearcher): - def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + def search( + self, + query, + top_k, + complexity=64, + beam_width=1, + prune_ratio=0.0, + recompute_embeddings=False, + pruning_strategy="global", + zmq_port=None, + **kwargs, + ): return {"labels": [], "distances": []} searcher = object.__new__(TestSearcher) diff --git a/tests/test_readme_examples.py b/tests/test_readme_examples.py index a562a02..3ff0829 100644 --- a/tests/test_readme_examples.py +++ b/tests/test_readme_examples.py @@ -97,17 +97,17 @@ def test_backend_options(): with tempfile.TemporaryDirectory() as temp_dir: # Use smaller model in CI to avoid memory issues - if os.environ.get("CI") == "true": - model_args = { - "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", - "dimensions": 384, - } - else: - model_args = {} + is_ci = os.environ.get("CI") == "true" + embedding_model = ( + "sentence-transformers/all-MiniLM-L6-v2" if is_ci else "facebook/contriever" + ) + dimensions = 384 if is_ci else None # Test HNSW backend (as shown in README) hnsw_path = str(Path(temp_dir) / "test_hnsw.leann") - builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args) + builder_hnsw = LeannBuilder( + backend_name="hnsw", embedding_model=embedding_model, dimensions=dimensions + ) builder_hnsw.add_text("Test document for HNSW backend") builder_hnsw.build_index(hnsw_path) assert Path(hnsw_path).parent.exists() @@ -115,7 +115,9 @@ def test_backend_options(): # Test DiskANN backend (mentioned as available option) diskann_path = str(Path(temp_dir) / "test_diskann.leann") - builder_diskann = LeannBuilder(backend_name="diskann", **model_args) + builder_diskann = LeannBuilder( + backend_name="diskann", embedding_model=embedding_model, dimensions=dimensions + ) builder_diskann.add_text("Test document for DiskANN backend") builder_diskann.build_index(diskann_path) assert Path(diskann_path).parent.exists()