fix: Restore embedding-mode parameter to all examples

- All examples now have --embedding-mode parameter (unified interface benefit)
- Default is 'sentence-transformers' (consistent with original behavior)
- Users can now use OpenAI or MLX embeddings with any data source
- Maintains functional equivalence with original scripts
This commit is contained in:
Andy Lee
2025-07-29 13:33:40 -07:00
parent ff1b622bdd
commit ddc789b231
7 changed files with 25 additions and 36 deletions

View File

@@ -23,12 +23,10 @@ class BaseRAGExample(ABC):
name: str, name: str,
description: str, description: str,
default_index_name: str, default_index_name: str,
include_embedding_mode: bool = True,
): ):
self.name = name self.name = name
self.description = description self.description = description
self.default_index_name = default_index_name self.default_index_name = default_index_name
self.include_embedding_mode = include_embedding_mode
self.parser = self._create_parser() self.parser = self._create_parser()
def _create_parser(self) -> argparse.ArgumentParser: def _create_parser(self) -> argparse.ArgumentParser:
@@ -73,14 +71,13 @@ class BaseRAGExample(ABC):
default=embedding_model_default, default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default})", help=f"Embedding model to use (default: {embedding_model_default})",
) )
if self.include_embedding_mode: embedding_group.add_argument(
embedding_group.add_argument( "--embedding-mode",
"--embedding-mode", type=str,
type=str, default="sentence-transformers",
default="sentence-transformers", choices=["sentence-transformers", "openai", "mlx"],
choices=["sentence-transformers", "openai", "mlx"], help="Embedding backend mode (default: sentence-transformers)",
help="Embedding backend mode (default: sentence-transformers)", )
)
# LLM parameters # LLM parameters
llm_group = parser.add_argument_group("LLM Parameters") llm_group = parser.add_argument_group("LLM Parameters")
@@ -152,22 +149,16 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...") print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}") print(f"Total text chunks: {len(texts)}")
# Build kwargs for LeannBuilder builder = LeannBuilder(
builder_kwargs = { backend_name="hnsw",
"backend_name": "hnsw", embedding_model=args.embedding_model,
"embedding_model": args.embedding_model, embedding_mode=args.embedding_mode,
"graph_degree": 32, graph_degree=32,
"complexity": 64, complexity=64,
"is_compact": True, is_compact=True,
"is_recompute": True, is_recompute=True,
"num_threads": 1, # Force single-threaded mode num_threads=1, # Force single-threaded mode
} )
# Only add embedding_mode if it's not suppressed (for compatibility)
if hasattr(args, "embedding_mode") and args.embedding_mode is not None:
builder_kwargs["embedding_mode"] = args.embedding_mode
builder = LeannBuilder(**builder_kwargs)
# Add texts in batches for better progress tracking # Add texts in batches for better progress tracking
batch_size = 1000 batch_size = 1000

View File

@@ -21,7 +21,7 @@ class BrowserRAG(BaseRAGExample):
super().__init__( super().__init__(
name="Browser History", name="Browser History",
description="Process and query Chrome browser history with LEANN", description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index", # Match original: "./google_history_index", default_index_name="google_history_index",
) )
def _add_specific_arguments(self, parser): def _add_specific_arguments(self, parser):

View File

@@ -20,7 +20,7 @@ class DocumentRAG(BaseRAGExample):
super().__init__( super().__init__(
name="Document", name="Document",
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN", description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
default_index_name="test_doc_files", # Match original main_cli_example.py default default_index_name="test_doc_files",
) )
def _add_specific_arguments(self, parser): def _add_specific_arguments(self, parser):

View File

@@ -20,8 +20,7 @@ class EmailRAG(BaseRAGExample):
super().__init__( super().__init__(
name="Email", name="Email",
description="Process and query Apple Mail emails with LEANN", description="Process and query Apple Mail emails with LEANN",
default_index_name="mail_index", # Match original: "./mail_index" default_index_name="mail_index",
include_embedding_mode=False, # Original mail_reader_leann.py doesn't have embedding_mode
) )
def _add_specific_arguments(self, parser): def _add_specific_arguments(self, parser):

View File

@@ -25,8 +25,7 @@ class WeChatRAG(BaseRAGExample):
super().__init__( super().__init__(
name="WeChat History", name="WeChat History",
description="Process and query WeChat chat history with LEANN", description="Process and query WeChat chat history with LEANN",
default_index_name="wechat_history_magic_test_11Debug_new", # Match original default default_index_name="wechat_history_magic_test_11Debug_new",
include_embedding_mode=False, # Original wechat_history_reader_leann.py doesn't have embedding_mode
) )
def _add_specific_arguments(self, parser): def _add_specific_arguments(self, parser):

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py Debug script to test ZMQ communication with embedding models.
""" """
import sys import sys
@@ -13,9 +13,9 @@ from leann_backend_diskann import embedding_pb2
def test_zmq_with_same_model(): def test_zmq_with_same_model():
print("=== Testing ZMQ with same model as main_cli_example.py ===") print("=== Testing ZMQ with embedding model ===")
# Test the exact same model that main_cli_example.py uses # Test with a common embedding model
model_name = "sentence-transformers/all-mpnet-base-v2" model_name = "sentence-transformers/all-mpnet-base-v2"
# Start server with the same model # Start server with the same model

View File

@@ -1,5 +1,5 @@
""" """
Test document_rag (formerly main_cli_example) functionality using pytest. Test document_rag functionality using pytest.
""" """
import os import os