Compare commits

..

7 Commits

Author SHA1 Message Date
yichuan520030910320
6c8801480d fall back to original faiss as i merge the PR 2025-10-30 16:36:14 -07:00
ww2283
d226f72bc0 feat: implement true batch processing for Ollama embeddings
Migrate from deprecated /api/embeddings to modern /api/embed endpoint
which supports batch inputs. This reduces HTTP overhead by sending
32 texts per request instead of making individual API calls.

Changes:
- Update endpoint from /api/embeddings to /api/embed
- Change parameter from 'prompt' (single) to 'input' (array)
- Update response parsing for batch embeddings array
- Increase timeout to 60s for batch processing
- Improve error handling for batch requests

Performance:
- Reduces API calls by 32x (batch size)
- Eliminates HTTP connection overhead per text
- Note: Ollama still processes batch items sequentially internally

Related: #151
2025-10-25 10:58:15 -04:00
ww2283
45b87ce128 Merge upstream/main into feature/add-metadata-output
Resolved conflicts in cli.py by keeping structured metadata approach over
inline text concatenation from PR #149.

Our approach uses separate metadata dictionary which is cleaner and more
maintainable than parsing embedded strings.
2025-10-25 10:53:19 -04:00
ww2283
585ef7785d chore: update faiss submodule to use ww2283 fork
Use ww2283/faiss fork with fix/zmq-linking branch to resolve CI checkout
failures. The ZMQ linking fixes are not yet merged upstream.
2025-10-25 10:44:48 -04:00
ww2283
5073f312b6 style: apply ruff formatting 2025-10-22 20:13:25 -04:00
ww2283
76e16338ca fix: resolve ZMQ linking issues in Python extension
- Use pkg_check_modules IMPORTED_TARGET to create PkgConfig::ZMQ
- Set PKG_CONFIG_PATH to prioritize ARM64 Homebrew on Apple Silicon
- Override macOS -undefined dynamic_lookup to force proper symbol resolution
- Use PUBLIC linkage for ZMQ in faiss library for transitive linking
- Mark cppzmq includes as SYSTEM to suppress warnings

Fixes editable install ZMQ symbol errors while maintaining compatibility
across Linux, macOS Intel, and macOS ARM64 platforms.
2025-10-22 18:53:13 -04:00
ww2283
d6a3c2821c feat: add metadata output to search results
- Add --show-metadata flag to display file paths in search results
- Preserve document metadata (file_path, file_name, timestamps) during chunking
- Update MCP tool schema to support show_metadata parameter
- Enhance CLI search output to display metadata when requested
- Fix pre-existing bug: args.backend -> args.backend_name

Resolves yichuan-w/LEANN#144
2025-10-22 14:10:47 -04:00
7 changed files with 25 additions and 278 deletions

3
.gitignore vendored
View File

@@ -105,6 +105,3 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

View File

@@ -1213,7 +1213,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>
## 🤖 Explore LEANN with AI
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.

View File

