From ddc789b2316d441261d3db119a6ec750dae4542a Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 29 Jul 2025 13:33:40 -0700 Subject: [PATCH] fix: Restore embedding-mode parameter to all examples - All examples now have --embedding-mode parameter (unified interface benefit) - Default is 'sentence-transformers' (consistent with original behavior) - Users can now use OpenAI or MLX embeddings with any data source - Maintains functional equivalence with original scripts --- examples/base_rag_example.py | 43 +++++++++++---------------- examples/browser_rag.py | 2 +- examples/document_rag.py | 2 +- examples/email_rag.py | 3 +- examples/wechat_rag.py | 3 +- test/sanity_checks/debug_zmq_issue.py | 6 ++-- tests/test_document_rag.py | 2 +- 7 files changed, 25 insertions(+), 36 deletions(-) diff --git a/examples/base_rag_example.py b/examples/base_rag_example.py index a307193..84cb957 100644 --- a/examples/base_rag_example.py +++ b/examples/base_rag_example.py @@ -23,12 +23,10 @@ class BaseRAGExample(ABC): name: str, description: str, default_index_name: str, - include_embedding_mode: bool = True, ): self.name = name self.description = description self.default_index_name = default_index_name - self.include_embedding_mode = include_embedding_mode self.parser = self._create_parser() def _create_parser(self) -> argparse.ArgumentParser: @@ -73,14 +71,13 @@ class BaseRAGExample(ABC): default=embedding_model_default, help=f"Embedding model to use (default: {embedding_model_default})", ) - if self.include_embedding_mode: - embedding_group.add_argument( - "--embedding-mode", - type=str, - default="sentence-transformers", - choices=["sentence-transformers", "openai", "mlx"], - help="Embedding backend mode (default: sentence-transformers)", - ) + embedding_group.add_argument( + "--embedding-mode", + type=str, + default="sentence-transformers", + choices=["sentence-transformers", "openai", "mlx"], + help="Embedding backend mode (default: sentence-transformers)", + ) # LLM parameters llm_group = parser.add_argument_group("LLM Parameters") @@ -152,22 +149,16 @@ class BaseRAGExample(ABC): print(f"\n[Building Index] Creating {self.name} index...") print(f"Total text chunks: {len(texts)}") - # Build kwargs for LeannBuilder - builder_kwargs = { - "backend_name": "hnsw", - "embedding_model": args.embedding_model, - "graph_degree": 32, - "complexity": 64, - "is_compact": True, - "is_recompute": True, - "num_threads": 1, # Force single-threaded mode - } - - # Only add embedding_mode if it's not suppressed (for compatibility) - if hasattr(args, "embedding_mode") and args.embedding_mode is not None: - builder_kwargs["embedding_mode"] = args.embedding_mode - - builder = LeannBuilder(**builder_kwargs) + builder = LeannBuilder( + backend_name="hnsw", + embedding_model=args.embedding_model, + embedding_mode=args.embedding_mode, + graph_degree=32, + complexity=64, + is_compact=True, + is_recompute=True, + num_threads=1, # Force single-threaded mode + ) # Add texts in batches for better progress tracking batch_size = 1000 diff --git a/examples/browser_rag.py b/examples/browser_rag.py index fde0367..5697d49 100644 --- a/examples/browser_rag.py +++ b/examples/browser_rag.py @@ -21,7 +21,7 @@ class BrowserRAG(BaseRAGExample): super().__init__( name="Browser History", description="Process and query Chrome browser history with LEANN", - default_index_name="google_history_index", # Match original: "./google_history_index", + default_index_name="google_history_index", ) def _add_specific_arguments(self, parser): diff --git a/examples/document_rag.py b/examples/document_rag.py index 7cff9b9..3497698 100644 --- a/examples/document_rag.py +++ b/examples/document_rag.py @@ -20,7 +20,7 @@ class DocumentRAG(BaseRAGExample): super().__init__( name="Document", description="Process and query documents (PDF, TXT, MD, etc.) with LEANN", - default_index_name="test_doc_files", # Match original main_cli_example.py default + default_index_name="test_doc_files", ) def _add_specific_arguments(self, parser): diff --git a/examples/email_rag.py b/examples/email_rag.py index 450d73c..10ec202 100644 --- a/examples/email_rag.py +++ b/examples/email_rag.py @@ -20,8 +20,7 @@ class EmailRAG(BaseRAGExample): super().__init__( name="Email", description="Process and query Apple Mail emails with LEANN", - default_index_name="mail_index", # Match original: "./mail_index" - include_embedding_mode=False, # Original mail_reader_leann.py doesn't have embedding_mode + default_index_name="mail_index", ) def _add_specific_arguments(self, parser): diff --git a/examples/wechat_rag.py b/examples/wechat_rag.py index b554929..590c61a 100644 --- a/examples/wechat_rag.py +++ b/examples/wechat_rag.py @@ -25,8 +25,7 @@ class WeChatRAG(BaseRAGExample): super().__init__( name="WeChat History", description="Process and query WeChat chat history with LEANN", - default_index_name="wechat_history_magic_test_11Debug_new", # Match original default - include_embedding_mode=False, # Original wechat_history_reader_leann.py doesn't have embedding_mode + default_index_name="wechat_history_magic_test_11Debug_new", ) def _add_specific_arguments(self, parser): diff --git a/test/sanity_checks/debug_zmq_issue.py b/test/sanity_checks/debug_zmq_issue.py index d1bd156..9ce8917 100644 --- a/test/sanity_checks/debug_zmq_issue.py +++ b/test/sanity_checks/debug_zmq_issue.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Debug script to test ZMQ communication with the exact same setup as main_cli_example.py +Debug script to test ZMQ communication with embedding models. """ import sys @@ -13,9 +13,9 @@ from leann_backend_diskann import embedding_pb2 def test_zmq_with_same_model(): - print("=== Testing ZMQ with same model as main_cli_example.py ===") + print("=== Testing ZMQ with embedding model ===") - # Test the exact same model that main_cli_example.py uses + # Test with a common embedding model model_name = "sentence-transformers/all-mpnet-base-v2" # Start server with the same model diff --git a/tests/test_document_rag.py b/tests/test_document_rag.py index 26007a8..f9c793d 100644 --- a/tests/test_document_rag.py +++ b/tests/test_document_rag.py @@ -1,5 +1,5 @@ """ -Test document_rag (formerly main_cli_example) functionality using pytest. +Test document_rag functionality using pytest. """ import os