diff --git a/apps/base_rag_example.py b/apps/base_rag_example.py index d797865..e67ee56 100644 --- a/apps/base_rag_example.py +++ b/apps/base_rag_example.py @@ -180,14 +180,14 @@ class BaseRAGExample(ABC): ast_group.add_argument( "--ast-chunk-size", type=int, - default=512, - help="Maximum characters per AST chunk (default: 512)", + default=300, + help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars", ) ast_group.add_argument( "--ast-chunk-overlap", type=int, default=64, - help="Overlap between AST chunks (default: 64)", + help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it", ) ast_group.add_argument( "--code-file-extensions", diff --git a/packages/leann-core/src/leann/chunking_utils.py b/packages/leann-core/src/leann/chunking_utils.py index db80a39..0140930 100644 --- a/packages/leann-core/src/leann/chunking_utils.py +++ b/packages/leann-core/src/leann/chunking_utils.py @@ -11,6 +11,119 @@ from llama_index.core.node_parser import SentenceSplitter logger = logging.getLogger(__name__) + +def estimate_token_count(text: str) -> int: + """ + Estimate token count for a text string. + Uses conservative estimation: ~4 characters per token for natural text, + ~1.2 tokens per character for code (worse tokenization). + + Args: + text: Input text to estimate tokens for + + Returns: + Estimated token count + """ + try: + import tiktoken + + encoder = tiktoken.get_encoding("cl100k_base") + return len(encoder.encode(text)) + except ImportError: + # Fallback: Conservative character-based estimation + # Assume worst case for code: 1.2 tokens per character + return int(len(text) * 1.2) + + +def calculate_safe_chunk_size( + model_token_limit: int, + overlap_tokens: int, + chunking_mode: str = "traditional", + safety_factor: float = 0.9, +) -> int: + """ + Calculate safe chunk size accounting for overlap and safety margin. + + Args: + model_token_limit: Maximum tokens supported by embedding model + overlap_tokens: Overlap size (tokens for traditional, chars for AST) + chunking_mode: "traditional" (tokens) or "ast" (characters) + safety_factor: Safety margin (0.9 = 10% safety margin) + + Returns: + Safe chunk size: tokens for traditional, characters for AST + """ + safe_limit = int(model_token_limit * safety_factor) + + if chunking_mode == "traditional": + # Traditional chunking uses tokens + # Max chunk = chunk_size + overlap, so chunk_size = limit - overlap + return max(1, safe_limit - overlap_tokens) + else: # AST chunking + # AST uses characters, need to convert + # Conservative estimate: 1.2 tokens per char for code + overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code + safe_chars = int(safe_limit / 1.2) + return max(1, safe_chars - overlap_chars) + + +def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]: + """ + Validate that chunks don't exceed token limits and truncate if necessary. + + Args: + chunks: List of text chunks to validate + max_tokens: Maximum tokens allowed per chunk + + Returns: + Tuple of (validated_chunks, num_truncated) + """ + validated_chunks = [] + num_truncated = 0 + + for i, chunk in enumerate(chunks): + estimated_tokens = estimate_token_count(chunk) + + if estimated_tokens > max_tokens: + # Truncate chunk to fit token limit + try: + import tiktoken + + encoder = tiktoken.get_encoding("cl100k_base") + tokens = encoder.encode(chunk) + if len(tokens) > max_tokens: + truncated_tokens = tokens[:max_tokens] + truncated_chunk = encoder.decode(truncated_tokens) + validated_chunks.append(truncated_chunk) + num_truncated += 1 + logger.warning( + f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens " + f"(from {len(chunk)} to {len(truncated_chunk)} characters)" + ) + else: + validated_chunks.append(chunk) + except ImportError: + # Fallback: Conservative character truncation + char_limit = int(max_tokens / 1.2) # Conservative for code + if len(chunk) > char_limit: + truncated_chunk = chunk[:char_limit] + validated_chunks.append(truncated_chunk) + num_truncated += 1 + logger.warning( + f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters " + f"(conservative estimate for {max_tokens} tokens)" + ) + else: + validated_chunks.append(chunk) + else: + validated_chunks.append(chunk) + + if num_truncated > 0: + logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits") + + return validated_chunks, num_truncated + + # Code file extensions supported by astchunk CODE_EXTENSIONS = { ".py": "python", @@ -82,6 +195,17 @@ def create_ast_chunks( continue try: + # Warn if AST chunk size + overlap might exceed common token limits + estimated_max_tokens = int( + (max_chunk_size + chunk_overlap) * 1.2 + ) # Conservative estimate + if estimated_max_tokens > 512: + logger.warning( + f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars " + f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). " + f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}" + ) + configs = { "max_chunk_size": max_chunk_size, "language": language, @@ -217,4 +341,14 @@ def create_text_chunks( all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap) logger.info(f"Total chunks created: {len(all_chunks)}") - return all_chunks + + # Validate chunk token limits (default to 512 for safety) + # This provides a safety net for embedding models with token limits + validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512) + + if num_truncated > 0: + logger.info( + f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit" + ) + + return validated_chunks diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index da2fd7d..db011eb 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -181,25 +181,25 @@ Examples: "--doc-chunk-size", type=int, default=256, - help="Document chunk size in tokens/characters (default: 256)", + help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)", ) build_parser.add_argument( "--doc-chunk-overlap", type=int, default=128, - help="Document chunk overlap (default: 128)", + help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it", ) build_parser.add_argument( "--code-chunk-size", type=int, default=512, - help="Code chunk size in tokens/lines (default: 512)", + help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)", ) build_parser.add_argument( "--code-chunk-overlap", type=int, default=50, - help="Code chunk overlap (default: 50)", + help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it", ) build_parser.add_argument( "--use-ast-chunking", @@ -209,14 +209,14 @@ Examples: build_parser.add_argument( "--ast-chunk-size", type=int, - default=768, - help="AST chunk size in characters (default: 768)", + default=300, + help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)", ) build_parser.add_argument( "--ast-chunk-overlap", type=int, - default=96, - help="AST chunk overlap in characters (default: 96)", + default=64, + help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code", ) build_parser.add_argument( "--ast-fallback-traditional", diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index e3d9f86..a79a39e 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -14,6 +14,89 @@ import torch from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url + +def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]: + """ + Truncate texts to token limit using tiktoken or conservative character truncation. + + Args: + texts: List of texts to truncate + max_tokens: Maximum tokens allowed per text + + Returns: + List of truncated texts that should fit within token limit + """ + try: + import tiktoken + + encoder = tiktoken.get_encoding("cl100k_base") + truncated = [] + + for text in texts: + tokens = encoder.encode(text) + if len(tokens) > max_tokens: + # Truncate to max_tokens and decode back to text + truncated_tokens = tokens[:max_tokens] + truncated_text = encoder.decode(truncated_tokens) + truncated.append(truncated_text) + logger.warning( + f"Truncated text from {len(tokens)} to {max_tokens} tokens " + f"(from {len(text)} to {len(truncated_text)} characters)" + ) + else: + truncated.append(text) + return truncated + + except ImportError: + # Fallback: Conservative character truncation + # Assume worst case: 1.5 tokens per character for code content + char_limit = int(max_tokens / 1.5) + truncated = [] + + for text in texts: + if len(text) > char_limit: + truncated_text = text[:char_limit] + truncated.append(truncated_text) + logger.warning( + f"Truncated text from {len(text)} to {char_limit} characters " + f"(conservative estimate for {max_tokens} tokens)" + ) + else: + truncated.append(text) + return truncated + + +def get_model_token_limit(model_name: str) -> int: + """ + Get token limit for a given embedding model. + + Args: + model_name: Name of the embedding model + + Returns: + Token limit for the model, defaults to 512 if unknown + """ + # Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text") + base_model_name = model_name.split(":")[0] + + # Check exact match first + if model_name in EMBEDDING_MODEL_LIMITS: + return EMBEDDING_MODEL_LIMITS[model_name] + + # Check base name match + if base_model_name in EMBEDDING_MODEL_LIMITS: + return EMBEDDING_MODEL_LIMITS[base_model_name] + + # Check partial matches for common patterns + for known_model, limit in EMBEDDING_MODEL_LIMITS.items(): + if known_model in base_model_name or base_model_name in known_model: + return limit + + # Default to conservative 512 token limit + logger.warning(f"Unknown model '{model_name}', using default 512 token limit") + return 512 + + # Set up logger with proper level logger = logging.getLogger(__name__) LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() @@ -23,6 +106,17 @@ logger.setLevel(log_level) # Global model cache to avoid repeated loading _model_cache: dict[str, Any] = {} +# Known embedding model token limits +EMBEDDING_MODEL_LIMITS = { + "nomic-embed-text": 512, + "nomic-embed-text-v2": 512, + "mxbai-embed-large": 512, + "all-minilm": 512, + "bge-m3": 8192, + "snowflake-arctic-embed": 512, + # Add more models as needed +} + def compute_embeddings( texts: list[str], @@ -720,20 +814,28 @@ def compute_embeddings_ollama( logger.info(f"Using batch size: {batch_size} for true batch processing") + # Get model token limit and apply truncation + token_limit = get_model_token_limit(model_name) + logger.info(f"Model '{model_name}' token limit: {token_limit}") + + # Apply token-aware truncation to all texts + truncated_texts = truncate_to_token_limit(texts, token_limit) + if len(truncated_texts) != len(texts): + logger.error("Truncation failed - text count mismatch") + truncated_texts = texts # Fallback to original texts + def get_batch_embeddings(batch_texts): """Get embeddings for a batch of texts using /api/embed endpoint.""" max_retries = 3 retry_count = 0 - # Truncate very long texts to avoid API issues - truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts] - + # Texts are already truncated to token limit by the outer function while retry_count < max_retries: try: # Use /api/embed endpoint with "input" parameter for batch processing response = requests.post( f"{resolved_host}/api/embed", - json={"model": model_name, "input": truncated_texts}, + json={"model": model_name, "input": batch_texts}, timeout=60, # Increased timeout for batch processing ) response.raise_for_status() @@ -763,17 +865,27 @@ def compute_embeddings_ollama( except Exception as e: retry_count += 1 if retry_count >= max_retries: - logger.error(f"Failed to get embeddings for batch: {e}") + # Enhanced error detection for token limit violations + error_msg = str(e).lower() + if "token" in error_msg and ( + "limit" in error_msg or "exceed" in error_msg or "length" in error_msg + ): + logger.error( + f"Token limit exceeded for batch. Error: {e}. " + f"Consider reducing chunk sizes or check token truncation." + ) + else: + logger.error(f"Failed to get embeddings for batch: {e}") return None, list(range(len(batch_texts))) return None, list(range(len(batch_texts))) - # Process texts in batches + # Process truncated texts in batches all_embeddings = [] all_failed_indices = [] # Setup progress bar if needed - show_progress = is_build or len(texts) > 10 + show_progress = is_build or len(truncated_texts) > 10 try: if show_progress: from tqdm import tqdm @@ -781,7 +893,7 @@ def compute_embeddings_ollama( show_progress = False # Process batches - num_batches = (len(texts) + batch_size - 1) // batch_size + num_batches = (len(truncated_texts) + batch_size - 1) // batch_size if show_progress: batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)") @@ -790,8 +902,8 @@ def compute_embeddings_ollama( for batch_idx in batch_iterator: start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, len(texts)) - batch_texts = texts[start_idx:end_idx] + end_idx = min(start_idx + batch_size, len(truncated_texts)) + batch_texts = truncated_texts[start_idx:end_idx] batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) @@ -806,11 +918,11 @@ def compute_embeddings_ollama( # Handle failed embeddings if all_failed_indices: - if len(all_failed_indices) == len(texts): + if len(all_failed_indices) == len(truncated_texts): raise RuntimeError("Failed to compute any embeddings") logger.warning( - f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts" + f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts" ) # Use zero embeddings as fallback for failed ones