Compare commits
4 Commits
feature/op
...
fix/chunki
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -105,3 +105,6 @@ apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weavia
|
|||||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||||
|
|
||||||
|
# AUR build directory (Arch Linux)
|
||||||
|
paru-bin/
|
||||||
|
|||||||
@@ -1213,3 +1213,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
## 🤖 Explore LEANN with AI
|
||||||
|
|
||||||
|
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
||||||
|
|||||||
@@ -180,14 +180,14 @@ class BaseRAGExample(ABC):
|
|||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=300,
|
||||||
help="Maximum characters per AST chunk (default: 512)",
|
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="Overlap between AST chunks (default: 64)",
|
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
ast_group.add_argument(
|
ast_group.add_argument(
|
||||||
"--code-file-extensions",
|
"--code-file-extensions",
|
||||||
|
|||||||
@@ -29,12 +29,25 @@ if(APPLE)
|
|||||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Use system ZeroMQ instead of building from source
|
# Find ZMQ using pkg-config with IMPORTED_TARGET for automatic target creation
|
||||||
find_package(PkgConfig REQUIRED)
|
find_package(PkgConfig REQUIRED)
|
||||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
|
||||||
|
# On ARM64 macOS, ensure pkg-config finds ARM64 Homebrew packages first
|
||||||
|
if(APPLE AND CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
set(ENV{PKG_CONFIG_PATH} "/opt/homebrew/lib/pkgconfig:/opt/homebrew/share/pkgconfig:$ENV{PKG_CONFIG_PATH}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
pkg_check_modules(ZMQ REQUIRED IMPORTED_TARGET libzmq)
|
||||||
|
|
||||||
|
# This creates PkgConfig::ZMQ target automatically with correct properties
|
||||||
|
if(TARGET PkgConfig::ZMQ)
|
||||||
|
message(STATUS "Found and configured ZMQ target: PkgConfig::ZMQ")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "pkg_check_modules did not create IMPORTED target for ZMQ.")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Add cppzmq headers
|
# Add cppzmq headers
|
||||||
include_directories(third_party/cppzmq)
|
include_directories(SYSTEM third_party/cppzmq)
|
||||||
|
|
||||||
# Configure msgpack-c - disable boost dependency
|
# Configure msgpack-c - disable boost dependency
|
||||||
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 5952745237...c69511a99c
@@ -11,6 +11,119 @@ from llama_index.core.node_parser import SentenceSplitter
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_token_count(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for a text string.
|
||||||
|
Uses conservative estimation: ~4 characters per token for natural text,
|
||||||
|
~1.2 tokens per character for code (worse tokenization).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(encoder.encode(text))
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character-based estimation
|
||||||
|
# Assume worst case for code: 1.2 tokens per character
|
||||||
|
return int(len(text) * 1.2)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_safe_chunk_size(
|
||||||
|
model_token_limit: int,
|
||||||
|
overlap_tokens: int,
|
||||||
|
chunking_mode: str = "traditional",
|
||||||
|
safety_factor: float = 0.9,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate safe chunk size accounting for overlap and safety margin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_token_limit: Maximum tokens supported by embedding model
|
||||||
|
overlap_tokens: Overlap size (tokens for traditional, chars for AST)
|
||||||
|
chunking_mode: "traditional" (tokens) or "ast" (characters)
|
||||||
|
safety_factor: Safety margin (0.9 = 10% safety margin)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Safe chunk size: tokens for traditional, characters for AST
|
||||||
|
"""
|
||||||
|
safe_limit = int(model_token_limit * safety_factor)
|
||||||
|
|
||||||
|
if chunking_mode == "traditional":
|
||||||
|
# Traditional chunking uses tokens
|
||||||
|
# Max chunk = chunk_size + overlap, so chunk_size = limit - overlap
|
||||||
|
return max(1, safe_limit - overlap_tokens)
|
||||||
|
else: # AST chunking
|
||||||
|
# AST uses characters, need to convert
|
||||||
|
# Conservative estimate: 1.2 tokens per char for code
|
||||||
|
overlap_chars = int(overlap_tokens * 3) # ~3 chars per token for code
|
||||||
|
safe_chars = int(safe_limit / 1.2)
|
||||||
|
return max(1, safe_chars - overlap_chars)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chunk_token_limits(chunks: list[str], max_tokens: int = 512) -> tuple[list[str], int]:
|
||||||
|
"""
|
||||||
|
Validate that chunks don't exceed token limits and truncate if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of text chunks to validate
|
||||||
|
max_tokens: Maximum tokens allowed per chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (validated_chunks, num_truncated)
|
||||||
|
"""
|
||||||
|
validated_chunks = []
|
||||||
|
num_truncated = 0
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
estimated_tokens = estimate_token_count(chunk)
|
||||||
|
|
||||||
|
if estimated_tokens > max_tokens:
|
||||||
|
# Truncate chunk to fit token limit
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens = encoder.encode(chunk)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
truncated_tokens = tokens[:max_tokens]
|
||||||
|
truncated_chunk = encoder.decode(truncated_tokens)
|
||||||
|
validated_chunks.append(truncated_chunk)
|
||||||
|
num_truncated += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated chunk {i} from {len(tokens)} to {max_tokens} tokens "
|
||||||
|
f"(from {len(chunk)} to {len(truncated_chunk)} characters)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character truncation
|
||||||
|
char_limit = int(max_tokens / 1.2) # Conservative for code
|
||||||
|
if len(chunk) > char_limit:
|
||||||
|
truncated_chunk = chunk[:char_limit]
|
||||||
|
validated_chunks.append(truncated_chunk)
|
||||||
|
num_truncated += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated chunk {i} from {len(chunk)} to {char_limit} characters "
|
||||||
|
f"(conservative estimate for {max_tokens} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
validated_chunks.append(chunk)
|
||||||
|
|
||||||
|
if num_truncated > 0:
|
||||||
|
logger.warning(f"Truncated {num_truncated}/{len(chunks)} chunks to fit token limits")
|
||||||
|
|
||||||
|
return validated_chunks, num_truncated
|
||||||
|
|
||||||
|
|
||||||
# Code file extensions supported by astchunk
|
# Code file extensions supported by astchunk
|
||||||
CODE_EXTENSIONS = {
|
CODE_EXTENSIONS = {
|
||||||
".py": "python",
|
".py": "python",
|
||||||
@@ -82,6 +195,17 @@ def create_ast_chunks(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Warn if AST chunk size + overlap might exceed common token limits
|
||||||
|
estimated_max_tokens = int(
|
||||||
|
(max_chunk_size + chunk_overlap) * 1.2
|
||||||
|
) # Conservative estimate
|
||||||
|
if estimated_max_tokens > 512:
|
||||||
|
logger.warning(
|
||||||
|
f"AST chunk size ({max_chunk_size}) + overlap ({chunk_overlap}) = {max_chunk_size + chunk_overlap} chars "
|
||||||
|
f"may exceed 512 token limit (~{estimated_max_tokens} tokens estimated). "
|
||||||
|
f"Consider reducing --ast-chunk-size to {int(400 / 1.2)} or --ast-chunk-overlap to {int(50 / 1.2)}"
|
||||||
|
)
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
"language": language,
|
"language": language,
|
||||||
@@ -217,4 +341,14 @@ def create_text_chunks(
|
|||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
return all_chunks
|
|
||||||
|
# Validate chunk token limits (default to 512 for safety)
|
||||||
|
# This provides a safety net for embedding models with token limits
|
||||||
|
validated_chunks, num_truncated = validate_chunk_token_limits(all_chunks, max_tokens=512)
|
||||||
|
|
||||||
|
if num_truncated > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Post-chunking validation: {num_truncated} chunks were truncated to fit 512 token limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
return validated_chunks
|
||||||
|
|||||||
@@ -181,25 +181,25 @@ Examples:
|
|||||||
"--doc-chunk-size",
|
"--doc-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Document chunk size in tokens/characters (default: 256)",
|
help="Document chunk size in TOKENS (default: 256). Final chunks may be larger due to overlap. For 512 token models: recommended 350 tokens (350 + 128 overlap = 478 max)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--doc-chunk-overlap",
|
"--doc-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
default=128,
|
||||||
help="Document chunk overlap (default: 128)",
|
help="Document chunk overlap in TOKENS (default: 128). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-size",
|
"--code-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help="Code chunk size in tokens/lines (default: 512)",
|
help="Code chunk size in TOKENS (default: 512). Final chunks may be larger due to overlap. For 512 token models: recommended 400 tokens (400 + 50 overlap = 450 max)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--code-chunk-overlap",
|
"--code-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="Code chunk overlap (default: 50)",
|
help="Code chunk overlap in TOKENS (default: 50). Added to chunk size, not included in it",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--use-ast-chunking",
|
"--use-ast-chunking",
|
||||||
@@ -209,14 +209,14 @@ Examples:
|
|||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-size",
|
"--ast-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=768,
|
default=300,
|
||||||
help="AST chunk size in characters (default: 768)",
|
help="AST chunk size in CHARACTERS (non-whitespace) (default: 300). Final chunks may be larger due to overlap and expansion. For 512 token models: recommended 300 chars (300 + 64 overlap ~= 480 tokens)",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-chunk-overlap",
|
"--ast-chunk-overlap",
|
||||||
type=int,
|
type=int,
|
||||||
default=96,
|
default=64,
|
||||||
help="AST chunk overlap in characters (default: 96)",
|
help="AST chunk overlap in CHARACTERS (default: 64). Added to chunk size, not included in it. ~1.2 tokens per character for code",
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--ast-fallback-traditional",
|
"--ast-fallback-traditional",
|
||||||
@@ -255,6 +255,11 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Non-interactive mode: automatically select index without prompting",
|
help="Non-interactive mode: automatically select index without prompting",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--show-metadata",
|
||||||
|
action="store_true",
|
||||||
|
help="Display file paths and metadata in search results",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
@@ -1263,7 +1268,7 @@ Examples:
|
|||||||
from .chunking_utils import create_text_chunks
|
from .chunking_utils import create_text_chunks
|
||||||
|
|
||||||
# Use enhanced chunking with AST support
|
# Use enhanced chunking with AST support
|
||||||
all_texts = create_text_chunks(
|
chunk_texts = create_text_chunks(
|
||||||
documents,
|
documents,
|
||||||
chunk_size=self.node_parser.chunk_size,
|
chunk_size=self.node_parser.chunk_size,
|
||||||
chunk_overlap=self.node_parser.chunk_overlap,
|
chunk_overlap=self.node_parser.chunk_overlap,
|
||||||
@@ -1274,6 +1279,14 @@ Examples:
|
|||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: AST chunking currently returns plain text chunks without metadata
|
||||||
|
# We preserve basic file info by associating chunks with their source documents
|
||||||
|
# For better metadata preservation, documents list order should be maintained
|
||||||
|
for chunk_text in chunk_texts:
|
||||||
|
# TODO: Enhance create_text_chunks to return metadata alongside text
|
||||||
|
# For now, we store chunks with empty metadata
|
||||||
|
all_texts.append({"text": chunk_text, "metadata": {}})
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(
|
print(
|
||||||
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
@@ -1285,17 +1298,27 @@ Examples:
|
|||||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||||
# Check if this is a code file based on source path
|
# Check if this is a code file based on source path
|
||||||
source_path = doc.metadata.get("source", "")
|
source_path = doc.metadata.get("source", "")
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||||
|
|
||||||
|
# Extract metadata to preserve with chunks
|
||||||
|
chunk_metadata = {
|
||||||
|
"file_path": file_path or source_path,
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional metadata if available
|
||||||
|
if "creation_date" in doc.metadata:
|
||||||
|
chunk_metadata["creation_date"] = doc.metadata["creation_date"]
|
||||||
|
if "last_modified_date" in doc.metadata:
|
||||||
|
chunk_metadata["last_modified_date"] = doc.metadata["last_modified_date"]
|
||||||
|
|
||||||
# Use appropriate parser based on file type
|
# Use appropriate parser based on file type
|
||||||
parser = self.code_parser if is_code_file else self.node_parser
|
parser = self.code_parser if is_code_file else self.node_parser
|
||||||
nodes = parser.get_nodes_from_documents([doc])
|
nodes = parser.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
text_with_source = (
|
all_texts.append({"text": node.get_content(), "metadata": chunk_metadata})
|
||||||
"Chunk source:" + source_path + "\n" + node.get_content().replace("\n", " ")
|
|
||||||
)
|
|
||||||
all_texts.append(text_with_source)
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||||
return all_texts
|
return all_texts
|
||||||
@@ -1370,7 +1393,7 @@ Examples:
|
|||||||
|
|
||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
print(f"Building index '{index_name}' with {args.backend_name} backend...")
|
||||||
|
|
||||||
embedding_options: dict[str, Any] = {}
|
embedding_options: dict[str, Any] = {}
|
||||||
if args.embedding_mode == "ollama":
|
if args.embedding_mode == "ollama":
|
||||||
@@ -1382,7 +1405,7 @@ Examples:
|
|||||||
embedding_options["api_key"] = resolved_embedding_key
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
embedding_options=embedding_options or None,
|
embedding_options=embedding_options or None,
|
||||||
@@ -1393,10 +1416,8 @@ Examples:
|
|||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk_text_with_source in all_texts:
|
for chunk in all_texts:
|
||||||
chunk_source = chunk_text_with_source.split("\n")[0].split(":")[1]
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
chunk_text = chunk_text_with_source.split("\n")[1]
|
|
||||||
builder.add_text(chunk_text, {"source": chunk_source})
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index built at {index_path}")
|
print(f"Index built at {index_path}")
|
||||||
@@ -1517,6 +1538,23 @@ Examples:
|
|||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
for i, result in enumerate(results, 1):
|
for i, result in enumerate(results, 1):
|
||||||
print(f"{i}. Score: {result.score:.3f}")
|
print(f"{i}. Score: {result.score:.3f}")
|
||||||
|
|
||||||
|
# Display metadata if flag is set
|
||||||
|
if args.show_metadata and result.metadata:
|
||||||
|
file_path = result.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
print(f" 📄 File: {file_path}")
|
||||||
|
|
||||||
|
file_name = result.metadata.get("file_name", "")
|
||||||
|
if file_name and file_name != file_path:
|
||||||
|
print(f" 📝 Name: {file_name}")
|
||||||
|
|
||||||
|
# Show timestamps if available
|
||||||
|
if "creation_date" in result.metadata:
|
||||||
|
print(f" 🕐 Created: {result.metadata['creation_date']}")
|
||||||
|
if "last_modified_date" in result.metadata:
|
||||||
|
print(f" 🕑 Modified: {result.metadata['last_modified_date']}")
|
||||||
|
|
||||||
print(f" {result.text[:200]}...")
|
print(f" {result.text[:200]}...")
|
||||||
print(f" Source: {result.metadata.get('source', '')}")
|
print(f" Source: {result.metadata.get('source', '')}")
|
||||||
print()
|
print()
|
||||||
|
|||||||
@@ -14,6 +14,89 @@ import torch
|
|||||||
|
|
||||||
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_to_token_limit(texts: list[str], max_tokens: int = 512) -> list[str]:
|
||||||
|
"""
|
||||||
|
Truncate texts to token limit using tiktoken or conservative character truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to truncate
|
||||||
|
max_tokens: Maximum tokens allowed per text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of truncated texts that should fit within token limit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
truncated = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens = encoder.encode(text)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
# Truncate to max_tokens and decode back to text
|
||||||
|
truncated_tokens = tokens[:max_tokens]
|
||||||
|
truncated_text = encoder.decode(truncated_tokens)
|
||||||
|
truncated.append(truncated_text)
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated text from {len(tokens)} to {max_tokens} tokens "
|
||||||
|
f"(from {len(text)} to {len(truncated_text)} characters)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncated.append(text)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: Conservative character truncation
|
||||||
|
# Assume worst case: 1.5 tokens per character for code content
|
||||||
|
char_limit = int(max_tokens / 1.5)
|
||||||
|
truncated = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
if len(text) > char_limit:
|
||||||
|
truncated_text = text[:char_limit]
|
||||||
|
truncated.append(truncated_text)
|
||||||
|
logger.warning(
|
||||||
|
f"Truncated text from {len(text)} to {char_limit} characters "
|
||||||
|
f"(conservative estimate for {max_tokens} tokens)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncated.append(text)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_token_limit(model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Get token limit for a given embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the embedding model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token limit for the model, defaults to 512 if unknown
|
||||||
|
"""
|
||||||
|
# Handle versioned model names (e.g., "nomic-embed-text:latest" -> "nomic-embed-text")
|
||||||
|
base_model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
if model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
return EMBEDDING_MODEL_LIMITS[model_name]
|
||||||
|
|
||||||
|
# Check base name match
|
||||||
|
if base_model_name in EMBEDDING_MODEL_LIMITS:
|
||||||
|
return EMBEDDING_MODEL_LIMITS[base_model_name]
|
||||||
|
|
||||||
|
# Check partial matches for common patterns
|
||||||
|
for known_model, limit in EMBEDDING_MODEL_LIMITS.items():
|
||||||
|
if known_model in base_model_name or base_model_name in known_model:
|
||||||
|
return limit
|
||||||
|
|
||||||
|
# Default to conservative 512 token limit
|
||||||
|
logger.warning(f"Unknown model '{model_name}', using default 512 token limit")
|
||||||
|
return 512
|
||||||
|
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -23,6 +106,17 @@ logger.setLevel(log_level)
|
|||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Known embedding model token limits
|
||||||
|
EMBEDDING_MODEL_LIMITS = {
|
||||||
|
"nomic-embed-text": 512,
|
||||||
|
"nomic-embed-text-v2": 512,
|
||||||
|
"mxbai-embed-large": 512,
|
||||||
|
"all-minilm": 512,
|
||||||
|
"bge-m3": 8192,
|
||||||
|
"snowflake-arctic-embed": 512,
|
||||||
|
# Add more models as needed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -574,9 +668,10 @@ def compute_embeddings_ollama(
|
|||||||
host: Optional[str] = None,
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with true batch processing.
|
||||||
|
|
||||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
Uses the /api/embed endpoint which supports batch inputs.
|
||||||
|
Batch size: 32 for MPS/CPU, 128 for CUDA to optimize performance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
@@ -681,11 +776,11 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||||
model_name = resolved_model_name
|
model_name = resolved_model_name
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it with /api/embed
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{resolved_host}/api/embeddings",
|
f"{resolved_host}/api/embed",
|
||||||
json={"model": model_name, "prompt": "test"},
|
json={"model": model_name, "input": "test"},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
@@ -717,63 +812,80 @@ def compute_embeddings_ollama(
|
|||||||
# If torch is not available, use conservative batch size
|
# If torch is not available, use conservative batch size
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size}")
|
logger.info(f"Using batch size: {batch_size} for true batch processing")
|
||||||
|
|
||||||
|
# Get model token limit and apply truncation
|
||||||
|
token_limit = get_model_token_limit(model_name)
|
||||||
|
logger.info(f"Model '{model_name}' token limit: {token_limit}")
|
||||||
|
|
||||||
|
# Apply token-aware truncation to all texts
|
||||||
|
truncated_texts = truncate_to_token_limit(texts, token_limit)
|
||||||
|
if len(truncated_texts) != len(texts):
|
||||||
|
logger.error("Truncation failed - text count mismatch")
|
||||||
|
truncated_texts = texts # Fallback to original texts
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
def get_batch_embeddings(batch_texts):
|
||||||
"""Get embeddings for a batch of texts."""
|
"""Get embeddings for a batch of texts using /api/embed endpoint."""
|
||||||
all_embeddings = []
|
max_retries = 3
|
||||||
failed_indices = []
|
retry_count = 0
|
||||||
|
|
||||||
for i, text in enumerate(batch_texts):
|
# Texts are already truncated to token limit by the outer function
|
||||||
max_retries = 3
|
while retry_count < max_retries:
|
||||||
retry_count = 0
|
try:
|
||||||
|
# Use /api/embed endpoint with "input" parameter for batch processing
|
||||||
|
response = requests.post(
|
||||||
|
f"{resolved_host}/api/embed",
|
||||||
|
json={"model": model_name, "input": batch_texts},
|
||||||
|
timeout=60, # Increased timeout for batch processing
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
result = response.json()
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
batch_embeddings = result.get("embeddings")
|
||||||
while retry_count < max_retries:
|
|
||||||
try:
|
if batch_embeddings is None:
|
||||||
response = requests.post(
|
raise ValueError("No embeddings returned from API")
|
||||||
f"{resolved_host}/api/embeddings",
|
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
if not isinstance(batch_embeddings, list):
|
||||||
timeout=30,
|
raise ValueError(f"Invalid embeddings format: {type(batch_embeddings)}")
|
||||||
|
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: requested {len(batch_texts)} embeddings, got {len(batch_embeddings)}"
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
return batch_embeddings, []
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
except requests.exceptions.Timeout:
|
||||||
raise ValueError(f"No embedding returned for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.warning(f"Timeout for batch after {max_retries} retries")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid embedding format for text {i}")
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
# Enhanced error detection for token limit violations
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
if "token" in error_msg and (
|
||||||
|
"limit" in error_msg or "exceed" in error_msg or "length" in error_msg
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Token limit exceeded for batch. Error: {e}. "
|
||||||
|
f"Consider reducing chunk sizes or check token truncation."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to get embeddings for batch: {e}")
|
||||||
|
return None, list(range(len(batch_texts)))
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
return None, list(range(len(batch_texts)))
|
||||||
break
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
# Process truncated texts in batches
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
return all_embeddings, failed_indices
|
|
||||||
|
|
||||||
# Process texts in batches
|
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
all_failed_indices = []
|
all_failed_indices = []
|
||||||
|
|
||||||
# Setup progress bar if needed
|
# Setup progress bar if needed
|
||||||
show_progress = is_build or len(texts) > 10
|
show_progress = is_build or len(truncated_texts) > 10
|
||||||
try:
|
try:
|
||||||
if show_progress:
|
if show_progress:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -781,32 +893,36 @@ def compute_embeddings_ollama(
|
|||||||
show_progress = False
|
show_progress = False
|
||||||
|
|
||||||
# Process batches
|
# Process batches
|
||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
num_batches = (len(truncated_texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if show_progress:
|
if show_progress:
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings (batched)")
|
||||||
else:
|
else:
|
||||||
batch_iterator = range(num_batches)
|
batch_iterator = range(num_batches)
|
||||||
|
|
||||||
for batch_idx in batch_iterator:
|
for batch_idx in batch_iterator:
|
||||||
start_idx = batch_idx * batch_size
|
start_idx = batch_idx * batch_size
|
||||||
end_idx = min(start_idx + batch_size, len(texts))
|
end_idx = min(start_idx + batch_size, len(truncated_texts))
|
||||||
batch_texts = texts[start_idx:end_idx]
|
batch_texts = truncated_texts[start_idx:end_idx]
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
# Adjust failed indices to global indices
|
if batch_embeddings is not None:
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
all_embeddings.extend(batch_embeddings)
|
||||||
all_failed_indices.extend(global_failed)
|
else:
|
||||||
all_embeddings.extend(batch_embeddings)
|
# Entire batch failed, add None placeholders
|
||||||
|
all_embeddings.extend([None] * len(batch_texts))
|
||||||
|
# Adjust failed indices to global indices
|
||||||
|
global_failed = [start_idx + idx for idx in batch_failed]
|
||||||
|
all_failed_indices.extend(global_failed)
|
||||||
|
|
||||||
# Handle failed embeddings
|
# Handle failed embeddings
|
||||||
if all_failed_indices:
|
if all_failed_indices:
|
||||||
if len(all_failed_indices) == len(texts):
|
if len(all_failed_indices) == len(truncated_texts):
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(truncated_texts)} texts"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
|||||||
@@ -60,6 +60,11 @@ def handle_request(request):
|
|||||||
"maximum": 128,
|
"maximum": 128,
|
||||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||||
},
|
},
|
||||||
|
"show_metadata": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": False,
|
||||||
|
"description": "Include file paths and metadata in search results. Useful for understanding which files contain the results.",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["index_name", "query"],
|
"required": ["index_name", "query"],
|
||||||
},
|
},
|
||||||
@@ -104,6 +109,8 @@ def handle_request(request):
|
|||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
"--non-interactive",
|
"--non-interactive",
|
||||||
]
|
]
|
||||||
|
if args.get("show_metadata", False):
|
||||||
|
cmd.append("--show-metadata")
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
elif tool_name == "leann_list":
|
||||||
|
|||||||
Reference in New Issue
Block a user