fixing chunking token issues within limit for embedding models
This commit is contained in:
@@ -14,6 +14,89 @@ import torch
|
||||
|
||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
|
||||
|
||||
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
||||
"""
|
||||
Truncate texts to token limit using tiktoken or conservative character truncation.
|
||||
|
||||
Args:
|
||||
texts: List of texts to truncate
|
||||
max_tokens: Maximum tokens allowed per text
|
||||
|
||||
Returns:
|
||||
List of truncated texts that should fit within token limit
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
truncated = []
|
||||
|
||||
for text in texts:
|
||||
tokens = encoder.encode(text)
|
||||
if len(tokens) > max_tokens:
|
||||
# Truncate to max_tokens and decode back to text
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_text = encoder.decode(truncated_tokens)
|
||||
truncated.append(truncated_text)
|
||||
logger.warning(
|
||||
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
||||
f"(from {len(text)} to {len(truncated_text)} characters)"
|
||||
)
|
||||
else:
|
||||
truncated.append(text)
|
||||
return truncated
|
||||
|
||||
except ImportError:
|
||||
# Fallback: Conservative character truncation
|
||||
# Assume worst case: 1.5 tokens per character for code content
|
||||
char_limit = int(max_tokens / 1.5)
|
||||
truncated = []
|
||||
|
||||
for text in texts:
|
||||
if len(text) > char_limit:
|
||||
truncated_text = text[:char_limit]
|
||||
truncated.append(truncated_text)
|
||||
logger.warning(
|
||||
f"Truncated text from {len(text)} to {char_limit} characters "
|
||||
f"(conservative estimate for {max_tokens} tokens)"
|
||||
)
|
||||
else:
|
||||
truncated.append(text)
|
||||
return truncated
|
||||
|
||||
|
||||
def get_model_token_limit(model_name: str) -> int:
|
||||
"""
|
||||
Get token limit for a given embedding model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model
|
||||
|
||||
Returns:
|
||||
Token limit for the model, defaults to 512 if unknown
|
||||
"""
|
||||
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||
base_model_name = model_name.split(":")[0]
|
||||
|
||||
# Check exact match first
|
||||
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[model_name]
|
||||
|
||||
# Check base name match
|
||||
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||
|
||||
# Check partial matches for common patterns
|
||||
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
||||
if known_model in base_model_name or base_model_name in known_model:
|
||||
return limit
|
||||
|
||||
# Default to conservative 512 token limit
|
||||
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
||||
return 512
|
||||
|
||||
|
||||
# Set up logger with proper level
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
@@ -23,6 +106,17 @@ logger.setLevel(log_level)
|
||||
# Global model cache to avoid repeated loading
|
||||
_model_cache: dict[str, Any] = {}
|
||||
|
||||
# Known embedding model token limits
|
||||
EMBEDDING_MODEL_LIMITS = {
|
||||
"nomic-embed-text": 512,
|
||||
"nomic-embed-text-v2": 512,
|
||||
"mxbai-embed-large": 512,
|
||||
"all-minilm": 512,
|
||||
"bge-m3": 8192,
|
||||
"snowflake-arctic-embed": 512,
|
||||
# Add more models as needed
|
||||
}
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: list[str],
|
||||
@@ -720,20 +814,28 @@ def compute_embeddings_ollama(
|
||||
|
||||
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||
|
||||
# Get model token limit and apply truncation
|
||||
token_limit = get_model_token_limit(model_name)
|
||||
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||
|
||||
# Apply token-aware truncation to all texts
|
||||
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
||||
if len(truncated_texts) != len(texts):
|
||||
logger.error("Truncation failed - text count mismatch")
|
||||
truncated_texts = texts # Fallback to original texts
|
||||
|
||||
def get_batch_embeddings(batch_texts):
|
||||
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
|
||||
|
||||
# Texts are already truncated to token limit by the outer function
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||
response = requests.post(
|
||||
f"{resolved_host}/api/embed",
|
||||
json={"model": model_name, "input": truncated_texts},
|
||||
json={"model": model_name, "input": batch_texts},
|
||||
timeout=60, # Increased timeout for batch processing
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -763,17 +865,27 @@ def compute_embeddings_ollama(
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||
# Enhanced error detection for token limit violations
|
||||
error_msg = str(e).lower()
|
||||
if "token" in error_msg and (
|
||||
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||
):
|
||||
logger.error(
|
||||
f"Token limit exceeded for batch. Error: {e}. "
|
||||
f"Consider reducing chunk sizes or check token truncation."
|
||||
)
|
||||
else:
|
||||
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
return None, list(range(len(batch_texts)))
|
||||
|
||||
# Process texts in batches
|
||||
# Process truncated texts in batches
|
||||
all_embeddings = []
|
||||
all_failed_indices = []
|
||||
|
||||
# Setup progress bar if needed
|
||||
show_progress = is_build or len(texts) > 10
|
||||
show_progress = is_build or len(truncated_texts) > 10
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
@@ -781,7 +893,7 @@ def compute_embeddings_ollama(
|
||||
show_progress = False
|
||||
|
||||
# Process batches
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
||||
|
||||
if show_progress:
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||
@@ -790,8 +902,8 @@ def compute_embeddings_ollama(
|
||||
|
||||
for batch_idx in batch_iterator:
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(texts))
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
||||
batch_texts = truncated_texts[start_idx:end_idx]
|
||||
|
||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||
|
||||
@@ -806,11 +918,11 @@ def compute_embeddings_ollama(
|
||||
|
||||
# Handle failed embeddings
|
||||
if all_failed_indices:
|
||||
if len(all_failed_indices) == len(texts):
|
||||
if len(all_failed_indices) == len(truncated_texts):
|
||||
raise RuntimeError("Failed to compute any embeddings")
|
||||
|
||||
logger.warning(
|
||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
||||
)
|
||||
|
||||
# Use zero embeddings as fallback for failed ones
|
||||
|
||||
Reference in New Issue
Block a user