diff --git a/apps/chunking/__init__.py b/apps/chunking/__init__.py index 2f323e3..17a7e4a 100644 --- a/apps/chunking/__init__.py +++ b/apps/chunking/__init__.py @@ -12,6 +12,7 @@ from pathlib import Path try: from leann.chunking_utils import ( CODE_EXTENSIONS, + _traditional_chunks_as_dicts, create_ast_chunks, create_text_chunks, create_traditional_chunks, @@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment sys.path.insert(0, str(leann_src)) from leann.chunking_utils import ( CODE_EXTENSIONS, + _traditional_chunks_as_dicts, create_ast_chunks, create_text_chunks, create_traditional_chunks, @@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment __all__ = [ "CODE_EXTENSIONS", + "_traditional_chunks_as_dicts", "create_ast_chunks", "create_text_chunks", "create_traditional_chunks", diff --git a/packages/leann-core/src/leann/chunking_utils.py b/packages/leann-core/src/leann/chunking_utils.py index 0140930..34e0779 100644 --- a/packages/leann-core/src/leann/chunking_utils.py +++ b/packages/leann-core/src/leann/chunking_utils.py @@ -5,12 +5,15 @@ Packaged within leann-core so installed wheels can import it reliably. import logging from pathlib import Path -from typing import Optional +from typing import Any, Optional from llama_index.core.node_parser import SentenceSplitter logger = logging.getLogger(__name__) +# Flag to ensure AST token warning only shown once per session +_ast_token_warning_shown = False + def estimate_token_count(text: str) -> int: """ @@ -174,37 +177,44 @@ def create_ast_chunks( max_chunk_size: int = 512, chunk_overlap: int = 64, metadata_template: str = "default", -) -> list[str]: +) -> list[dict[str, Any]]: """Create AST-aware chunks from code documents using astchunk. Falls back to traditional chunking if astchunk is unavailable. + + Returns: + List of dicts with {"text": str, "metadata": dict} """ try: from astchunk import ASTChunkBuilder # optional dependency except ImportError as e: logger.error(f"astchunk not available: {e}") logger.info("Falling back to traditional chunking for code files") - return create_traditional_chunks(documents, max_chunk_size, chunk_overlap) + return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap) all_chunks = [] for doc in documents: language = doc.metadata.get("language") if not language: logger.warning("No language detected; falling back to traditional chunking") - all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) + all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) continue try: - # Warn if AST chunk size + overlap might exceed common token limits + # Warn once if AST chunk size + overlap might exceed common token limits + # Note: Actual truncation happens at embedding time with dynamic model limits + global _ast_token_warning_shown estimated_max_tokens = int( (max_chunk_size + chunk_overlap) * 1.2 ) # Conservative estimate - if estimated_max_tokens > 512: + if estimated_max_tokens > 512 and not _ast_token_warning_shown: logger.warning( f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars " f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). " - f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}" + f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. " + f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit." ) + _ast_token_warning_shown = True configs = { "max_chunk_size": max_chunk_size, @@ -229,17 +239,40 @@ def create_ast_chunks( chunks = chunk_builder.chunkify(code_content) for chunk in chunks: + chunk_text = None + astchunk_metadata = {} + if hasattr(chunk, "text"): chunk_text = chunk.text - elif isinstance(chunk, dict) and "text" in chunk: - chunk_text = chunk["text"] elif isinstance(chunk, str): chunk_text = chunk + elif isinstance(chunk, dict): + # Handle astchunk format: {"content": "...", "metadata": {...}} + if "content" in chunk: + chunk_text = chunk["content"] + astchunk_metadata = chunk.get("metadata", {}) + elif "text" in chunk: + chunk_text = chunk["text"] + else: + chunk_text = str(chunk) # Last resort else: chunk_text = str(chunk) if chunk_text and chunk_text.strip(): - all_chunks.append(chunk_text.strip()) + # Extract document-level metadata + doc_metadata = { + "file_path": doc.metadata.get("file_path", ""), + "file_name": doc.metadata.get("file_name", ""), + } + if "creation_date" in doc.metadata: + doc_metadata["creation_date"] = doc.metadata["creation_date"] + if "last_modified_date" in doc.metadata: + doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"] + + # Merge document metadata + astchunk metadata + combined_metadata = {**doc_metadata, **astchunk_metadata} + + all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata}) logger.info( f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}" @@ -247,15 +280,19 @@ def create_ast_chunks( except Exception as e: logger.warning(f"AST chunking failed for {language} file: {e}") logger.info("Falling back to traditional chunking") - all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap)) + all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap)) return all_chunks def create_traditional_chunks( documents, chunk_size: int = 256, chunk_overlap: int = 128 -) -> list[str]: - """Create traditional text chunks using LlamaIndex SentenceSplitter.""" +) -> list[dict[str, Any]]: + """Create traditional text chunks using LlamaIndex SentenceSplitter. + + Returns: + List of dicts with {"text": str, "metadata": dict} + """ if chunk_size <= 0: logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") chunk_size = 256 @@ -271,19 +308,40 @@ def create_traditional_chunks( paragraph_separator="\n\n", ) - all_texts = [] + result = [] for doc in documents: + # Extract document-level metadata + doc_metadata = { + "file_path": doc.metadata.get("file_path", ""), + "file_name": doc.metadata.get("file_name", ""), + } + if "creation_date" in doc.metadata: + doc_metadata["creation_date"] = doc.metadata["creation_date"] + if "last_modified_date" in doc.metadata: + doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"] + try: nodes = node_parser.get_nodes_from_documents([doc]) if nodes: - all_texts.extend(node.get_content() for node in nodes) + for node in nodes: + result.append({"text": node.get_content(), "metadata": doc_metadata}) except Exception as e: logger.error(f"Traditional chunking failed for document: {e}") content = doc.get_content() if content and content.strip(): - all_texts.append(content.strip()) + result.append({"text": content.strip(), "metadata": doc_metadata}) - return all_texts + return result + + +def _traditional_chunks_as_dicts( + documents, chunk_size: int = 256, chunk_overlap: int = 128 +) -> list[dict[str, Any]]: + """Helper: Traditional chunking that returns dict format for consistency. + + This is now just an alias for create_traditional_chunks for backwards compatibility. + """ + return create_traditional_chunks(documents, chunk_size, chunk_overlap) def create_text_chunks( @@ -295,8 +353,12 @@ def create_text_chunks( ast_chunk_overlap: int = 64, code_file_extensions: Optional[list[str]] = None, ast_fallback_traditional: bool = True, -) -> list[str]: - """Create text chunks from documents with optional AST support for code files.""" +) -> list[dict[str, Any]]: + """Create text chunks from documents with optional AST support for code files. + + Returns: + List of dicts with {"text": str, "metadata": dict} + """ if not documents: logger.warning("No documents provided for chunking") return [] @@ -331,24 +393,17 @@ def create_text_chunks( logger.error(f"AST chunking failed: {e}") if ast_fallback_traditional: all_chunks.extend( - create_traditional_chunks(code_docs, chunk_size, chunk_overlap) + _traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap) ) else: raise if text_docs: - all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap)) + all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap)) else: - all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap) + all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap) logger.info(f"Total chunks created: {len(all_chunks)}") - # Validate chunk token limits (default to 512 for safety) - # This provides a safety net for embedding models with token limits - validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512) - - if num_truncated > 0: - logger.info( - f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit" - ) - - return validated_chunks + # Note: Token truncation is now handled at embedding time with dynamic model limits + # See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py + return all_chunks diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index db011eb..7dbd5af 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -1279,13 +1279,8 @@ Examples: ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True), ) - # Note: AST chunking currently returns plain text chunks without metadata - # We preserve basic file info by associating chunks with their source documents - # For better metadata preservation, documents list order should be maintained - for chunk_text in chunk_texts: - # TODO: Enhance create_text_chunks to return metadata alongside text - # For now, we store chunks with empty metadata - all_texts.append({"text": chunk_text, "metadata": {}}) + # create_text_chunks now returns list[dict] with metadata preserved + all_texts.extend(chunk_texts) except ImportError as e: print( diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index a79a39e..2af4469 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -10,72 +10,63 @@ import time from typing import Any, Optional import numpy as np +import tiktoken import torch from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url +# Set up logger with proper level +logger = logging.getLogger(__name__) +LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() +log_level = getattr(logging, LOG_LEVEL, logging.WARNING) +logger.setLevel(log_level) -def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]: - """ - Truncate texts to token limit using tiktoken or conservative character truncation. - - Args: - texts: List of texts to truncate - max_tokens: Maximum tokens allowed per text - - Returns: - List of truncated texts that should fit within token limit - """ - try: - import tiktoken - - encoder = tiktoken.get_encoding("cl100k_base") - truncated = [] - - for text in texts: - tokens = encoder.encode(text) - if len(tokens) > max_tokens: - # Truncate to max_tokens and decode back to text - truncated_tokens = tokens[:max_tokens] - truncated_text = encoder.decode(truncated_tokens) - truncated.append(truncated_text) - logger.warning( - f"Truncated text from {len(tokens)} to {max_tokens} tokens " - f"(from {len(text)} to {len(truncated_text)} characters)" - ) - else: - truncated.append(text) - return truncated - - except ImportError: - # Fallback: Conservative character truncation - # Assume worst case: 1.5 tokens per character for code content - char_limit = int(max_tokens / 1.5) - truncated = [] - - for text in texts: - if len(text) > char_limit: - truncated_text = text[:char_limit] - truncated.append(truncated_text) - logger.warning( - f"Truncated text from {len(text)} to {char_limit} characters " - f"(conservative estimate for {max_tokens} tokens)" - ) - else: - truncated.append(text) - return truncated +# Token limit registry for embedding models +# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI) +# Ollama models use dynamic discovery via /api/show +EMBEDDING_MODEL_LIMITS = { + # Nomic models (common across servers) + "nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show + "nomic-embed-text-v1.5": 2048, + "nomic-embed-text-v2": 512, + # Other embedding models + "mxbai-embed-large": 512, + "all-minilm": 512, + "bge-m3": 8192, + "snowflake-arctic-embed": 512, + # OpenAI models + "text-embedding-3-small": 8192, + "text-embedding-3-large": 8192, + "text-embedding-ada-002": 8192, +} -def get_model_token_limit(model_name: str) -> int: +def get_model_token_limit( + model_name: str, + base_url: Optional[str] = None, + default: int = 2048, +) -> int: """ Get token limit for a given embedding model. + Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others. Args: model_name: Name of the embedding model + base_url: Base URL of the embedding server (for dynamic discovery) + default: Default token limit if model not found Returns: - Token limit for the model, defaults to 512 if unknown + Token limit for the model in tokens """ + # Try Ollama dynamic discovery if base_url provided + if base_url: + # Detect Ollama servers by port or "ollama" in URL + if "11434" in base_url or "ollama" in base_url.lower(): + limit = _query_ollama_context_limit(model_name, base_url) + if limit: + return limit + + # Fallback to known model registry with version handling (from PR #154) # Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text") base_model_name = model_name.split(":")[0] @@ -92,31 +83,111 @@ def get_model_token_limit(model_name: str) -> int: if known_model in base_model_name or base_model_name in known_model: return limit - # Default to conservative 512 token limit - logger.warning(f"Unknown model '{model_name}', using default 512 token limit") - return 512 + # Default fallback + logger.warning(f"Unknown model '{model_name}', using default {default} token limit") + return default -# Set up logger with proper level -logger = logging.getLogger(__name__) -LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() -log_level = getattr(logging, LOG_LEVEL, logging.WARNING) -logger.setLevel(log_level) +def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]: + """ + Truncate texts to fit within token limit using tiktoken. + + Args: + texts: List of text strings to truncate + token_limit: Maximum number of tokens allowed + + Returns: + List of truncated texts (same length as input) + """ + if not texts: + return [] + + # Use tiktoken with cl100k_base encoding + enc = tiktoken.get_encoding("cl100k_base") + + truncated_texts = [] + truncation_count = 0 + total_tokens_removed = 0 + max_original_length = 0 + + for i, text in enumerate(texts): + tokens = enc.encode(text) + original_length = len(tokens) + + if original_length <= token_limit: + # Text is within limit, keep as is + truncated_texts.append(text) + else: + # Truncate to token_limit + truncated_tokens = tokens[:token_limit] + truncated_text = enc.decode(truncated_tokens) + truncated_texts.append(truncated_text) + + # Track truncation statistics + truncation_count += 1 + tokens_removed = original_length - token_limit + total_tokens_removed += tokens_removed + max_original_length = max(max_original_length, original_length) + + # Log individual truncation at WARNING level (first few only) + if truncation_count <= 3: + logger.warning( + f"Text {i + 1} truncated: {original_length} → {token_limit} tokens " + f"({tokens_removed} tokens removed)" + ) + elif truncation_count == 4: + logger.warning("Further truncation warnings suppressed...") + + # Log summary at INFO level + if truncation_count > 0: + logger.warning( + f"Truncation summary: {truncation_count}/{len(texts)} texts truncated " + f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)" + ) + else: + logger.debug( + f"No truncation needed - all {len(texts)} texts within {token_limit} token limit" + ) + + return truncated_texts + + +def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]: + """ + Query Ollama /api/show for model context limit. + + Args: + model_name: Name of the Ollama model + base_url: Base URL of the Ollama server + + Returns: + Context limit in tokens if found, None otherwise + """ + try: + import requests + + response = requests.post( + f"{base_url}/api/show", + json={"name": model_name}, + timeout=5, + ) + if response.status_code == 200: + data = response.json() + if "model_info" in data: + # Look for *.context_length in model_info + for key, value in data["model_info"].items(): + if "context_length" in key and isinstance(value, int): + logger.info(f"Detected {model_name} context limit: {value} tokens") + return value + except Exception as e: + logger.debug(f"Failed to query Ollama context limit: {e}") + + return None + # Global model cache to avoid repeated loading _model_cache: dict[str, Any] = {} -# Known embedding model token limits -EMBEDDING_MODEL_LIMITS = { - "nomic-embed-text": 512, - "nomic-embed-text-v2": 512, - "mxbai-embed-large": 512, - "all-minilm": 512, - "bge-m3": 8192, - "snowflake-arctic-embed": 512, - # Add more models as needed -} - def compute_embeddings( texts: list[str], @@ -814,15 +885,13 @@ def compute_embeddings_ollama( logger.info(f"Using batch size: {batch_size} for true batch processing") - # Get model token limit and apply truncation - token_limit = get_model_token_limit(model_name) + # Get model token limit and apply truncation before batching + token_limit = get_model_token_limit(model_name, base_url=resolved_host) logger.info(f"Model '{model_name}' token limit: {token_limit}") - # Apply token-aware truncation to all texts - truncated_texts = truncate_to_token_limit(texts, token_limit) - if len(truncated_texts) != len(texts): - logger.error("Truncation failed - text count mismatch") - truncated_texts = texts # Fallback to original texts + # Apply truncation to all texts before batch processing + # Function logs truncation details internally + texts = truncate_to_token_limit(texts, token_limit) def get_batch_embeddings(batch_texts): """Get embeddings for a batch of texts using /api/embed endpoint.""" @@ -880,12 +949,12 @@ def compute_embeddings_ollama( return None, list(range(len(batch_texts))) - # Process truncated texts in batches + # Process texts in batches all_embeddings = [] all_failed_indices = [] # Setup progress bar if needed - show_progress = is_build or len(truncated_texts) > 10 + show_progress = is_build or len(texts) > 10 try: if show_progress: from tqdm import tqdm @@ -893,7 +962,7 @@ def compute_embeddings_ollama( show_progress = False # Process batches - num_batches = (len(truncated_texts) + batch_size - 1) // batch_size + num_batches = (len(texts) + batch_size - 1) // batch_size if show_progress: batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)") @@ -902,8 +971,8 @@ def compute_embeddings_ollama( for batch_idx in batch_iterator: start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, len(truncated_texts)) - batch_texts = truncated_texts[start_idx:end_idx] + end_idx = min(start_idx + batch_size, len(texts)) + batch_texts = texts[start_idx:end_idx] batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) @@ -918,11 +987,11 @@ def compute_embeddings_ollama( # Handle failed embeddings if all_failed_indices: - if len(all_failed_indices) == len(truncated_texts): + if len(all_failed_indices) == len(texts): raise RuntimeError("Failed to compute any embeddings") logger.warning( - f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts" + f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts" ) # Use zero embeddings as fallback for failed ones diff --git a/tests/test_astchunk_integration.py b/tests/test_astchunk_integration.py index df34521..ab68e65 100644 --- a/tests/test_astchunk_integration.py +++ b/tests/test_astchunk_integration.py @@ -8,7 +8,7 @@ import subprocess import sys import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest @@ -116,8 +116,10 @@ class TestChunkingFunctions: chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10) assert len(chunks) > 0 - assert all(isinstance(chunk, str) for chunk in chunks) - assert all(len(chunk.strip()) > 0 for chunk in chunks) + # Traditional chunks now return dict format for consistency + assert all(isinstance(chunk, dict) for chunk in chunks) + assert all("text" in chunk and "metadata" in chunk for chunk in chunks) + assert all(len(chunk["text"].strip()) > 0 for chunk in chunks) def test_create_traditional_chunks_empty_docs(self): """Test traditional chunking with empty documents.""" @@ -158,11 +160,22 @@ class Calculator: # Should have multiple chunks due to different functions/classes assert len(chunks) > 0 - assert all(isinstance(chunk, str) for chunk in chunks) - assert all(len(chunk.strip()) > 0 for chunk in chunks) + # R3: Expect dict format with "text" and "metadata" keys + assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts" + assert all("text" in chunk and "metadata" in chunk for chunk in chunks), ( + "Each chunk should have 'text' and 'metadata' keys" + ) + assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), ( + "Each chunk text should be non-empty" + ) + + # Check metadata is present + assert all("file_path" in chunk["metadata"] for chunk in chunks), ( + "Each chunk should have file_path metadata" + ) # Check that code structure is somewhat preserved - combined_content = " ".join(chunks) + combined_content = " ".join([c["text"] for c in chunks]) assert "def hello_world" in combined_content assert "class Calculator" in combined_content @@ -194,7 +207,11 @@ class Calculator: chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10) assert len(chunks) > 0 - assert all(isinstance(chunk, str) for chunk in chunks) + # R3: Traditional chunking should also return dict format for consistency + assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts" + assert all("text" in chunk and "metadata" in chunk for chunk in chunks), ( + "Each chunk should have 'text' and 'metadata' keys" + ) def test_create_text_chunks_ast_mode(self): """Test text chunking in AST mode.""" @@ -213,7 +230,11 @@ class Calculator: ) assert len(chunks) > 0 - assert all(isinstance(chunk, str) for chunk in chunks) + # R3: AST mode should also return dict format + assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts" + assert all("text" in chunk and "metadata" in chunk for chunk in chunks), ( + "Each chunk should have 'text' and 'metadata' keys" + ) def test_create_text_chunks_custom_extensions(self): """Test text chunking with custom code file extensions.""" @@ -353,6 +374,552 @@ class MathUtils: pytest.skip("Test timed out - likely due to model download in CI") +class TestASTContentExtraction: + """Test AST content extraction bug fix. + + These tests verify that astchunk's dict format with 'content' key is handled correctly, + and that the extraction logic doesn't fall through to stringifying entire dicts. + """ + + def test_extract_content_from_astchunk_dict(self): + """Test that astchunk dict format with 'content' key is handled correctly. + + Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"]. + This causes fallthrough to str(chunk), stringifying the entire dict. + + This test will FAIL until the bug is fixed because: + - Current code will stringify the dict: "{'content': '...', 'metadata': {...}}" + - Fixed code should extract just the content value + """ + # Mock the ASTChunkBuilder class + mock_builder = Mock() + + # Astchunk returns this format + astchunk_format_chunk = { + "content": "def hello():\n print('world')", + "metadata": { + "filepath": "test.py", + "line_count": 2, + "start_line_no": 0, + "end_line_no": 1, + "node_count": 1, + }, + } + mock_builder.chunkify.return_value = [astchunk_format_chunk] + + # Create mock document + doc = MockDocument( + "def hello():\n print('world')", "/test/test.py", {"language": "python"} + ) + + # Mock the astchunk module and its ASTChunkBuilder class + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + # Patch sys.modules to inject our mock before the import + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + # Call create_ast_chunks + chunks = create_ast_chunks([doc]) + + # R3: Should return dict format with proper metadata + assert len(chunks) > 0, "Should return at least one chunk" + + # R3: Each chunk should be a dict + chunk = chunks[0] + assert isinstance(chunk, dict), "Chunk should be a dict" + assert "text" in chunk, "Chunk should have 'text' key" + assert "metadata" in chunk, "Chunk should have 'metadata' key" + + chunk_text = chunk["text"] + + # CRITICAL: Should NOT contain stringified dict markers in the text field + # These assertions will FAIL with current buggy code + assert "'content':" not in chunk_text, ( + f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..." + ) + assert "'metadata':" not in chunk_text, ( + "Chunk text contains stringified metadata - extraction failed! " + f"Got: {chunk_text[:100]}..." + ) + assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], ( + "Chunk text appears to be a stringified dict" + ) + + # Should contain actual content + assert "def hello()" in chunk_text, "Should extract actual code content" + assert "print('world')" in chunk_text, "Should extract complete code content" + + # R3: Should preserve astchunk metadata + assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], ( + "Should preserve file path metadata" + ) + + def test_extract_text_key_fallback(self): + """Test that 'text' key still works for backward compatibility. + + Some chunks might use 'text' instead of 'content' - ensure backward compatibility. + This test should PASS even with current code. + """ + mock_builder = Mock() + + # Some chunks might use "text" key + text_key_chunk = {"text": "def legacy_function():\n return True"} + mock_builder.chunkify.return_value = [text_key_chunk] + + # Create mock document + doc = MockDocument( + "def legacy_function():\n return True", "/test/legacy.py", {"language": "python"} + ) + + # Mock the astchunk module + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + # Call create_ast_chunks + chunks = create_ast_chunks([doc]) + + # R3: Should extract text correctly as dict format + assert len(chunks) > 0 + chunk = chunks[0] + assert isinstance(chunk, dict), "Chunk should be a dict" + assert "text" in chunk, "Chunk should have 'text' key" + + chunk_text = chunk["text"] + + # Should NOT be stringified + assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key" + + # Should contain actual content + assert "def legacy_function()" in chunk_text + assert "return True" in chunk_text + + def test_handles_string_chunks(self): + """Test that plain string chunks still work. + + Some chunkers might return plain strings - verify these are preserved. + This test should PASS with current code. + """ + mock_builder = Mock() + + # Plain string chunk + plain_string_chunk = "def simple_function():\n pass" + mock_builder.chunkify.return_value = [plain_string_chunk] + + # Create mock document + doc = MockDocument( + "def simple_function():\n pass", "/test/simple.py", {"language": "python"} + ) + + # Mock the astchunk module + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + # Call create_ast_chunks + chunks = create_ast_chunks([doc]) + + # R3: Should wrap string in dict format + assert len(chunks) > 0 + chunk = chunks[0] + assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict" + assert "text" in chunk, "Chunk should have 'text' key" + + chunk_text = chunk["text"] + + assert chunk_text == plain_string_chunk.strip(), ( + "Should preserve plain string chunk content" + ) + assert "def simple_function()" in chunk_text + assert "pass" in chunk_text + + def test_multiple_chunks_with_mixed_formats(self): + """Test handling of multiple chunks with different formats. + + Real-world scenario: astchunk might return a mix of formats. + This test will FAIL if any chunk with 'content' key gets stringified. + """ + mock_builder = Mock() + + # Mix of formats + mixed_chunks = [ + {"content": "def first():\n return 1", "metadata": {"line_count": 2}}, + "def second():\n return 2", # Plain string + {"text": "def third():\n return 3"}, # Old format + {"content": "class MyClass:\n pass", "metadata": {"node_count": 1}}, + ] + mock_builder.chunkify.return_value = mixed_chunks + + # Create mock document + code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass" + doc = MockDocument(code, "/test/mixed.py", {"language": "python"}) + + # Mock the astchunk module + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + # Call create_ast_chunks + chunks = create_ast_chunks([doc]) + + # R3: Should extract all chunks correctly as dicts + assert len(chunks) == 4, "Should extract all 4 chunks" + + # Check each chunk + for i, chunk in enumerate(chunks): + assert isinstance(chunk, dict), f"Chunk {i} should be a dict" + assert "text" in chunk, f"Chunk {i} should have 'text' key" + assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key" + + chunk_text = chunk["text"] + # None should be stringified dicts + assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)" + assert "'metadata':" not in chunk_text, ( + f"Chunk {i} text is stringified (has 'metadata':)" + ) + assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)" + + # Verify actual content is present + combined = "\n".join([c["text"] for c in chunks]) + assert "def first()" in combined + assert "def second()" in combined + assert "def third()" in combined + assert "class MyClass:" in combined + + def test_empty_content_value_handling(self): + """Test handling of chunks with empty content values. + + Edge case: chunk has 'content' key but value is empty. + Should skip these chunks, not stringify them. + """ + mock_builder = Mock() + + chunks_with_empty = [ + {"content": "", "metadata": {"line_count": 0}}, # Empty content + {"content": " ", "metadata": {"line_count": 1}}, # Whitespace only + {"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid + ] + mock_builder.chunkify.return_value = chunks_with_empty + + doc = MockDocument( + "def valid():\n return True", "/test/empty.py", {"language": "python"} + ) + + # Mock the astchunk module + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + chunks = create_ast_chunks([doc]) + + # R3: Should only have the valid chunk (empty ones filtered out) + assert len(chunks) == 1, "Should filter out empty content chunks" + + chunk = chunks[0] + assert isinstance(chunk, dict), "Chunk should be a dict" + assert "text" in chunk, "Chunk should have 'text' key" + assert "def valid()" in chunk["text"] + + # Should not have stringified the empty dict + assert "'content': ''" not in chunk["text"] + + +class TestASTMetadataPreservation: + """Test metadata preservation in AST chunk dictionaries. + + R3: These tests define the contract for metadata preservation when returning + chunk dictionaries instead of plain strings. Each chunk dict should have: + - "text": str - the actual chunk content + - "metadata": dict - all metadata from document AND astchunk + + These tests will FAIL until G3 implementation changes return type to list[dict]. + """ + + def test_ast_chunks_preserve_file_metadata(self): + """Test that document metadata is preserved in chunk metadata. + + This test verifies that all document-level metadata (file_path, file_name, + creation_date, last_modified_date) is included in each chunk's metadata dict. + + This will FAIL because current code returns list[str], not list[dict]. + """ + # Create mock document with rich metadata + python_code = ''' +def calculate_sum(numbers): + """Calculate sum of numbers.""" + return sum(numbers) + +class DataProcessor: + """Process data records.""" + + def process(self, data): + return [x * 2 for x in data] +''' + doc = MockDocument( + python_code, + file_path="/project/src/utils.py", + metadata={ + "language": "python", + "file_path": "/project/src/utils.py", + "file_name": "utils.py", + "creation_date": "2024-01-15T10:30:00", + "last_modified_date": "2024-10-31T15:45:00", + }, + ) + + # Mock astchunk to return chunks with metadata + mock_builder = Mock() + astchunk_chunks = [ + { + "content": "def calculate_sum(numbers):\n return sum(numbers)", + "metadata": { + "filepath": "/project/src/utils.py", + "line_count": 2, + "start_line_no": 1, + "end_line_no": 2, + "node_count": 1, + }, + }, + { + "content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]", + "metadata": { + "filepath": "/project/src/utils.py", + "line_count": 3, + "start_line_no": 5, + "end_line_no": 7, + "node_count": 2, + }, + }, + ] + mock_builder.chunkify.return_value = astchunk_chunks + + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + chunks = create_ast_chunks([doc]) + + # CRITICAL: These assertions will FAIL with current list[str] return type + assert len(chunks) == 2, "Should return 2 chunks" + + for i, chunk in enumerate(chunks): + # Structure assertions - WILL FAIL: current code returns strings + assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}" + assert "text" in chunk, f"Chunk {i} must have 'text' key" + assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key" + assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict" + + # Document metadata preservation - WILL FAIL + metadata = chunk["metadata"] + assert "file_path" in metadata, f"Chunk {i} should preserve file_path" + assert metadata["file_path"] == "/project/src/utils.py", ( + f"Chunk {i} file_path incorrect" + ) + + assert "file_name" in metadata, f"Chunk {i} should preserve file_name" + assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect" + + assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date" + assert metadata["creation_date"] == "2024-01-15T10:30:00", ( + f"Chunk {i} creation_date incorrect" + ) + + assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date" + assert metadata["last_modified_date"] == "2024-10-31T15:45:00", ( + f"Chunk {i} last_modified_date incorrect" + ) + + # Verify metadata is consistent across chunks from same document + assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], ( + "All chunks from same document should have same file_path" + ) + + # Verify text content is present and not stringified + assert "def calculate_sum" in chunks[0]["text"] + assert "class DataProcessor" in chunks[1]["text"] + + def test_ast_chunks_include_astchunk_metadata(self): + """Test that astchunk-specific metadata is merged into chunk metadata. + + This test verifies that astchunk's metadata (line_count, start_line_no, + end_line_no, node_count) is merged with document metadata. + + This will FAIL because current code returns list[str], not list[dict]. + """ + python_code = ''' +def function_one(): + """First function.""" + x = 1 + y = 2 + return x + y + +def function_two(): + """Second function.""" + return 42 +''' + doc = MockDocument( + python_code, + file_path="/test/code.py", + metadata={ + "language": "python", + "file_path": "/test/code.py", + "file_name": "code.py", + }, + ) + + # Mock astchunk with detailed metadata + mock_builder = Mock() + astchunk_chunks = [ + { + "content": "def function_one():\n x = 1\n y = 2\n return x + y", + "metadata": { + "filepath": "/test/code.py", + "line_count": 4, + "start_line_no": 1, + "end_line_no": 4, + "node_count": 5, # function, assignments, return + }, + }, + { + "content": "def function_two():\n return 42", + "metadata": { + "filepath": "/test/code.py", + "line_count": 2, + "start_line_no": 7, + "end_line_no": 8, + "node_count": 2, # function, return + }, + }, + ] + mock_builder.chunkify.return_value = astchunk_chunks + + mock_astchunk = Mock() + mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder) + + with patch.dict("sys.modules", {"astchunk": mock_astchunk}): + chunks = create_ast_chunks([doc]) + + # CRITICAL: These will FAIL with current list[str] return + assert len(chunks) == 2 + + # First chunk - function_one + chunk1 = chunks[0] + assert isinstance(chunk1, dict), "Chunk should be dict" + assert "metadata" in chunk1 + + metadata1 = chunk1["metadata"] + + # Check astchunk metadata is present + assert "line_count" in metadata1, "Should include astchunk line_count" + assert metadata1["line_count"] == 4, "line_count should be 4" + + assert "start_line_no" in metadata1, "Should include astchunk start_line_no" + assert metadata1["start_line_no"] == 1, "start_line_no should be 1" + + assert "end_line_no" in metadata1, "Should include astchunk end_line_no" + assert metadata1["end_line_no"] == 4, "end_line_no should be 4" + + assert "node_count" in metadata1, "Should include astchunk node_count" + assert metadata1["node_count"] == 5, "node_count should be 5" + + # Second chunk - function_two + chunk2 = chunks[1] + metadata2 = chunk2["metadata"] + + assert metadata2["line_count"] == 2, "line_count should be 2" + assert metadata2["start_line_no"] == 7, "start_line_no should be 7" + assert metadata2["end_line_no"] == 8, "end_line_no should be 8" + assert metadata2["node_count"] == 2, "node_count should be 2" + + # Verify document metadata is ALSO present (merged, not replaced) + assert metadata1["file_path"] == "/test/code.py" + assert metadata1["file_name"] == "code.py" + assert metadata2["file_path"] == "/test/code.py" + assert metadata2["file_name"] == "code.py" + + # Verify text content is correct + assert "def function_one" in chunk1["text"] + assert "def function_two" in chunk2["text"] + + def test_traditional_chunks_as_dicts_helper(self): + """Test the helper function that wraps traditional chunks as dicts. + + This test verifies that when create_traditional_chunks is called, + its plain string chunks are wrapped into dict format with metadata. + + This will FAIL because the helper function _traditional_chunks_as_dicts() + doesn't exist yet, and create_traditional_chunks returns list[str]. + """ + # Create documents with various metadata + docs = [ + MockDocument( + "This is the first paragraph of text. It contains multiple sentences. " + "This should be split into chunks based on size.", + file_path="/docs/readme.txt", + metadata={ + "file_path": "/docs/readme.txt", + "file_name": "readme.txt", + "creation_date": "2024-01-01", + }, + ), + MockDocument( + "Second document with different metadata. It also has content that needs chunking.", + file_path="/docs/guide.md", + metadata={ + "file_path": "/docs/guide.md", + "file_name": "guide.md", + "last_modified_date": "2024-10-31", + }, + ), + ] + + # Call create_traditional_chunks (which should now return list[dict]) + chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10) + + # CRITICAL: Will FAIL - current code returns list[str] + assert len(chunks) > 0, "Should return chunks" + + for i, chunk in enumerate(chunks): + # Structure assertions - WILL FAIL + assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}" + assert "text" in chunk, f"Chunk {i} must have 'text' key" + assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key" + + # Text should be non-empty + assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty" + + # Metadata should include document info + metadata = chunk["metadata"] + assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata" + assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata" + + # Verify metadata tracking works correctly + # At least one chunk should be from readme.txt + readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]] + assert len(readme_chunks) > 0, "Should have chunks from readme.txt" + + # At least one chunk should be from guide.md + guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]] + assert len(guide_chunks) > 0, "Should have chunks from guide.md" + + # Verify creation_date is preserved for readme chunks + for chunk in readme_chunks: + assert chunk["metadata"].get("creation_date") == "2024-01-01", ( + "readme.txt chunks should preserve creation_date" + ) + + # Verify last_modified_date is preserved for guide chunks + for chunk in guide_chunks: + assert chunk["metadata"].get("last_modified_date") == "2024-10-31", ( + "guide.md chunks should preserve last_modified_date" + ) + + # Verify text content is present + all_text = " ".join([c["text"] for c in chunks]) + assert "first paragraph" in all_text + assert "Second document" in all_text + + class TestErrorHandling: """Test error handling and edge cases.""" diff --git a/tests/test_token_truncation.py b/tests/test_token_truncation.py new file mode 100644 index 0000000..ad00e3a --- /dev/null +++ b/tests/test_token_truncation.py @@ -0,0 +1,268 @@ +"""Unit tests for token-aware truncation functionality. + +This test suite defines the contract for token truncation functions that prevent +500 errors from Ollama when text exceeds model token limits. These tests verify: + +1. Model token limit retrieval (known and unknown models) +2. Text truncation behavior for single and multiple texts +3. Token counting and truncation accuracy using tiktoken + +All tests are written in Red Phase - they should FAIL initially because the +implementation does not exist yet. +""" + +import pytest +import tiktoken +from leann.embedding_compute import ( + EMBEDDING_MODEL_LIMITS, + get_model_token_limit, + truncate_to_token_limit, +) + + +class TestModelTokenLimits: + """Tests for retrieving model-specific token limits.""" + + def test_get_model_token_limit_known_model(self): + """Verify correct token limit is returned for known models. + + Known models should return their specific token limits from + EMBEDDING_MODEL_LIMITS dictionary. + """ + # Test nomic-embed-text (2048 tokens) + limit = get_model_token_limit("nomic-embed-text") + assert limit == 2048, "nomic-embed-text should have 2048 token limit" + + # Test nomic-embed-text-v1.5 (2048 tokens) + limit = get_model_token_limit("nomic-embed-text-v1.5") + assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit" + + # Test nomic-embed-text-v2 (512 tokens) + limit = get_model_token_limit("nomic-embed-text-v2") + assert limit == 512, "nomic-embed-text-v2 should have 512 token limit" + + # Test OpenAI models (8192 tokens) + limit = get_model_token_limit("text-embedding-3-small") + assert limit == 8192, "text-embedding-3-small should have 8192 token limit" + + def test_get_model_token_limit_unknown_model(self): + """Verify default token limit is returned for unknown models. + + Unknown models should return the default limit (2048) to allow + operation with reasonable safety margin. + """ + # Test with completely unknown model + limit = get_model_token_limit("unknown-model-xyz") + assert limit == 2048, "Unknown models should return default 2048" + + # Test with empty string + limit = get_model_token_limit("") + assert limit == 2048, "Empty model name should return default 2048" + + def test_get_model_token_limit_custom_default(self): + """Verify custom default can be specified for unknown models. + + Allow callers to specify their own default token limit when + model is not in the known models dictionary. + """ + limit = get_model_token_limit("unknown-model", default=4096) + assert limit == 4096, "Should return custom default for unknown models" + + # Known model should ignore custom default + limit = get_model_token_limit("nomic-embed-text", default=4096) + assert limit == 2048, "Known model should ignore custom default" + + def test_embedding_model_limits_dictionary_exists(self): + """Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models. + + The dictionary should be importable and contain at least the + known nomic models with correct token limits. + """ + assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary" + assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text" + assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, ( + "Should contain nomic-embed-text-v1.5" + ) + assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048 + assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048 + assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512 + # OpenAI models + assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192 + + +class TestTokenTruncation: + """Tests for truncating texts to token limits.""" + + @pytest.fixture + def tokenizer(self): + """Provide tiktoken tokenizer for token counting verification.""" + return tiktoken.get_encoding("cl100k_base") + + def test_truncate_single_text_under_limit(self, tokenizer): + """Verify text under token limit remains unchanged. + + When text is already within the token limit, it should be + returned unchanged with no truncation. + """ + text = "This is a short text that is well under the token limit." + token_count = len(tokenizer.encode(text)) + assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)" + + # Truncate with generous limit + result = truncate_to_token_limit([text], token_limit=512) + + assert len(result) == 1, "Should return same number of texts" + assert result[0] == text, "Text under limit should be unchanged" + + def test_truncate_single_text_over_limit(self, tokenizer): + """Verify text over token limit is truncated correctly. + + When text exceeds the token limit, it should be truncated to + fit within the limit while maintaining valid token boundaries. + """ + # Create a text that definitely exceeds limit + text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens) + original_token_count = len(tokenizer.encode(text)) + assert original_token_count > 50, ( + f"Test setup: text should be long (has {original_token_count} tokens)" + ) + + # Truncate to 50 tokens + result = truncate_to_token_limit([text], token_limit=50) + + assert len(result) == 1, "Should return same number of texts" + assert result[0] != text, "Text over limit should be truncated" + assert len(result[0]) < len(text), "Truncated text should be shorter" + + # Verify truncated text is within token limit + truncated_token_count = len(tokenizer.encode(result[0])) + assert truncated_token_count <= 50, ( + f"Truncated text should be ≤50 tokens, got {truncated_token_count}" + ) + + def test_truncate_multiple_texts_mixed_lengths(self, tokenizer): + """Verify multiple texts with mixed lengths are handled correctly. + + When processing multiple texts: + - Texts under limit should remain unchanged + - Texts over limit should be truncated independently + - Output list should maintain same order and length + """ + texts = [ + "Short text.", # Under limit + "word " * 200, # Over limit + "Another short one.", # Under limit + "token " * 150, # Over limit + ] + + # Verify test setup + for i, text in enumerate(texts): + token_count = len(tokenizer.encode(text)) + if i in [1, 3]: + assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)" + else: + assert token_count < 50, ( + f"Text {i} should be under limit (has {token_count} tokens)" + ) + + # Truncate with 50 token limit + result = truncate_to_token_limit(texts, token_limit=50) + + assert len(result) == len(texts), "Should return same number of texts" + + # Verify each text individually + for i, (original, truncated) in enumerate(zip(texts, result)): + token_count = len(tokenizer.encode(truncated)) + assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}" + + # Short texts should be unchanged + if i in [0, 2]: + assert truncated == original, f"Short text {i} should be unchanged" + # Long texts should be truncated + else: + assert len(truncated) < len(original), f"Long text {i} should be truncated" + + def test_truncate_empty_list(self): + """Verify empty input list returns empty output list. + + Edge case: empty list should return empty list without errors. + """ + result = truncate_to_token_limit([], token_limit=512) + assert result == [], "Empty input should return empty output" + + def test_truncate_preserves_order(self, tokenizer): + """Verify truncation preserves original text order. + + Output list should maintain the same order as input list, + regardless of which texts were truncated. + """ + texts = [ + "First text " * 50, # Will be truncated + "Second text.", # Won't be truncated + "Third text " * 50, # Will be truncated + ] + + result = truncate_to_token_limit(texts, token_limit=20) + + assert len(result) == 3, "Should preserve list length" + # Check that order is maintained by looking for distinctive words + assert "First" in result[0], "First text should remain in first position" + assert "Second" in result[1], "Second text should remain in second position" + assert "Third" in result[2], "Third text should remain in third position" + + def test_truncate_extremely_long_text(self, tokenizer): + """Verify extremely long texts are truncated efficiently. + + Test with text that far exceeds token limit to ensure + truncation handles extreme cases without performance issues. + """ + # Create very long text (simulate real-world scenario) + text = "token " * 5000 # ~5000+ tokens + original_token_count = len(tokenizer.encode(text)) + assert original_token_count > 1000, "Test setup: text should be very long" + + # Truncate to small limit + result = truncate_to_token_limit([text], token_limit=100) + + assert len(result) == 1 + truncated_token_count = len(tokenizer.encode(result[0])) + assert truncated_token_count <= 100, ( + f"Should truncate to ≤100 tokens, got {truncated_token_count}" + ) + assert len(result[0]) < len(text) // 10, "Should significantly reduce text length" + + def test_truncate_exact_token_limit(self, tokenizer): + """Verify text at exactly token limit is handled correctly. + + Edge case: text with exactly the token limit should either + remain unchanged or be safely truncated by 1 token. + """ + # Create text with approximately 50 tokens + # We'll adjust to get exactly 50 + target_tokens = 50 + text = "word " * 50 + tokens = tokenizer.encode(text) + + # Adjust to get exactly target_tokens + if len(tokens) > target_tokens: + tokens = tokens[:target_tokens] + text = tokenizer.decode(tokens) + elif len(tokens) < target_tokens: + # Add more words + while len(tokenizer.encode(text)) < target_tokens: + text += "word " + tokens = tokenizer.encode(text)[:target_tokens] + text = tokenizer.decode(tokens) + + # Verify we have exactly target_tokens + assert len(tokenizer.encode(text)) == target_tokens, ( + "Test setup: should have exactly 50 tokens" + ) + + result = truncate_to_token_limit([text], token_limit=target_tokens) + + assert len(result) == 1 + result_tokens = len(tokenizer.encode(result[0])) + assert result_tokens <= target_tokens, ( + f"Should be ≤{target_tokens} tokens, got {result_tokens}" + )