fix: Fix pre-commit issues and update tests
- Fix import sorting and unused imports - Update type annotations to use built-in types (list, dict) instead of typing.List/Dict - Fix trailing whitespace and end-of-file issues - Fix Chinese fullwidth comma to regular comma - Update test_main_cli.py to test_document_rag.py - Add backward compatibility test for main_cli_example.py - Pass all pre-commit hooks (ruff, ruff-format, etc.)
This commit is contained in:
@@ -195,7 +195,7 @@ python ./examples/document_rag.py --query "What are the main techniques LEANN ex
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small
|
||||
--embedding-mode MODE # sentence-transformers, openai, or mlx
|
||||
|
||||
# LLM Parameters
|
||||
# LLM Parameters
|
||||
--llm TYPE # openai, ollama, or hf
|
||||
--llm-model MODEL # e.g., gpt-4o, llama3.2:1b
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
|
||||
@@ -61,4 +61,4 @@ This document ensures that the new unified interface maintains exact parameter c
|
||||
|
||||
5. **Special Cases**:
|
||||
- WeChat uses a specific Chinese embedding model
|
||||
- Email reader includes HTML processing option
|
||||
- Email reader includes HTML processing option
|
||||
|
||||
@@ -4,14 +4,12 @@ Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
@@ -129,11 +127,11 @@ class BaseRAGExample(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> List[str]:
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> Dict[str, Any]:
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
@@ -147,7 +145,7 @@ class BaseRAGExample(ABC):
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: List[str]) -> str:
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
@@ -256,7 +254,7 @@ class BaseRAGExample(ABC):
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
|
||||
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> List[str]:
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||
"""Helper function to create text chunks from documents."""
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
|
||||
@@ -6,7 +6,6 @@ Supports Chrome browser history.
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -52,7 +51,7 @@ class BrowserRAG(BaseRAGExample):
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> List[Path]:
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
@@ -73,7 +72,7 @@ class BrowserRAG(BaseRAGExample):
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> List[str]:
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
|
||||
@@ -5,7 +5,6 @@ Supports PDF, TXT, MD, and other document formats.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -16,52 +15,46 @@ from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files" # Match original main_cli_example.py default
|
||||
default_index_name="test_doc_files", # Match original main_cli_example.py default
|
||||
)
|
||||
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group('Document Parameters')
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (default: examples/data)"
|
||||
help="Directory containing documents to index (default: examples/data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=[".pdf", ".txt", ".md"],
|
||||
help="File types to process (default: .pdf .txt .md)"
|
||||
help="File types to process (default: .pdf .txt .md)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Text chunk size (default: 256)"
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Text chunk overlap (default: 128)"
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> List[str]:
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
print(f"File types: {args.file_types}")
|
||||
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
|
||||
# Load documents
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
@@ -69,31 +62,29 @@ class DocumentRAG(BaseRAGExample):
|
||||
encoding="utf-8",
|
||||
required_exts=args.file_types,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap
|
||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[:args.max_items]
|
||||
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
@@ -102,6 +93,6 @@ if __name__ == "__main__":
|
||||
print("- 'Summarize the key findings in these papers'")
|
||||
print("- 'What is the storage reduction achieved by LEANN?'")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
asyncio.run(rag.run())
|
||||
|
||||
@@ -3,10 +3,8 @@ Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -39,7 +37,7 @@ class EmailRAG(BaseRAGExample):
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> List[Path]:
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
@@ -53,7 +51,7 @@ class EmailRAG(BaseRAGExample):
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> List[str]:
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
|
||||
@@ -5,7 +5,6 @@ This file is kept for backward compatibility.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
print("=" * 70)
|
||||
print("NOTICE: This script has been replaced!")
|
||||
|
||||
@@ -6,7 +6,6 @@ Supports WeChat chat history export and search.
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -84,7 +83,7 @@ class WeChatRAG(BaseRAGExample):
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> List[str]:
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
export_path = Path(args.export_dir)
|
||||
|
||||
@@ -145,7 +144,7 @@ if __name__ == "__main__":
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...25339b0341
@@ -53,7 +53,7 @@ def test_document_rag_simulated(test_data_dir):
|
||||
|
||||
# Verify output
|
||||
output = result.stdout + result.stderr
|
||||
assert "Leann index built at" in output or "Using existing index" in output
|
||||
assert "Index saved to" in output or "Using existing index" in output
|
||||
assert "This is a simulated answer" in output
|
||||
|
||||
|
||||
@@ -117,4 +117,16 @@ def test_document_rag_error_handling(test_data_dir):
|
||||
|
||||
# Should fail with invalid LLM type
|
||||
assert result.returncode != 0
|
||||
assert "Unknown LLM type" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
assert "invalid choice" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
|
||||
|
||||
def test_main_cli_backward_compatibility():
|
||||
"""Test that main_cli_example.py shows migration message."""
|
||||
cmd = [sys.executable, "examples/main_cli_example.py", "--help"]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
# Should exit with error code and show migration message
|
||||
assert result.returncode != 0
|
||||
assert "This script has been replaced" in result.stdout
|
||||
assert "document_rag.py" in result.stdout
|
||||
|
||||
Reference in New Issue
Block a user