355 lines
13 KiB
Python
355 lines
13 KiB
Python
"""
|
|
Enhanced chunking utilities with AST-aware code chunking support.
|
|
Packaged within leann-core so installed wheels can import it reliably.
|
|
"""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from llama_index.core.node_parser import SentenceSplitter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def estimate_token_count(text: str) -> int:
|
|
"""
|
|
Estimate token count for a text string.
|
|
Uses conservative estimation: ~4 characters per token for natural text,
|
|
~1.2 tokens per character for code (worse tokenization).
|
|
|
|
Args:
|
|
text: Input text to estimate tokens for
|
|
|
|
Returns:
|
|
Estimated token count
|
|
"""
|
|
try:
|
|
import tiktoken
|
|
|
|
encoder = tiktoken.get_encoding("cl100k_base")
|
|
return len(encoder.encode(text))
|
|
except ImportError:
|
|
# Fallback: Conservative character-based estimation
|
|
# Assume worst case for code: 1.2 tokens per character
|
|
return int(len(text) * 1.2)
|
|
|
|
|
|
def calculate_safe_chunk_size(
|
|
model_token_limit: int,
|
|
overlap_tokens: int,
|
|
chunking_mode: str = "traditional",
|
|
safety_factor: float = 0.9,
|
|
) -> int:
|
|
"""
|
|
Calculate safe chunk size accounting for overlap and safety margin.
|
|
|
|
Args:
|
|
model_token_limit: Maximum tokens supported by embedding model
|
|
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
|
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
|
safety_factor: Safety margin (0.9 = 10% safety margin)
|
|
|
|
Returns:
|
|
Safe chunk size: tokens for traditional, characters for AST
|
|
"""
|
|
safe_limit = int(model_token_limit * safety_factor)
|
|
|
|
if chunking_mode == "traditional":
|
|
# Traditional chunking uses tokens
|
|
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
|
return max(1, safe_limit - overlap_tokens)
|
|
else: # AST chunking
|
|
# AST uses characters, need to convert
|
|
# Conservative estimate: 1.2 tokens per char for code
|
|
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
|
safe_chars = int(safe_limit / 1.2)
|
|
return max(1, safe_chars - overlap_chars)
|
|
|
|
|
|
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
|
"""
|
|
Validate that chunks don't exceed token limits and truncate if necessary.
|
|
|
|
Args:
|
|
chunks: List of text chunks to validate
|
|
max_tokens: Maximum tokens allowed per chunk
|
|
|
|
Returns:
|
|
Tuple of (validated_chunks, num_truncated)
|
|
"""
|
|
validated_chunks = []
|
|
num_truncated = 0
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
estimated_tokens = estimate_token_count(chunk)
|
|
|
|
if estimated_tokens > max_tokens:
|
|
# Truncate chunk to fit token limit
|
|
try:
|
|
import tiktoken
|
|
|
|
encoder = tiktoken.get_encoding("cl100k_base")
|
|
tokens = encoder.encode(chunk)
|
|
if len(tokens) > max_tokens:
|
|
truncated_tokens = tokens[:max_tokens]
|
|
truncated_chunk = encoder.decode(truncated_tokens)
|
|
validated_chunks.append(truncated_chunk)
|
|
num_truncated += 1
|
|
logger.warning(
|
|
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
|
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
|
)
|
|
else:
|
|
validated_chunks.append(chunk)
|
|
except ImportError:
|
|
# Fallback: Conservative character truncation
|
|
char_limit = int(max_tokens / 1.2) # Conservative for code
|
|
if len(chunk) > char_limit:
|
|
truncated_chunk = chunk[:char_limit]
|
|
validated_chunks.append(truncated_chunk)
|
|
num_truncated += 1
|
|
logger.warning(
|
|
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
|
f"(conservative estimate for {max_tokens} tokens)"
|
|
)
|
|
else:
|
|
validated_chunks.append(chunk)
|
|
else:
|
|
validated_chunks.append(chunk)
|
|
|
|
if num_truncated > 0:
|
|
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
|
|
|
return validated_chunks, num_truncated
|
|
|
|
|
|
# Code file extensions supported by astchunk
|
|
CODE_EXTENSIONS = {
|
|
".py": "python",
|
|
".java": "java",
|
|
".cs": "csharp",
|
|
".ts": "typescript",
|
|
".tsx": "typescript",
|
|
".js": "typescript",
|
|
".jsx": "typescript",
|
|
}
|
|
|
|
|
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|
"""Separate documents into code files and regular text files."""
|
|
if code_extensions is None:
|
|
code_extensions = CODE_EXTENSIONS
|
|
|
|
code_docs = []
|
|
text_docs = []
|
|
|
|
for doc in documents:
|
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
|
if file_path:
|
|
file_ext = Path(file_path).suffix.lower()
|
|
if file_ext in code_extensions:
|
|
doc.metadata["language"] = code_extensions[file_ext]
|
|
doc.metadata["is_code"] = True
|
|
code_docs.append(doc)
|
|
else:
|
|
doc.metadata["is_code"] = False
|
|
text_docs.append(doc)
|
|
else:
|
|
doc.metadata["is_code"] = False
|
|
text_docs.append(doc)
|
|
|
|
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
|
return code_docs, text_docs
|
|
|
|
|
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
|
ext = Path(file_path).suffix.lower()
|
|
return CODE_EXTENSIONS.get(ext)
|
|
|
|
|
|
def create_ast_chunks(
|
|
documents,
|
|
max_chunk_size: int = 512,
|
|
chunk_overlap: int = 64,
|
|
metadata_template: str = "default",
|
|
) -> list[str]:
|
|
"""Create AST-aware chunks from code documents using astchunk.
|
|
|
|
Falls back to traditional chunking if astchunk is unavailable.
|
|
"""
|
|
try:
|
|
from astchunk import ASTChunkBuilder # optional dependency
|
|
except ImportError as e:
|
|
logger.error(f"astchunk not available: {e}")
|
|
logger.info("Falling back to traditional chunking for code files")
|
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
|
|
|
all_chunks = []
|
|
for doc in documents:
|
|
language = doc.metadata.get("language")
|
|
if not language:
|
|
logger.warning("No language detected; falling back to traditional chunking")
|
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
|
continue
|
|
|
|
try:
|
|
# Warn if AST chunk size + overlap might exceed common token limits
|
|
estimated_max_tokens = int(
|
|
(max_chunk_size + chunk_overlap) * 1.2
|
|
) # Conservative estimate
|
|
if estimated_max_tokens > 512:
|
|
logger.warning(
|
|
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
|
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
|
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
|
)
|
|
|
|
configs = {
|
|
"max_chunk_size": max_chunk_size,
|
|
"language": language,
|
|
"metadata_template": metadata_template,
|
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
|
}
|
|
|
|
repo_metadata = {
|
|
"file_path": doc.metadata.get("file_path", ""),
|
|
"file_name": doc.metadata.get("file_name", ""),
|
|
"creation_date": doc.metadata.get("creation_date", ""),
|
|
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
|
}
|
|
configs["repo_level_metadata"] = repo_metadata
|
|
|
|
chunk_builder = ASTChunkBuilder(**configs)
|
|
code_content = doc.get_content()
|
|
if not code_content or not code_content.strip():
|
|
logger.warning("Empty code content, skipping")
|
|
continue
|
|
|
|
chunks = chunk_builder.chunkify(code_content)
|
|
for chunk in chunks:
|
|
if hasattr(chunk, "text"):
|
|
chunk_text = chunk.text
|
|
elif isinstance(chunk, dict) and "text" in chunk:
|
|
chunk_text = chunk["text"]
|
|
elif isinstance(chunk, str):
|
|
chunk_text = chunk
|
|
else:
|
|
chunk_text = str(chunk)
|
|
|
|
if chunk_text and chunk_text.strip():
|
|
all_chunks.append(chunk_text.strip())
|
|
|
|
logger.info(
|
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
|
logger.info("Falling back to traditional chunking")
|
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
|
|
|
return all_chunks
|
|
|
|
|
|
def create_traditional_chunks(
|
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
|
) -> list[str]:
|
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
|
if chunk_size <= 0:
|
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
|
chunk_size = 256
|
|
if chunk_overlap < 0:
|
|
chunk_overlap = 0
|
|
if chunk_overlap >= chunk_size:
|
|
chunk_overlap = chunk_size // 2
|
|
|
|
node_parser = SentenceSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
separator=" ",
|
|
paragraph_separator="\n\n",
|
|
)
|
|
|
|
all_texts = []
|
|
for doc in documents:
|
|
try:
|
|
nodes = node_parser.get_nodes_from_documents([doc])
|
|
if nodes:
|
|
all_texts.extend(node.get_content() for node in nodes)
|
|
except Exception as e:
|
|
logger.error(f"Traditional chunking failed for document: {e}")
|
|
content = doc.get_content()
|
|
if content and content.strip():
|
|
all_texts.append(content.strip())
|
|
|
|
return all_texts
|
|
|
|
|
|
def create_text_chunks(
|
|
documents,
|
|
chunk_size: int = 256,
|
|
chunk_overlap: int = 128,
|
|
use_ast_chunking: bool = False,
|
|
ast_chunk_size: int = 512,
|
|
ast_chunk_overlap: int = 64,
|
|
code_file_extensions: Optional[list[str]] = None,
|
|
ast_fallback_traditional: bool = True,
|
|
) -> list[str]:
|
|
"""Create text chunks from documents with optional AST support for code files."""
|
|
if not documents:
|
|
logger.warning("No documents provided for chunking")
|
|
return []
|
|
|
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
|
if code_file_extensions:
|
|
ext_mapping = {
|
|
".py": "python",
|
|
".java": "java",
|
|
".cs": "c_sharp",
|
|
".ts": "typescript",
|
|
".tsx": "typescript",
|
|
}
|
|
for ext in code_file_extensions:
|
|
if ext.lower() not in local_code_extensions:
|
|
if ext.lower() in ext_mapping:
|
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
|
else:
|
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
|
|
|
all_chunks = []
|
|
if use_ast_chunking:
|
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
|
if code_docs:
|
|
try:
|
|
all_chunks.extend(
|
|
create_ast_chunks(
|
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"AST chunking failed: {e}")
|
|
if ast_fallback_traditional:
|
|
all_chunks.extend(
|
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
|
)
|
|
else:
|
|
raise
|
|
if text_docs:
|
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
|
else:
|
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
|
|
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
|
|
|
# Validate chunk token limits (default to 512 for safety)
|
|
# This provides a safety net for embedding models with token limits
|
|
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
|
|
|
if num_truncated > 0:
|
|
logger.info(
|
|
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
|
)
|
|
|
|
return validated_chunks
|