metadata reveal for ast-chunking; smart detection of seq length in ollama; auto adjust chunk length for ast to prevent silent truncation (#157)
* feat: enhance token limits with dynamic discovery + AST metadata Improves upon upstream PR #154 with two major enhancements: 1. **Hybrid Token Limit Discovery** - Dynamic: Query Ollama /api/show for context limits - Fallback: Registry for LM Studio/OpenAI - Zero maintenance for Ollama users - Respects custom num_ctx settings 2. **AST Metadata Preservation** - create_ast_chunks() returns dict format with metadata - Preserves file_path, file_name, timestamps - Includes astchunk metadata (line numbers, node counts) - Fixes content extraction bug (checks "content" key) - Enables --show-metadata flag 3. **Better Token Limits** - nomic-embed-text: 2048 tokens (vs 512) - nomic-embed-text-v1.5: 2048 tokens - Added OpenAI models: 8192 tokens 4. **Comprehensive Tests** - 11 tests for token truncation - 545 new lines in test_astchunk_integration.py - All metadata preservation tests passing * fix: merge EMBEDDING_MODEL_LIMITS and remove redundant validation - Merged upstream's model list with our corrected token limits - Kept our corrected nomic-embed-text: 2048 (not 512) - Removed post-chunking validation (redundant with embedding-time truncation) - All tests passing except 2 pre-existing integration test failures * style: apply ruff formatting and restore PR #154 version handling - Remove duplicate truncate_to_token_limit and get_model_token_limit functions - Restore version handling logic (model:latest -> model) from PR #154 - Restore partial matching fallback for model name variations - Apply ruff formatting to all modified files - All 11 token truncation tests passing * style: sort imports alphabetically (pre-commit auto-fix) * fix: show AST token limit warning only once per session - Add module-level flag to track if warning shown - Prevents spam when processing multiple files - Add clarifying note that auto-truncation happens at embedding time - Addresses issue where warning appeared for every code file * enhance: add detailed logging for token truncation - Track and report truncation statistics (count, tokens removed, max length) - Show first 3 individual truncations with exact token counts - Provide comprehensive summary when truncation occurs - Use WARNING level for data loss visibility - Silent (DEBUG level only) when no truncation needed Replaces misleading "truncated where necessary" message that appeared even when nothing was truncated.
This commit is contained in:
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -25,6 +26,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
@@ -36,6 +38,7 @@ except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"_traditional_chunks_as_dicts",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
|
||||
@@ -5,12 +5,15 @@ Packaged within leann-core so installed wheels can import it reliably.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Flag to ensure AST token warning only shown once per session
|
||||
_ast_token_warning_shown = False
|
||||
|
||||
|
||||
def estimate_token_count(text: str) -> int:
|
||||
"""
|
||||
@@ -174,37 +177,44 @@ def create_ast_chunks(
|
||||
max_chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
metadata_template: str = "default",
|
||||
) -> list[str]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create AST-aware chunks from code documents using astchunk.
|
||||
|
||||
Falls back to traditional chunking if astchunk is unavailable.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
try:
|
||||
from astchunk import ASTChunkBuilder # optional dependency
|
||||
except ImportError as e:
|
||||
logger.error(f"astchunk not available: {e}")
|
||||
logger.info("Falling back to traditional chunking for code files")
|
||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||
return _traditional_chunks_as_dicts(documents, max_chunk_size, chunk_overlap)
|
||||
|
||||
all_chunks = []
|
||||
for doc in documents:
|
||||
language = doc.metadata.get("language")
|
||||
if not language:
|
||||
logger.warning("No language detected; falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
continue
|
||||
|
||||
try:
|
||||
# Warn if AST chunk size + overlap might exceed common token limits
|
||||
# Warn once if AST chunk size + overlap might exceed common token limits
|
||||
# Note: Actual truncation happens at embedding time with dynamic model limits
|
||||
global _ast_token_warning_shown
|
||||
estimated_max_tokens = int(
|
||||
(max_chunk_size + chunk_overlap) * 1.2
|
||||
) # Conservative estimate
|
||||
if estimated_max_tokens > 512:
|
||||
if estimated_max_tokens > 512 and not _ast_token_warning_shown:
|
||||
logger.warning(
|
||||
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
||||
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}. "
|
||||
f"Note: Chunks will be auto-truncated at embedding time based on your model's actual token limit."
|
||||
)
|
||||
_ast_token_warning_shown = True
|
||||
|
||||
configs = {
|
||||
"max_chunk_size": max_chunk_size,
|
||||
@@ -229,17 +239,40 @@ def create_ast_chunks(
|
||||
|
||||
chunks = chunk_builder.chunkify(code_content)
|
||||
for chunk in chunks:
|
||||
chunk_text = None
|
||||
astchunk_metadata = {}
|
||||
|
||||
if hasattr(chunk, "text"):
|
||||
chunk_text = chunk.text
|
||||
elif isinstance(chunk, dict) and "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
elif isinstance(chunk, str):
|
||||
chunk_text = chunk
|
||||
elif isinstance(chunk, dict):
|
||||
# Handle astchunk format: {"content": "...", "metadata": {...}}
|
||||
if "content" in chunk:
|
||||
chunk_text = chunk["content"]
|
||||
astchunk_metadata = chunk.get("metadata", {})
|
||||
elif "text" in chunk:
|
||||
chunk_text = chunk["text"]
|
||||
else:
|
||||
chunk_text = str(chunk) # Last resort
|
||||
else:
|
||||
chunk_text = str(chunk)
|
||||
|
||||
if chunk_text and chunk_text.strip():
|
||||
all_chunks.append(chunk_text.strip())
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
# Merge document metadata + astchunk metadata
|
||||
combined_metadata = {**doc_metadata, **astchunk_metadata}
|
||||
|
||||
all_chunks.append({"text": chunk_text.strip(), "metadata": combined_metadata})
|
||||
|
||||
logger.info(
|
||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||
@@ -247,15 +280,19 @@ def create_ast_chunks(
|
||||
except Exception as e:
|
||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||
logger.info("Falling back to traditional chunking")
|
||||
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts([doc], max_chunk_size, chunk_overlap))
|
||||
|
||||
return all_chunks
|
||||
|
||||
|
||||
def create_traditional_chunks(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[str]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create traditional text chunks using LlamaIndex SentenceSplitter.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||
chunk_size = 256
|
||||
@@ -271,19 +308,40 @@ def create_traditional_chunks(
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
result = []
|
||||
for doc in documents:
|
||||
# Extract document-level metadata
|
||||
doc_metadata = {
|
||||
"file_path": doc.metadata.get("file_path", ""),
|
||||
"file_name": doc.metadata.get("file_name", ""),
|
||||
}
|
||||
if "creation_date" in doc.metadata:
|
||||
doc_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||
if "last_modified_date" in doc.metadata:
|
||||
doc_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||
|
||||
try:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
for node in nodes:
|
||||
result.append({"text": node.get_content(), "metadata": doc_metadata})
|
||||
except Exception as e:
|
||||
logger.error(f"Traditional chunking failed for document: {e}")
|
||||
content = doc.get_content()
|
||||
if content and content.strip():
|
||||
all_texts.append(content.strip())
|
||||
result.append({"text": content.strip(), "metadata": doc_metadata})
|
||||
|
||||
return all_texts
|
||||
return result
|
||||
|
||||
|
||||
def _traditional_chunks_as_dicts(
|
||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Helper: Traditional chunking that returns dict format for consistency.
|
||||
|
||||
This is now just an alias for create_traditional_chunks for backwards compatibility.
|
||||
"""
|
||||
return create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
|
||||
|
||||
def create_text_chunks(
|
||||
@@ -295,8 +353,12 @@ def create_text_chunks(
|
||||
ast_chunk_overlap: int = 64,
|
||||
code_file_extensions: Optional[list[str]] = None,
|
||||
ast_fallback_traditional: bool = True,
|
||||
) -> list[str]:
|
||||
"""Create text chunks from documents with optional AST support for code files."""
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Create text chunks from documents with optional AST support for code files.
|
||||
|
||||
Returns:
|
||||
List of dicts with {"text": str, "metadata": dict}
|
||||
"""
|
||||
if not documents:
|
||||
logger.warning("No documents provided for chunking")
|
||||
return []
|
||||
@@ -331,24 +393,17 @@ def create_text_chunks(
|
||||
logger.error(f"AST chunking failed: {e}")
|
||||
if ast_fallback_traditional:
|
||||
all_chunks.extend(
|
||||
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||
_traditional_chunks_as_dicts(code_docs, chunk_size, chunk_overlap)
|
||||
)
|
||||
else:
|
||||
raise
|
||||
if text_docs:
|
||||
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||
all_chunks.extend(_traditional_chunks_as_dicts(text_docs, chunk_size, chunk_overlap))
|
||||
else:
|
||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||
all_chunks = _traditional_chunks_as_dicts(documents, chunk_size, chunk_overlap)
|
||||
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
|
||||
# Validate chunk token limits (default to 512 for safety)
|
||||
# This provides a safety net for embedding models with token limits
|
||||
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
||||
|
||||
if num_truncated > 0:
|
||||
logger.info(
|
||||
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
||||
)
|
||||
|
||||
return validated_chunks
|
||||
# Note: Token truncation is now handled at embedding time with dynamic model limits
|
||||
# See get_model_token_limit() and truncate_to_token_limit() in embedding_compute.py
|
||||
return all_chunks
|
||||
|
||||
@@ -1279,13 +1279,8 @@ Examples:
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# Note: AST chunking currently returns plain text chunks without metadata
|
||||
# We preserve basic file info by associating chunks with their source documents
|
||||
# For better metadata preservation, documents list order should be maintained
|
||||
for chunk_text in chunk_texts:
|
||||
# TODO: Enhance create_text_chunks to return metadata alongside text
|
||||
# For now, we store chunks with empty metadata
|
||||
all_texts.append({"text": chunk_text, "metadata": {}})
|
||||
# create_text_chunks now returns list[dict] with metadata preserved
|
||||
all_texts.extend(chunk_texts)
|
||||
|
||||
except ImportError as e:
|
||||
print(
|
||||
|
||||
@@ -10,72 +10,63 @@ import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
||||
"""
|
||||
Truncate texts to token limit using tiktoken or conservative character truncation.
|
||||
|
||||
Args:
|
||||
texts: List of texts to truncate
|
||||
max_tokens: Maximum tokens allowed per text
|
||||
|
||||
Returns:
|
||||
List of truncated texts that should fit within token limit
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
truncated = []
|
||||
|
||||
for text in texts:
|
||||
tokens = encoder.encode(text)
|
||||
if len(tokens) > max_tokens:
|
||||
# Truncate to max_tokens and decode back to text
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_text = encoder.decode(truncated_tokens)
|
||||
truncated.append(truncated_text)
|
||||
logger.warning(
|
||||
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
||||
f"(from {len(text)} to {len(truncated_text)} characters)"
|
||||
)
|
||||
else:
|
||||
truncated.append(text)
|
||||
return truncated
|
||||
|
||||
except ImportError:
|
||||
# Fallback: Conservative character truncation
|
||||
# Assume worst case: 1.5 tokens per character for code content
|
||||
char_limit = int(max_tokens / 1.5)
|
||||
truncated = []
|
||||
|
||||
for text in texts:
|
||||
if len(text) > char_limit:
|
||||
truncated_text = text[:char_limit]
|
||||
truncated.append(truncated_text)
|
||||
logger.warning(
|
||||
f"Truncated text from {len(text)} to {char_limit} characters "
|
||||
f"(conservative estimate for {max_tokens} tokens)"
|
||||
)
|
||||
else:
|
||||
truncated.append(text)
|
||||
return truncated
|
||||
# Token limit registry for embedding models
|
||||
# Used as fallback when dynamic discovery fails (e.g., LM Studio, OpenAI)
|
||||
# Ollama models use dynamic discovery via /api/show
|
||||
EMBEDDING_MODEL_LIMITS = {
|
||||
# Nomic models (common across servers)
|
||||
"nomic-embed-text": 2048, # Corrected from 512 - verified via /api/show
|
||||
"nomic-embed-text-v1.5": 2048,
|
||||
"nomic-embed-text-v2": 512,
|
||||
# Other embedding models
|
||||
"mxbai-embed-large": 512,
|
||||
"all-minilm": 512,
|
||||
"bge-m3": 8192,
|
||||
"snowflake-arctic-embed": 512,
|
||||
# OpenAI models
|
||||
"text-embedding-3-small": 8192,
|
||||
"text-embedding-3-large": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
|
||||
def get_model_token_limit(model_name: str) -> int:
|
||||
def get_model_token_limit(
|
||||
model_name: str,
|
||||
base_url: Optional[str] = None,
|
||||
default: int = 2048,
|
||||
) -> int:
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
base_url: Base URL of the embedding server (for dynamic discovery)
|
||||
default: Default token limit if model not found
|
||||
|
||||
Returns:
|
||||
Token limit for the model, defaults to 512 if unknown
|
||||
Token limit for the model in tokens
|
||||
"""
|
||||
# Try Ollama dynamic discovery if base_url provided
|
||||
if base_url:
|
||||
# Detect Ollama servers by port or "ollama" in URL
|
||||
if "11434" in base_url or "ollama" in base_url.lower():
|
||||
limit = _query_ollama_context_limit(model_name, base_url)
|
||||
if limit:
|
||||
return limit
|
||||
|
||||
# Fallback to known model registry with version handling (from PR #154)
|
||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||
base_model_name = model_name.split(":")[0]
|
||||
|
||||
@@ -92,31 +83,111 @@ def get_model_token_limit(model_name: str) -> int:
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
return limit
|
||||
|
||||
# Default to conservative 512 token limit
|
||||
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
||||
return 512
|
||||
# Default fallback
|
||||
logger.warning(f"Unknown model '{model_name}', using default {default} token limit")
|
||||
return default
|
||||
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
def truncate_to_token_limit(texts: list[str], token_limit: int) -> list[str]:
|
||||
"""
|
||||
Truncate texts to fit within token limit using tiktoken.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to truncate
|
||||
token_limit: Maximum number of tokens allowed
|
||||
|
||||
Returns:
|
||||
List of truncated texts (same length as input)
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Use tiktoken with cl100k_base encoding
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
truncated_texts = []
|
||||
truncation_count = 0
|
||||
total_tokens_removed = 0
|
||||
max_original_length = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokens = enc.encode(text)
|
||||
original_length = len(tokens)
|
||||
|
||||
if original_length <= token_limit:
|
||||
# Text is within limit, keep as is
|
||||
truncated_texts.append(text)
|
||||
else:
|
||||
# Truncate to token_limit
|
||||
truncated_tokens = tokens[:token_limit]
|
||||
truncated_text = enc.decode(truncated_tokens)
|
||||
truncated_texts.append(truncated_text)
|
||||
|
||||
# Track truncation statistics
|
||||
truncation_count += 1
|
||||
tokens_removed = original_length - token_limit
|
||||
total_tokens_removed += tokens_removed
|
||||
max_original_length = max(max_original_length, original_length)
|
||||
|
||||
# Log individual truncation at WARNING level (first few only)
|
||||
if truncation_count <= 3:
|
||||
logger.warning(
|
||||
f"Text {i + 1} truncated: {original_length} → {token_limit} tokens "
|
||||
f"({tokens_removed} tokens removed)"
|
||||
)
|
||||
elif truncation_count == 4:
|
||||
logger.warning("Further truncation warnings suppressed...")
|
||||
|
||||
# Log summary at INFO level
|
||||
if truncation_count > 0:
|
||||
logger.warning(
|
||||
f"Truncation summary: {truncation_count}/{len(texts)} texts truncated "
|
||||
f"(removed {total_tokens_removed} tokens total, longest was {max_original_length} tokens)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"No truncation needed - all {len(texts)} texts within {token_limit} token limit"
|
||||
)
|
||||
|
||||
return truncated_texts
|
||||
|
||||
|
||||
def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int]:
|
||||
"""
|
||||
Query Ollama /api/show for model context limit.
|
||||
|
||||
Args:
|
||||
model_name: Name of the Ollama model
|
||||
base_url: Base URL of the Ollama server
|
||||
|
||||
Returns:
|
||||
Context limit in tokens if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/api/show",
|
||||
json={"name": model_name},
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "model_info" in data:
|
||||
# Look for *.context_length in model_info
|
||||
for key, value in data["model_info"].items():
|
||||
if "context_length" in key and isinstance(value, int):
|
||||
logger.info(f"Detected {model_name} context limit: {value} tokens")
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to query Ollama context limit: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
# Known embedding model token limits
|
||||
EMBEDDING_MODEL_LIMITS = {
|
||||
"nomic-embed-text": 512,
|
||||
"nomic-embed-text-v2": 512,
|
||||
"mxbai-embed-large": 512,
|
||||
"all-minilm": 512,
|
||||
"bge-m3": 8192,
|
||||
"snowflake-arctic-embed": 512,
|
||||
# Add more models as needed
|
||||
}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: list[str],
|
||||
@@ -814,15 +885,13 @@ def compute_embeddings_ollama(
|
||||
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Get model token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name)
|
||||
# Get model token limit and apply truncation before batching
|
||||
token_limit = get_model_token_limit(model_name, base_url=resolved_host)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
# Apply token-aware truncation to all texts
|
||||
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
||||
if len(truncated_texts) != len(texts):
|
||||
logger.error("Truncation failed - text count mismatch")
|
||||
truncated_texts = texts # Fallback to original texts
|
||||
# Apply truncation to all texts before batch processing
|
||||
# Function logs truncation details internally
|
||||
texts = truncate_to_token_limit(texts, token_limit)
|
||||
|
||||
def get_batch_embeddings(batch_texts):
|
||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||
@@ -880,12 +949,12 @@ def compute_embeddings_ollama(
|
||||
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
# Process truncated texts in batches
|
||||
# Process texts in batches
|
||||
all_embeddings = []
|
||||
all_failed_indices = []
|
||||
|
||||
# Setup progress bar if needed
|
||||
show_progress = is_build or len(truncated_texts) > 10
|
||||
show_progress = is_build or len(texts) > 10
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
@@ -893,7 +962,7 @@ def compute_embeddings_ollama(
|
||||
show_progress = False
|
||||
|
||||
# Process batches
|
||||
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
if show_progress:
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||
@@ -902,8 +971,8 @@ def compute_embeddings_ollama(
|
||||
|
||||
for batch_idx in batch_iterator:
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
||||
batch_texts = truncated_texts[start_idx:end_idx]
|
||||
end_idx = min(start_idx + batch_size, len(texts))
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
|
||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||
|
||||
@@ -918,11 +987,11 @@ def compute_embeddings_ollama(
|
||||
|
||||
# Handle failed embeddings
|
||||
if all_failed_indices:
|
||||
if len(all_failed_indices) == len(truncated_texts):
|
||||
if len(all_failed_indices) == len(texts):
|
||||
raise RuntimeError("Failed to compute any embeddings")
|
||||
|
||||
logger.warning(
|
||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||
)
|
||||
|
||||
# Use zero embeddings as fallback for failed ones
|
||||
|
||||
@@ -8,7 +8,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -116,8 +116,10 @@ class TestChunkingFunctions:
|
||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||
# Traditional chunks now return dict format for consistency
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks)
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks)
|
||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks)
|
||||
|
||||
def test_create_traditional_chunks_empty_docs(self):
|
||||
"""Test traditional chunking with empty documents."""
|
||||
@@ -158,11 +160,22 @@ class Calculator:
|
||||
|
||||
# Should have multiple chunks due to different functions/classes
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||
# R3: Expect dict format with "text" and "metadata" keys
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
assert all(len(chunk["text"].strip()) > 0 for chunk in chunks), (
|
||||
"Each chunk text should be non-empty"
|
||||
)
|
||||
|
||||
# Check metadata is present
|
||||
assert all("file_path" in chunk["metadata"] for chunk in chunks), (
|
||||
"Each chunk should have file_path metadata"
|
||||
)
|
||||
|
||||
# Check that code structure is somewhat preserved
|
||||
combined_content = " ".join(chunks)
|
||||
combined_content = " ".join([c["text"] for c in chunks])
|
||||
assert "def hello_world" in combined_content
|
||||
assert "class Calculator" in combined_content
|
||||
|
||||
@@ -194,7 +207,11 @@ class Calculator:
|
||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
# R3: Traditional chunking should also return dict format for consistency
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
|
||||
def test_create_text_chunks_ast_mode(self):
|
||||
"""Test text chunking in AST mode."""
|
||||
@@ -213,7 +230,11 @@ class Calculator:
|
||||
)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||
# R3: AST mode should also return dict format
|
||||
assert all(isinstance(chunk, dict) for chunk in chunks), "All chunks should be dicts"
|
||||
assert all("text" in chunk and "metadata" in chunk for chunk in chunks), (
|
||||
"Each chunk should have 'text' and 'metadata' keys"
|
||||
)
|
||||
|
||||
def test_create_text_chunks_custom_extensions(self):
|
||||
"""Test text chunking with custom code file extensions."""
|
||||
@@ -353,6 +374,552 @@ class MathUtils:
|
||||
pytest.skip("Test timed out - likely due to model download in CI")
|
||||
|
||||
|
||||
class TestASTContentExtraction:
|
||||
"""Test AST content extraction bug fix.
|
||||
|
||||
These tests verify that astchunk's dict format with 'content' key is handled correctly,
|
||||
and that the extraction logic doesn't fall through to stringifying entire dicts.
|
||||
"""
|
||||
|
||||
def test_extract_content_from_astchunk_dict(self):
|
||||
"""Test that astchunk dict format with 'content' key is handled correctly.
|
||||
|
||||
Bug: Current code checks for chunk["text"] but astchunk returns chunk["content"].
|
||||
This causes fallthrough to str(chunk), stringifying the entire dict.
|
||||
|
||||
This test will FAIL until the bug is fixed because:
|
||||
- Current code will stringify the dict: "{'content': '...', 'metadata': {...}}"
|
||||
- Fixed code should extract just the content value
|
||||
"""
|
||||
# Mock the ASTChunkBuilder class
|
||||
mock_builder = Mock()
|
||||
|
||||
# Astchunk returns this format
|
||||
astchunk_format_chunk = {
|
||||
"content": "def hello():\n print('world')",
|
||||
"metadata": {
|
||||
"filepath": "test.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 0,
|
||||
"end_line_no": 1,
|
||||
"node_count": 1,
|
||||
},
|
||||
}
|
||||
mock_builder.chunkify.return_value = [astchunk_format_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def hello():\n print('world')", "/test/test.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module and its ASTChunkBuilder class
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
# Patch sys.modules to inject our mock before the import
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should return dict format with proper metadata
|
||||
assert len(chunks) > 0, "Should return at least one chunk"
|
||||
|
||||
# R3: Each chunk should be a dict
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
assert "metadata" in chunk, "Chunk should have 'metadata' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
# CRITICAL: Should NOT contain stringified dict markers in the text field
|
||||
# These assertions will FAIL with current buggy code
|
||||
assert "'content':" not in chunk_text, (
|
||||
f"Chunk text contains stringified dict - extraction failed! Got: {chunk_text[:100]}..."
|
||||
)
|
||||
assert "'metadata':" not in chunk_text, (
|
||||
"Chunk text contains stringified metadata - extraction failed! "
|
||||
f"Got: {chunk_text[:100]}..."
|
||||
)
|
||||
assert "{" not in chunk_text or "def hello" in chunk_text.split("{")[0], (
|
||||
"Chunk text appears to be a stringified dict"
|
||||
)
|
||||
|
||||
# Should contain actual content
|
||||
assert "def hello()" in chunk_text, "Should extract actual code content"
|
||||
assert "print('world')" in chunk_text, "Should extract complete code content"
|
||||
|
||||
# R3: Should preserve astchunk metadata
|
||||
assert "filepath" in chunk["metadata"] or "file_path" in chunk["metadata"], (
|
||||
"Should preserve file path metadata"
|
||||
)
|
||||
|
||||
def test_extract_text_key_fallback(self):
|
||||
"""Test that 'text' key still works for backward compatibility.
|
||||
|
||||
Some chunks might use 'text' instead of 'content' - ensure backward compatibility.
|
||||
This test should PASS even with current code.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Some chunks might use "text" key
|
||||
text_key_chunk = {"text": "def legacy_function():\n return True"}
|
||||
mock_builder.chunkify.return_value = [text_key_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def legacy_function():\n return True", "/test/legacy.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should extract text correctly as dict format
|
||||
assert len(chunks) > 0
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
# Should NOT be stringified
|
||||
assert "'text':" not in chunk_text, "Should not stringify dict with 'text' key"
|
||||
|
||||
# Should contain actual content
|
||||
assert "def legacy_function()" in chunk_text
|
||||
assert "return True" in chunk_text
|
||||
|
||||
def test_handles_string_chunks(self):
|
||||
"""Test that plain string chunks still work.
|
||||
|
||||
Some chunkers might return plain strings - verify these are preserved.
|
||||
This test should PASS with current code.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Plain string chunk
|
||||
plain_string_chunk = "def simple_function():\n pass"
|
||||
mock_builder.chunkify.return_value = [plain_string_chunk]
|
||||
|
||||
# Create mock document
|
||||
doc = MockDocument(
|
||||
"def simple_function():\n pass", "/test/simple.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should wrap string in dict format
|
||||
assert len(chunks) > 0
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Even string chunks should be wrapped in dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
|
||||
assert chunk_text == plain_string_chunk.strip(), (
|
||||
"Should preserve plain string chunk content"
|
||||
)
|
||||
assert "def simple_function()" in chunk_text
|
||||
assert "pass" in chunk_text
|
||||
|
||||
def test_multiple_chunks_with_mixed_formats(self):
|
||||
"""Test handling of multiple chunks with different formats.
|
||||
|
||||
Real-world scenario: astchunk might return a mix of formats.
|
||||
This test will FAIL if any chunk with 'content' key gets stringified.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
# Mix of formats
|
||||
mixed_chunks = [
|
||||
{"content": "def first():\n return 1", "metadata": {"line_count": 2}},
|
||||
"def second():\n return 2", # Plain string
|
||||
{"text": "def third():\n return 3"}, # Old format
|
||||
{"content": "class MyClass:\n pass", "metadata": {"node_count": 1}},
|
||||
]
|
||||
mock_builder.chunkify.return_value = mixed_chunks
|
||||
|
||||
# Create mock document
|
||||
code = "def first():\n return 1\n\ndef second():\n return 2\n\ndef third():\n return 3\n\nclass MyClass:\n pass"
|
||||
doc = MockDocument(code, "/test/mixed.py", {"language": "python"})
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
# Call create_ast_chunks
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should extract all chunks correctly as dicts
|
||||
assert len(chunks) == 4, "Should extract all 4 chunks"
|
||||
|
||||
# Check each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be a dict"
|
||||
assert "text" in chunk, f"Chunk {i} should have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} should have 'metadata' key"
|
||||
|
||||
chunk_text = chunk["text"]
|
||||
# None should be stringified dicts
|
||||
assert "'content':" not in chunk_text, f"Chunk {i} text is stringified (has 'content':)"
|
||||
assert "'metadata':" not in chunk_text, (
|
||||
f"Chunk {i} text is stringified (has 'metadata':)"
|
||||
)
|
||||
assert "'text':" not in chunk_text, f"Chunk {i} text is stringified (has 'text':)"
|
||||
|
||||
# Verify actual content is present
|
||||
combined = "\n".join([c["text"] for c in chunks])
|
||||
assert "def first()" in combined
|
||||
assert "def second()" in combined
|
||||
assert "def third()" in combined
|
||||
assert "class MyClass:" in combined
|
||||
|
||||
def test_empty_content_value_handling(self):
|
||||
"""Test handling of chunks with empty content values.
|
||||
|
||||
Edge case: chunk has 'content' key but value is empty.
|
||||
Should skip these chunks, not stringify them.
|
||||
"""
|
||||
mock_builder = Mock()
|
||||
|
||||
chunks_with_empty = [
|
||||
{"content": "", "metadata": {"line_count": 0}}, # Empty content
|
||||
{"content": " ", "metadata": {"line_count": 1}}, # Whitespace only
|
||||
{"content": "def valid():\n return True", "metadata": {"line_count": 2}}, # Valid
|
||||
]
|
||||
mock_builder.chunkify.return_value = chunks_with_empty
|
||||
|
||||
doc = MockDocument(
|
||||
"def valid():\n return True", "/test/empty.py", {"language": "python"}
|
||||
)
|
||||
|
||||
# Mock the astchunk module
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# R3: Should only have the valid chunk (empty ones filtered out)
|
||||
assert len(chunks) == 1, "Should filter out empty content chunks"
|
||||
|
||||
chunk = chunks[0]
|
||||
assert isinstance(chunk, dict), "Chunk should be a dict"
|
||||
assert "text" in chunk, "Chunk should have 'text' key"
|
||||
assert "def valid()" in chunk["text"]
|
||||
|
||||
# Should not have stringified the empty dict
|
||||
assert "'content': ''" not in chunk["text"]
|
||||
|
||||
|
||||
class TestASTMetadataPreservation:
|
||||
"""Test metadata preservation in AST chunk dictionaries.
|
||||
|
||||
R3: These tests define the contract for metadata preservation when returning
|
||||
chunk dictionaries instead of plain strings. Each chunk dict should have:
|
||||
- "text": str - the actual chunk content
|
||||
- "metadata": dict - all metadata from document AND astchunk
|
||||
|
||||
These tests will FAIL until G3 implementation changes return type to list[dict].
|
||||
"""
|
||||
|
||||
def test_ast_chunks_preserve_file_metadata(self):
|
||||
"""Test that document metadata is preserved in chunk metadata.
|
||||
|
||||
This test verifies that all document-level metadata (file_path, file_name,
|
||||
creation_date, last_modified_date) is included in each chunk's metadata dict.
|
||||
|
||||
This will FAIL because current code returns list[str], not list[dict].
|
||||
"""
|
||||
# Create mock document with rich metadata
|
||||
python_code = '''
|
||||
def calculate_sum(numbers):
|
||||
"""Calculate sum of numbers."""
|
||||
return sum(numbers)
|
||||
|
||||
class DataProcessor:
|
||||
"""Process data records."""
|
||||
|
||||
def process(self, data):
|
||||
return [x * 2 for x in data]
|
||||
'''
|
||||
doc = MockDocument(
|
||||
python_code,
|
||||
file_path="/project/src/utils.py",
|
||||
metadata={
|
||||
"language": "python",
|
||||
"file_path": "/project/src/utils.py",
|
||||
"file_name": "utils.py",
|
||||
"creation_date": "2024-01-15T10:30:00",
|
||||
"last_modified_date": "2024-10-31T15:45:00",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock astchunk to return chunks with metadata
|
||||
mock_builder = Mock()
|
||||
astchunk_chunks = [
|
||||
{
|
||||
"content": "def calculate_sum(numbers):\n return sum(numbers)",
|
||||
"metadata": {
|
||||
"filepath": "/project/src/utils.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 1,
|
||||
"end_line_no": 2,
|
||||
"node_count": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"content": "class DataProcessor:\n def process(self, data):\n return [x * 2 for x in data]",
|
||||
"metadata": {
|
||||
"filepath": "/project/src/utils.py",
|
||||
"line_count": 3,
|
||||
"start_line_no": 5,
|
||||
"end_line_no": 7,
|
||||
"node_count": 2,
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_builder.chunkify.return_value = astchunk_chunks
|
||||
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# CRITICAL: These assertions will FAIL with current list[str] return type
|
||||
assert len(chunks) == 2, "Should return 2 chunks"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure assertions - WILL FAIL: current code returns strings
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||
assert isinstance(chunk["metadata"], dict), f"Chunk {i} metadata should be dict"
|
||||
|
||||
# Document metadata preservation - WILL FAIL
|
||||
metadata = chunk["metadata"]
|
||||
assert "file_path" in metadata, f"Chunk {i} should preserve file_path"
|
||||
assert metadata["file_path"] == "/project/src/utils.py", (
|
||||
f"Chunk {i} file_path incorrect"
|
||||
)
|
||||
|
||||
assert "file_name" in metadata, f"Chunk {i} should preserve file_name"
|
||||
assert metadata["file_name"] == "utils.py", f"Chunk {i} file_name incorrect"
|
||||
|
||||
assert "creation_date" in metadata, f"Chunk {i} should preserve creation_date"
|
||||
assert metadata["creation_date"] == "2024-01-15T10:30:00", (
|
||||
f"Chunk {i} creation_date incorrect"
|
||||
)
|
||||
|
||||
assert "last_modified_date" in metadata, f"Chunk {i} should preserve last_modified_date"
|
||||
assert metadata["last_modified_date"] == "2024-10-31T15:45:00", (
|
||||
f"Chunk {i} last_modified_date incorrect"
|
||||
)
|
||||
|
||||
# Verify metadata is consistent across chunks from same document
|
||||
assert chunks[0]["metadata"]["file_path"] == chunks[1]["metadata"]["file_path"], (
|
||||
"All chunks from same document should have same file_path"
|
||||
)
|
||||
|
||||
# Verify text content is present and not stringified
|
||||
assert "def calculate_sum" in chunks[0]["text"]
|
||||
assert "class DataProcessor" in chunks[1]["text"]
|
||||
|
||||
def test_ast_chunks_include_astchunk_metadata(self):
|
||||
"""Test that astchunk-specific metadata is merged into chunk metadata.
|
||||
|
||||
This test verifies that astchunk's metadata (line_count, start_line_no,
|
||||
end_line_no, node_count) is merged with document metadata.
|
||||
|
||||
This will FAIL because current code returns list[str], not list[dict].
|
||||
"""
|
||||
python_code = '''
|
||||
def function_one():
|
||||
"""First function."""
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
|
||||
def function_two():
|
||||
"""Second function."""
|
||||
return 42
|
||||
'''
|
||||
doc = MockDocument(
|
||||
python_code,
|
||||
file_path="/test/code.py",
|
||||
metadata={
|
||||
"language": "python",
|
||||
"file_path": "/test/code.py",
|
||||
"file_name": "code.py",
|
||||
},
|
||||
)
|
||||
|
||||
# Mock astchunk with detailed metadata
|
||||
mock_builder = Mock()
|
||||
astchunk_chunks = [
|
||||
{
|
||||
"content": "def function_one():\n x = 1\n y = 2\n return x + y",
|
||||
"metadata": {
|
||||
"filepath": "/test/code.py",
|
||||
"line_count": 4,
|
||||
"start_line_no": 1,
|
||||
"end_line_no": 4,
|
||||
"node_count": 5, # function, assignments, return
|
||||
},
|
||||
},
|
||||
{
|
||||
"content": "def function_two():\n return 42",
|
||||
"metadata": {
|
||||
"filepath": "/test/code.py",
|
||||
"line_count": 2,
|
||||
"start_line_no": 7,
|
||||
"end_line_no": 8,
|
||||
"node_count": 2, # function, return
|
||||
},
|
||||
},
|
||||
]
|
||||
mock_builder.chunkify.return_value = astchunk_chunks
|
||||
|
||||
mock_astchunk = Mock()
|
||||
mock_astchunk.ASTChunkBuilder = Mock(return_value=mock_builder)
|
||||
|
||||
with patch.dict("sys.modules", {"astchunk": mock_astchunk}):
|
||||
chunks = create_ast_chunks([doc])
|
||||
|
||||
# CRITICAL: These will FAIL with current list[str] return
|
||||
assert len(chunks) == 2
|
||||
|
||||
# First chunk - function_one
|
||||
chunk1 = chunks[0]
|
||||
assert isinstance(chunk1, dict), "Chunk should be dict"
|
||||
assert "metadata" in chunk1
|
||||
|
||||
metadata1 = chunk1["metadata"]
|
||||
|
||||
# Check astchunk metadata is present
|
||||
assert "line_count" in metadata1, "Should include astchunk line_count"
|
||||
assert metadata1["line_count"] == 4, "line_count should be 4"
|
||||
|
||||
assert "start_line_no" in metadata1, "Should include astchunk start_line_no"
|
||||
assert metadata1["start_line_no"] == 1, "start_line_no should be 1"
|
||||
|
||||
assert "end_line_no" in metadata1, "Should include astchunk end_line_no"
|
||||
assert metadata1["end_line_no"] == 4, "end_line_no should be 4"
|
||||
|
||||
assert "node_count" in metadata1, "Should include astchunk node_count"
|
||||
assert metadata1["node_count"] == 5, "node_count should be 5"
|
||||
|
||||
# Second chunk - function_two
|
||||
chunk2 = chunks[1]
|
||||
metadata2 = chunk2["metadata"]
|
||||
|
||||
assert metadata2["line_count"] == 2, "line_count should be 2"
|
||||
assert metadata2["start_line_no"] == 7, "start_line_no should be 7"
|
||||
assert metadata2["end_line_no"] == 8, "end_line_no should be 8"
|
||||
assert metadata2["node_count"] == 2, "node_count should be 2"
|
||||
|
||||
# Verify document metadata is ALSO present (merged, not replaced)
|
||||
assert metadata1["file_path"] == "/test/code.py"
|
||||
assert metadata1["file_name"] == "code.py"
|
||||
assert metadata2["file_path"] == "/test/code.py"
|
||||
assert metadata2["file_name"] == "code.py"
|
||||
|
||||
# Verify text content is correct
|
||||
assert "def function_one" in chunk1["text"]
|
||||
assert "def function_two" in chunk2["text"]
|
||||
|
||||
def test_traditional_chunks_as_dicts_helper(self):
|
||||
"""Test the helper function that wraps traditional chunks as dicts.
|
||||
|
||||
This test verifies that when create_traditional_chunks is called,
|
||||
its plain string chunks are wrapped into dict format with metadata.
|
||||
|
||||
This will FAIL because the helper function _traditional_chunks_as_dicts()
|
||||
doesn't exist yet, and create_traditional_chunks returns list[str].
|
||||
"""
|
||||
# Create documents with various metadata
|
||||
docs = [
|
||||
MockDocument(
|
||||
"This is the first paragraph of text. It contains multiple sentences. "
|
||||
"This should be split into chunks based on size.",
|
||||
file_path="/docs/readme.txt",
|
||||
metadata={
|
||||
"file_path": "/docs/readme.txt",
|
||||
"file_name": "readme.txt",
|
||||
"creation_date": "2024-01-01",
|
||||
},
|
||||
),
|
||||
MockDocument(
|
||||
"Second document with different metadata. It also has content that needs chunking.",
|
||||
file_path="/docs/guide.md",
|
||||
metadata={
|
||||
"file_path": "/docs/guide.md",
|
||||
"file_name": "guide.md",
|
||||
"last_modified_date": "2024-10-31",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Call create_traditional_chunks (which should now return list[dict])
|
||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||
|
||||
# CRITICAL: Will FAIL - current code returns list[str]
|
||||
assert len(chunks) > 0, "Should return chunks"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Structure assertions - WILL FAIL
|
||||
assert isinstance(chunk, dict), f"Chunk {i} should be dict, got {type(chunk)}"
|
||||
assert "text" in chunk, f"Chunk {i} must have 'text' key"
|
||||
assert "metadata" in chunk, f"Chunk {i} must have 'metadata' key"
|
||||
|
||||
# Text should be non-empty
|
||||
assert len(chunk["text"].strip()) > 0, f"Chunk {i} text should be non-empty"
|
||||
|
||||
# Metadata should include document info
|
||||
metadata = chunk["metadata"]
|
||||
assert "file_path" in metadata, f"Chunk {i} should have file_path in metadata"
|
||||
assert "file_name" in metadata, f"Chunk {i} should have file_name in metadata"
|
||||
|
||||
# Verify metadata tracking works correctly
|
||||
# At least one chunk should be from readme.txt
|
||||
readme_chunks = [c for c in chunks if "readme.txt" in c["metadata"]["file_name"]]
|
||||
assert len(readme_chunks) > 0, "Should have chunks from readme.txt"
|
||||
|
||||
# At least one chunk should be from guide.md
|
||||
guide_chunks = [c for c in chunks if "guide.md" in c["metadata"]["file_name"]]
|
||||
assert len(guide_chunks) > 0, "Should have chunks from guide.md"
|
||||
|
||||
# Verify creation_date is preserved for readme chunks
|
||||
for chunk in readme_chunks:
|
||||
assert chunk["metadata"].get("creation_date") == "2024-01-01", (
|
||||
"readme.txt chunks should preserve creation_date"
|
||||
)
|
||||
|
||||
# Verify last_modified_date is preserved for guide chunks
|
||||
for chunk in guide_chunks:
|
||||
assert chunk["metadata"].get("last_modified_date") == "2024-10-31", (
|
||||
"guide.md chunks should preserve last_modified_date"
|
||||
)
|
||||
|
||||
# Verify text content is present
|
||||
all_text = " ".join([c["text"] for c in chunks])
|
||||
assert "first paragraph" in all_text
|
||||
assert "Second document" in all_text
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling and edge cases."""
|
||||
|
||||
|
||||
268
tests/test_token_truncation.py
Normal file
268
tests/test_token_truncation.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Unit tests for token-aware truncation functionality.
|
||||
|
||||
This test suite defines the contract for token truncation functions that prevent
|
||||
500 errors from Ollama when text exceeds model token limits. These tests verify:
|
||||
|
||||
1. Model token limit retrieval (known and unknown models)
|
||||
2. Text truncation behavior for single and multiple texts
|
||||
3. Token counting and truncation accuracy using tiktoken
|
||||
|
||||
All tests are written in Red Phase - they should FAIL initially because the
|
||||
implementation does not exist yet.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
from leann.embedding_compute import (
|
||||
EMBEDDING_MODEL_LIMITS,
|
||||
get_model_token_limit,
|
||||
truncate_to_token_limit,
|
||||
)
|
||||
|
||||
|
||||
class TestModelTokenLimits:
|
||||
"""Tests for retrieving model-specific token limits."""
|
||||
|
||||
def test_get_model_token_limit_known_model(self):
|
||||
"""Verify correct token limit is returned for known models.
|
||||
|
||||
Known models should return their specific token limits from
|
||||
EMBEDDING_MODEL_LIMITS dictionary.
|
||||
"""
|
||||
# Test nomic-embed-text (2048 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text")
|
||||
assert limit == 2048, "nomic-embed-text should have 2048 token limit"
|
||||
|
||||
# Test nomic-embed-text-v1.5 (2048 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text-v1.5")
|
||||
assert limit == 2048, "nomic-embed-text-v1.5 should have 2048 token limit"
|
||||
|
||||
# Test nomic-embed-text-v2 (512 tokens)
|
||||
limit = get_model_token_limit("nomic-embed-text-v2")
|
||||
assert limit == 512, "nomic-embed-text-v2 should have 512 token limit"
|
||||
|
||||
# Test OpenAI models (8192 tokens)
|
||||
limit = get_model_token_limit("text-embedding-3-small")
|
||||
assert limit == 8192, "text-embedding-3-small should have 8192 token limit"
|
||||
|
||||
def test_get_model_token_limit_unknown_model(self):
|
||||
"""Verify default token limit is returned for unknown models.
|
||||
|
||||
Unknown models should return the default limit (2048) to allow
|
||||
operation with reasonable safety margin.
|
||||
"""
|
||||
# Test with completely unknown model
|
||||
limit = get_model_token_limit("unknown-model-xyz")
|
||||
assert limit == 2048, "Unknown models should return default 2048"
|
||||
|
||||
# Test with empty string
|
||||
limit = get_model_token_limit("")
|
||||
assert limit == 2048, "Empty model name should return default 2048"
|
||||
|
||||
def test_get_model_token_limit_custom_default(self):
|
||||
"""Verify custom default can be specified for unknown models.
|
||||
|
||||
Allow callers to specify their own default token limit when
|
||||
model is not in the known models dictionary.
|
||||
"""
|
||||
limit = get_model_token_limit("unknown-model", default=4096)
|
||||
assert limit == 4096, "Should return custom default for unknown models"
|
||||
|
||||
# Known model should ignore custom default
|
||||
limit = get_model_token_limit("nomic-embed-text", default=4096)
|
||||
assert limit == 2048, "Known model should ignore custom default"
|
||||
|
||||
def test_embedding_model_limits_dictionary_exists(self):
|
||||
"""Verify EMBEDDING_MODEL_LIMITS dictionary contains expected models.
|
||||
|
||||
The dictionary should be importable and contain at least the
|
||||
known nomic models with correct token limits.
|
||||
"""
|
||||
assert isinstance(EMBEDDING_MODEL_LIMITS, dict), "Should be a dictionary"
|
||||
assert "nomic-embed-text" in EMBEDDING_MODEL_LIMITS, "Should contain nomic-embed-text"
|
||||
assert "nomic-embed-text-v1.5" in EMBEDDING_MODEL_LIMITS, (
|
||||
"Should contain nomic-embed-text-v1.5"
|
||||
)
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text"] == 2048
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v1.5"] == 2048
|
||||
assert EMBEDDING_MODEL_LIMITS["nomic-embed-text-v2"] == 512
|
||||
# OpenAI models
|
||||
assert EMBEDDING_MODEL_LIMITS["text-embedding-3-small"] == 8192
|
||||
|
||||
|
||||
class TestTokenTruncation:
|
||||
"""Tests for truncating texts to token limits."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self):
|
||||
"""Provide tiktoken tokenizer for token counting verification."""
|
||||
return tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def test_truncate_single_text_under_limit(self, tokenizer):
|
||||
"""Verify text under token limit remains unchanged.
|
||||
|
||||
When text is already within the token limit, it should be
|
||||
returned unchanged with no truncation.
|
||||
"""
|
||||
text = "This is a short text that is well under the token limit."
|
||||
token_count = len(tokenizer.encode(text))
|
||||
assert token_count < 100, f"Test setup: text should be short (has {token_count} tokens)"
|
||||
|
||||
# Truncate with generous limit
|
||||
result = truncate_to_token_limit([text], token_limit=512)
|
||||
|
||||
assert len(result) == 1, "Should return same number of texts"
|
||||
assert result[0] == text, "Text under limit should be unchanged"
|
||||
|
||||
def test_truncate_single_text_over_limit(self, tokenizer):
|
||||
"""Verify text over token limit is truncated correctly.
|
||||
|
||||
When text exceeds the token limit, it should be truncated to
|
||||
fit within the limit while maintaining valid token boundaries.
|
||||
"""
|
||||
# Create a text that definitely exceeds limit
|
||||
text = "word " * 200 # ~200 tokens (each "word " is typically 1-2 tokens)
|
||||
original_token_count = len(tokenizer.encode(text))
|
||||
assert original_token_count > 50, (
|
||||
f"Test setup: text should be long (has {original_token_count} tokens)"
|
||||
)
|
||||
|
||||
# Truncate to 50 tokens
|
||||
result = truncate_to_token_limit([text], token_limit=50)
|
||||
|
||||
assert len(result) == 1, "Should return same number of texts"
|
||||
assert result[0] != text, "Text over limit should be truncated"
|
||||
assert len(result[0]) < len(text), "Truncated text should be shorter"
|
||||
|
||||
# Verify truncated text is within token limit
|
||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||
assert truncated_token_count <= 50, (
|
||||
f"Truncated text should be ≤50 tokens, got {truncated_token_count}"
|
||||
)
|
||||
|
||||
def test_truncate_multiple_texts_mixed_lengths(self, tokenizer):
|
||||
"""Verify multiple texts with mixed lengths are handled correctly.
|
||||
|
||||
When processing multiple texts:
|
||||
- Texts under limit should remain unchanged
|
||||
- Texts over limit should be truncated independently
|
||||
- Output list should maintain same order and length
|
||||
"""
|
||||
texts = [
|
||||
"Short text.", # Under limit
|
||||
"word " * 200, # Over limit
|
||||
"Another short one.", # Under limit
|
||||
"token " * 150, # Over limit
|
||||
]
|
||||
|
||||
# Verify test setup
|
||||
for i, text in enumerate(texts):
|
||||
token_count = len(tokenizer.encode(text))
|
||||
if i in [1, 3]:
|
||||
assert token_count > 50, f"Text {i} should be over limit (has {token_count} tokens)"
|
||||
else:
|
||||
assert token_count < 50, (
|
||||
f"Text {i} should be under limit (has {token_count} tokens)"
|
||||
)
|
||||
|
||||
# Truncate with 50 token limit
|
||||
result = truncate_to_token_limit(texts, token_limit=50)
|
||||
|
||||
assert len(result) == len(texts), "Should return same number of texts"
|
||||
|
||||
# Verify each text individually
|
||||
for i, (original, truncated) in enumerate(zip(texts, result)):
|
||||
token_count = len(tokenizer.encode(truncated))
|
||||
assert token_count <= 50, f"Text {i} should be ≤50 tokens, got {token_count}"
|
||||
|
||||
# Short texts should be unchanged
|
||||
if i in [0, 2]:
|
||||
assert truncated == original, f"Short text {i} should be unchanged"
|
||||
# Long texts should be truncated
|
||||
else:
|
||||
assert len(truncated) < len(original), f"Long text {i} should be truncated"
|
||||
|
||||
def test_truncate_empty_list(self):
|
||||
"""Verify empty input list returns empty output list.
|
||||
|
||||
Edge case: empty list should return empty list without errors.
|
||||
"""
|
||||
result = truncate_to_token_limit([], token_limit=512)
|
||||
assert result == [], "Empty input should return empty output"
|
||||
|
||||
def test_truncate_preserves_order(self, tokenizer):
|
||||
"""Verify truncation preserves original text order.
|
||||
|
||||
Output list should maintain the same order as input list,
|
||||
regardless of which texts were truncated.
|
||||
"""
|
||||
texts = [
|
||||
"First text " * 50, # Will be truncated
|
||||
"Second text.", # Won't be truncated
|
||||
"Third text " * 50, # Will be truncated
|
||||
]
|
||||
|
||||
result = truncate_to_token_limit(texts, token_limit=20)
|
||||
|
||||
assert len(result) == 3, "Should preserve list length"
|
||||
# Check that order is maintained by looking for distinctive words
|
||||
assert "First" in result[0], "First text should remain in first position"
|
||||
assert "Second" in result[1], "Second text should remain in second position"
|
||||
assert "Third" in result[2], "Third text should remain in third position"
|
||||
|
||||
def test_truncate_extremely_long_text(self, tokenizer):
|
||||
"""Verify extremely long texts are truncated efficiently.
|
||||
|
||||
Test with text that far exceeds token limit to ensure
|
||||
truncation handles extreme cases without performance issues.
|
||||
"""
|
||||
# Create very long text (simulate real-world scenario)
|
||||
text = "token " * 5000 # ~5000+ tokens
|
||||
original_token_count = len(tokenizer.encode(text))
|
||||
assert original_token_count > 1000, "Test setup: text should be very long"
|
||||
|
||||
# Truncate to small limit
|
||||
result = truncate_to_token_limit([text], token_limit=100)
|
||||
|
||||
assert len(result) == 1
|
||||
truncated_token_count = len(tokenizer.encode(result[0]))
|
||||
assert truncated_token_count <= 100, (
|
||||
f"Should truncate to ≤100 tokens, got {truncated_token_count}"
|
||||
)
|
||||
assert len(result[0]) < len(text) // 10, "Should significantly reduce text length"
|
||||
|
||||
def test_truncate_exact_token_limit(self, tokenizer):
|
||||
"""Verify text at exactly token limit is handled correctly.
|
||||
|
||||
Edge case: text with exactly the token limit should either
|
||||
remain unchanged or be safely truncated by 1 token.
|
||||
"""
|
||||
# Create text with approximately 50 tokens
|
||||
# We'll adjust to get exactly 50
|
||||
target_tokens = 50
|
||||
text = "word " * 50
|
||||
tokens = tokenizer.encode(text)
|
||||
|
||||
# Adjust to get exactly target_tokens
|
||||
if len(tokens) > target_tokens:
|
||||
tokens = tokens[:target_tokens]
|
||||
text = tokenizer.decode(tokens)
|
||||
elif len(tokens) < target_tokens:
|
||||
# Add more words
|
||||
while len(tokenizer.encode(text)) < target_tokens:
|
||||
text += "word "
|
||||
tokens = tokenizer.encode(text)[:target_tokens]
|
||||
text = tokenizer.decode(tokens)
|
||||
|
||||
# Verify we have exactly target_tokens
|
||||
assert len(tokenizer.encode(text)) == target_tokens, (
|
||||
"Test setup: should have exactly 50 tokens"
|
||||
)
|
||||
|
||||
result = truncate_to_token_limit([text], token_limit=target_tokens)
|
||||
|
||||
assert len(result) == 1
|
||||
result_tokens = len(tokenizer.encode(result[0]))
|
||||
assert result_tokens <= target_tokens, (
|
||||
f"Should be ≤{target_tokens} tokens, got {result_tokens}"
|
||||
)
|
||||
Reference in New Issue
Block a user