@@ -180,14 +180,14 @@ class BaseRAGExample(ABC):
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
default=512,
help="Maximum characters per AST chunk (default: 512)",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
help="Overlap between AST chunks (default: 64)",
)
ast_group.add_argument(
"--code-file-extensions",

View File

@@ -11,119 +11,6 @@ from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text string.
Uses conservative estimation: ~4 characters per token for natural text,
~1.2 tokens per character for code (worse tokenization).
Args:
text: Input text to estimate tokens for
Returns:
Estimated token count
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
return len(encoder.encode(text))
except ImportError:
# Fallback: Conservative character-based estimation
# Assume worst case for code: 1.2 tokens per character
return int(len(text) * 1.2)
def calculate_safe_chunk_size(
model_token_limit: int,
overlap_tokens: int,
chunking_mode: str = "traditional",
safety_factor: float = 0.9,
) -> int:
"""
Calculate safe chunk size accounting for overlap and safety margin.
Args:
model_token_limit: Maximum tokens supported by embedding model
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
chunking_mode: "traditional" (tokens) or "ast" (characters)
safety_factor: Safety margin (0.9 = 10% safety margin)
Returns:
Safe chunk size: tokens for traditional, characters for AST
"""
safe_limit = int(model_token_limit * safety_factor)
if chunking_mode == "traditional":
# Traditional chunking uses tokens
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
return max(1, safe_limit - overlap_tokens)
else: # AST chunking
# AST uses characters, need to convert
# Conservative estimate: 1.2 tokens per char for code
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
safe_chars = int(safe_limit / 1.2)
return max(1, safe_chars - overlap_chars)
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
"""
Validate that chunks don't exceed token limits and truncate if necessary.
Args:
chunks: List of text chunks to validate
max_tokens: Maximum tokens allowed per chunk
Returns:
Tuple of (validated_chunks, num_truncated)
"""
validated_chunks = []
num_truncated = 0
for i, chunk in enumerate(chunks):
estimated_tokens = estimate_token_count(chunk)
if estimated_tokens > max_tokens:
# Truncate chunk to fit token limit
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
tokens = encoder.encode(chunk)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
truncated_chunk = encoder.decode(truncated_tokens)
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
)
else:
validated_chunks.append(chunk)
except ImportError:
# Fallback: Conservative character truncation
char_limit = int(max_tokens / 1.2) # Conservative for code
if len(chunk) > char_limit:
truncated_chunk = chunk[:char_limit]
validated_chunks.append(truncated_chunk)
num_truncated += 1
logger.warning(
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
validated_chunks.append(chunk)
else:
validated_chunks.append(chunk)
if num_truncated > 0:
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
return validated_chunks, num_truncated
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
@@ -195,17 +82,6 @@ def create_ast_chunks(
continue
try:
# Warn if AST chunk size + overlap might exceed common token limits
estimated_max_tokens = int(
(max_chunk_size + chunk_overlap) * 1.2
) # Conservative estimate
if estimated_max_tokens > 512:
logger.warning(
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
)
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -341,14 +217,4 @@ def create_text_chunks(
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
# Validate chunk token limits (default to 512 for safety)
# This provides a safety net for embedding models with token limits
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
if num_truncated > 0:
logger.info(
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
)
return validated_chunks
return all_chunks

View File

@@ -181,25 +181,25 @@ Examples:
"--doc-chunk-size",
type=int,
default=256,
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
help="Document chunk size in tokens/characters (default: 256)",
)
build_parser.add_argument(
"--doc-chunk-overlap",
type=int,
default=128,
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
help="Document chunk overlap (default: 128)",
)
build_parser.add_argument(
"--code-chunk-size",
type=int,
default=512,
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
help="Code chunk size in tokens/lines (default: 512)",
)
build_parser.add_argument(
"--code-chunk-overlap",
type=int,
default=50,
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
help="Code chunk overlap (default: 50)",
)
build_parser.add_argument(
"--use-ast-chunking",
@@ -209,14 +209,14 @@ Examples:
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=300,
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
default=768,
help="AST chunk size in characters (default: 768)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
default=96,
help="AST chunk overlap in characters (default: 96)",
)
build_parser.add_argument(
"--ast-fallback-traditional",

View File

@@ -14,89 +14,6 @@ import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
"""
Truncate texts to token limit using tiktoken or conservative character truncation.
Args:
texts: List of texts to truncate
max_tokens: Maximum tokens allowed per text
Returns:
List of truncated texts that should fit within token limit
"""
try:
import tiktoken
encoder = tiktoken.get_encoding("cl100k_base")
truncated = []
for text in texts:
tokens = encoder.encode(text)
if len(tokens) > max_tokens:
# Truncate to max_tokens and decode back to text
truncated_tokens = tokens[:max_tokens]
truncated_text = encoder.decode(truncated_tokens)
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
f"(from {len(text)} to {len(truncated_text)} characters)"
)
else:
truncated.append(text)
return truncated
except ImportError:
# Fallback: Conservative character truncation
# Assume worst case: 1.5 tokens per character for code content
char_limit = int(max_tokens / 1.5)
truncated = []
for text in texts:
if len(text) > char_limit:
truncated_text = text[:char_limit]
truncated.append(truncated_text)
logger.warning(
f"Truncated text from {len(text)} to {char_limit} characters "
f"(conservative estimate for {max_tokens} tokens)"
)
else:
truncated.append(text)
return truncated
def get_model_token_limit(model_name: str) -> int:
"""
Get token limit for a given embedding model.
Args:
model_name: Name of the embedding model
Returns:
Token limit for the model, defaults to 512 if unknown
"""
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
base_model_name = model_name.split(":")[0]
# Check exact match first
if model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[model_name]
# Check base name match
if base_model_name in EMBEDDING_MODEL_LIMITS:
return EMBEDDING_MODEL_LIMITS[base_model_name]
# Check partial matches for common patterns
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
if known_model in base_model_name or base_model_name in known_model:
return limit
# Default to conservative 512 token limit
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
return 512
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -106,17 +23,6 @@ logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
# Known embedding model token limits
EMBEDDING_MODEL_LIMITS = {
"nomic-embed-text": 512,
"nomic-embed-text-v2": 512,
"mxbai-embed-large": 512,
"all-minilm": 512,
"bge-m3": 8192,
"snowflake-arctic-embed": 512,
# Add more models as needed
}
def compute_embeddings(
texts: list[str],
@@ -814,28 +720,20 @@ def compute_embeddings_ollama(
logger.info(f"Using batch size: {batch_size} for true batch processing")
# Get model token limit and apply truncation
token_limit = get_model_token_limit(model_name)
logger.info(f"Model '{model_name}' token limit: {token_limit}")
# Apply token-aware truncation to all texts
truncated_texts = truncate_to_token_limit(texts, token_limit)
if len(truncated_texts) != len(texts):
logger.error("Truncation failed - text count mismatch")
truncated_texts = texts # Fallback to original texts
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts using /api/embed endpoint."""
max_retries = 3
retry_count = 0
# Texts are already truncated to token limit by the outer function
# Truncate very long texts to avoid API issues
truncated_texts = [text[:8000] if len(text) > 8000 else text for text in batch_texts]
while retry_count < max_retries:
try:
# Use /api/embed endpoint with "input" parameter for batch processing
response = requests.post(
f"{resolved_host}/api/embed",
json={"model": model_name, "input": batch_texts},
json={"model": model_name, "input": truncated_texts},
timeout=60, # Increased timeout for batch processing
)
response.raise_for_status()
@@ -865,27 +763,17 @@ def compute_embeddings_ollama(
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
# Enhanced error detection for token limit violations
error_msg = str(e).lower()
if "token" in error_msg and (
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
):
logger.error(
f"Token limit exceeded for batch. Error: {e}. "
f"Consider reducing chunk sizes or check token truncation."
)
else:
logger.error(f"Failed to get embeddings for batch: {e}")
logger.error(f"Failed to get embeddings for batch: {e}")
return None, list(range(len(batch_texts)))
return None, list(range(len(batch_texts)))
# Process truncated texts in batches
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(truncated_texts) > 10
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
@@ -893,7 +781,7 @@ def compute_embeddings_ollama(
show_progress = False
# Process batches
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
@@ -902,8 +790,8 @@ def compute_embeddings_ollama(
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(truncated_texts))
batch_texts = truncated_texts[start_idx:end_idx]
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
@@ -918,11 +806,11 @@ def compute_embeddings_ollama(
# Handle failed embeddings
if all_failed_indices:
if len(all_failed_indices) == len(truncated_texts):
if len(all_failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
# Use zero embeddings as fallback for failed ones