Merge remote-tracking branch 'origin/main' into financebench

This commit is contained in:
Andy Lee
2025-08-22 13:39:08 -07:00
30 changed files with 4245 additions and 1308 deletions

View File

@@ -83,9 +83,7 @@ def create_diskann_embedding_server(
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Import protobuf after ensuring the path is correct
try:

View File

@@ -1,6 +1,7 @@
import logging
import os
import shutil
import time
from pathlib import Path
from typing import Any, Literal, Optional
@@ -255,6 +256,7 @@ class HNSWSearcher(BaseSearcher):
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
search_time = time.time()
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
@@ -263,6 +265,8 @@ class HNSWSearcher(BaseSearcher):
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
if self._id_map:
def map_label(x: int) -> str:

View File

@@ -90,9 +90,7 @@ def create_hnsw_embedding_server(
embedding_dim: int = int(meta.get("dimensions", 0))
except Exception:
embedding_dim = 0
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []

View File

@@ -10,7 +10,7 @@ import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union
import numpy as np
@@ -18,6 +18,7 @@ from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
@@ -119,9 +120,13 @@ class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
self.offset_maps: dict[str, dict[str, int]] = {}
self.passage_files: dict[str, str] = {}
# Avoid materializing a single gigantic global map to reduce memory
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
# per-shard maps and do a lightweight per-shard lookup on demand.
self._total_count: int = 0
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
@@ -142,12 +147,25 @@ class PassageManager:
default_name: Optional[str],
source_dict: dict[str, Any],
) -> list[Path]:
"""
Build an ordered list of candidate paths. For relative paths specified in
metadata, prefer resolution relative to the metadata file directory first,
then fall back to CWD-based resolution, and finally to conventional
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
"""
candidates: list[Path] = []
# 1) Primary as-is (absolute or relative)
# 1) Primary path
if primary:
p = Path(primary)
candidates.append(p if p.is_absolute() else (Path.cwd() / p))
# 2) metadata-relative explicit relative key
if p.is_absolute():
candidates.append(p)
else:
# Prefer metadata-relative resolution for relative paths
if metadata_file_path:
candidates.append(Path(metadata_file_path).parent / p)
# Also consider CWD-relative as a fallback for legacy layouts
candidates.append(Path.cwd() / p)
# 2) metadata-relative explicit relative key (if present)
if metadata_file_path and source_dict.get(relative_key):
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
# 3) metadata-relative standard sibling filename
@@ -177,23 +195,78 @@ class PassageManager:
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map = pickle.load(f)
offset_map: dict[str, int] = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
self._total_count += len(offset_map)
def get_passage(self, passage_id: str) -> dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
# Fast path: check each shard map (there are typically few shards).
# This avoids building a massive combined dict while keeping lookups
# bounded by the number of shards.
for passage_file, offset_map in self.offset_maps.items():
try:
offset = offset_map[passage_id]
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
except KeyError:
continue
raise KeyError(f"Passage ID not found: {passage_id}")
def filter_search_results(
self,
search_results: list[SearchResult],
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
) -> list[SearchResult]:
"""
Apply metadata filters to search results.
Args:
search_results: List of SearchResult objects
metadata_filters: Filter specifications to apply
Returns:
Filtered list of SearchResult objects
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying metadata filters to {len(search_results)} results")
# Convert SearchResult objects to dictionaries for the filter engine
result_dicts = []
for result in search_results:
result_dicts.append(
{
"id": result.id,
"score": result.score,
"text": result.text,
"metadata": result.metadata,
}
)
# Apply filters using the filter engine
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
# Convert back to SearchResult objects
filtered_results = []
for result_dict in filtered_dicts:
filtered_results.append(
SearchResult(
id=result_dict["id"],
score=result_dict["score"],
text=result_dict["text"],
metadata=result_dict["metadata"],
)
)
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
return filtered_results
def __len__(self) -> int:
return self._total_count
class LeannBuilder:
def __init__(
@@ -573,6 +646,8 @@ class LeannSearcher:
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
)
# Preserve backend name for conditional parameter forwarding
self.backend_name = backend_name
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
@@ -592,15 +667,44 @@ class LeannSearcher:
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
**kwargs,
) -> list[SearchResult]:
"""
Search for nearest neighbors with optional metadata filtering.
Args:
query: Text query to search for
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata.
Format: {"field_name": {"operator": value}}
Supported operators:
- Comparison: "==", "!=", "<", "<=", ">", ">="
- Membership: "in", "not_in"
- String: "contains", "starts_with", "ends_with"
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
**kwargs: Backend-specific parameters
Returns:
List of SearchResult objects with text, metadata, and similarity scores
"""
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Metadata filters: {metadata_filters}")
logger.info(f" Additional kwargs: {kwargs}")
# Smart top_k detection and adjustment
total_docs = len(self.passage_manager.global_offset_map)
# Use PassageManager length (sum of shard sizes) to avoid
# depending on a massive combined map
total_docs = len(self.passage_manager)
original_top_k = top_k
if top_k > total_docs:
top_k = total_docs
@@ -629,23 +733,33 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
# time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds")
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
backend_search_kwargs: dict[str, Any] = {
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
**backend_search_kwargs,
)
# logger.info(f" Search time: {search_time} seconds")
search_time = time.time() - start_time
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
@@ -684,6 +798,13 @@ class LeannSearcher:
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
# Apply metadata filters if specified
if metadata_filters:
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
enriched_results = self.passage_manager.filter_search_results(
enriched_results, metadata_filters
)
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
@@ -724,9 +845,15 @@ class LeannChat:
index_path: str,
llm_config: Optional[dict[str, Any]] = None,
enable_warmup: bool = False,
searcher: Optional[LeannSearcher] = None,
**kwargs,
):
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
if searcher is None:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self._owns_searcher = True
else:
self.searcher = searcher
self._owns_searcher = False
self.llm = get_llm(llm_config)
def ask(
@@ -740,6 +867,8 @@ class LeannChat:
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
**search_kwargs,
):
if llm_kwargs is None:
@@ -754,10 +883,12 @@ class LeannChat:
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
metadata_filters=metadata_filters,
batch_size=batch_size,
**search_kwargs,
)
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -793,7 +924,9 @@ class LeannChat:
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.searcher, "cleanup"):
# Only stop the embedding server if this LeannChat instance created the searcher.
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()
# Enable automatic cleanup patterns

View File

@@ -1,7 +1,8 @@
import argparse
import asyncio
import sys
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
@@ -180,6 +181,29 @@ Examples:
default=50,
help="Code chunk overlap (default: 50)",
)
build_parser.add_argument(
"--use-ast-chunking",
action="store_true",
help="Enable AST-aware chunking for code files (requires astchunk)",
)
build_parser.add_argument(
"--ast-chunk-size",
type=int,
default=768,
help="AST chunk size in characters (default: 768)",
)
build_parser.add_argument(
"--ast-chunk-overlap",
type=int,
default=96,
help="AST chunk overlap in characters (default: 96)",
)
build_parser.add_argument(
"--ast-fallback-traditional",
action="store_true",
default=True,
help="Fall back to traditional chunking if AST chunking fails (default: True)",
)
# Search command
search_parser = subparsers.add_parser("search", help="Search documents")
@@ -833,6 +857,7 @@ Examples:
docs_paths: Union[str, list],
custom_file_types: Union[str, None] = None,
include_hidden: bool = False,
args: Optional[dict[str, Any]] = None,
):
# Handle both single path (string) and multiple paths (list) for backward compatibility
if isinstance(docs_paths, str):
@@ -1138,18 +1163,50 @@ Examples:
}
print("start chunking documents")
# Add progress bar for document chunking
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
# Check if AST chunking is requested
use_ast = getattr(args, "use_ast_chunking", False)
for node in nodes:
all_texts.append(node.get_content())
if use_ast:
print("🧠 Using AST-aware chunking for code files")
try:
# Import enhanced chunking utilities
# Add apps directory to path to import chunking utilities
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
if apps_dir.exists():
sys.path.insert(0, str(apps_dir))
from chunking import create_text_chunks
# Use enhanced chunking with AST support
all_texts = create_text_chunks(
documents,
chunk_size=self.node_parser.chunk_size,
chunk_overlap=self.node_parser.chunk_overlap,
use_ast_chunking=True,
ast_chunk_size=getattr(args, "ast_chunk_size", 768),
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 96),
code_file_extensions=None, # Use defaults
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
except ImportError as e:
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking")
use_ast = False
if not use_ast:
# Use traditional chunking logic
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
@@ -1216,7 +1273,7 @@ Examples:
)
all_texts = self.load_documents(
docs_paths, args.file_types, include_hidden=args.include_hidden
docs_paths, args.file_types, include_hidden=args.include_hidden, args=args
)
if not all_texts:
print("No documents found")

View File

@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any
import numpy as np
@@ -28,6 +29,8 @@ def compute_embeddings(
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -50,6 +53,8 @@ def compute_embeddings(
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
manual_tokenize=manual_tokenize,
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
@@ -71,6 +76,8 @@ def compute_embeddings_sentence_transformers(
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
@@ -214,20 +221,130 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
start_time = time.time()
if not manual_tokenize:
# Use SentenceTransformer's optimized encode path (default)
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Synchronize if CUDA to measure accurate wall time
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
all_embeddings: list[np.ndarray] = []
# Progress bar when building or for large inputs
show_progress = is_build or len(texts) > 32
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
unit="batch",
)
else:
batch_iter = range(0, len(texts), batch_size)
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():

View File

@@ -192,6 +192,7 @@ class EmbeddingServerManager:
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,

View File

@@ -0,0 +1,240 @@
"""
Metadata filtering engine for LEANN search results.
This module provides generic metadata filtering capabilities that can be applied
to search results from any LEANN backend. The filtering supports various
operators for different data types including numbers, strings, booleans, and lists.
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
# Type alias for filter specifications
FilterValue = Union[str, int, float, bool, list]
FilterSpec = dict[str, FilterValue]
MetadataFilters = dict[str, FilterSpec]
class MetadataFilterEngine:
"""
Engine for evaluating metadata filters against search results.
Supports various operators for filtering based on metadata fields:
- Comparison: ==, !=, <, <=, >, >=
- Membership: in, not_in
- String operations: contains, starts_with, ends_with
- Boolean operations: is_true, is_false
"""
def __init__(self):
"""Initialize the filter engine with supported operators."""
self.operators = {
"==": self._equals,
"!=": self._not_equals,
"<": self._less_than,
"<=": self._less_than_or_equal,
">": self._greater_than,
">=": self._greater_than_or_equal,
"in": self._in,
"not_in": self._not_in,
"contains": self._contains,
"starts_with": self._starts_with,
"ends_with": self._ends_with,
"is_true": self._is_true,
"is_false": self._is_false,
}
def apply_filters(
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
) -> list[dict[str, Any]]:
"""
Apply metadata filters to a list of search results.
Args:
search_results: List of result dictionaries, each containing 'metadata' field
metadata_filters: Dictionary of filter specifications
Format: {"field_name": {"operator": value}}
Returns:
Filtered list of search results
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying filters: {metadata_filters}")
logger.debug(f"Input results count: {len(search_results)}")
filtered_results = []
for result in search_results:
if self._evaluate_filters(result, metadata_filters):
filtered_results.append(result)
logger.debug(f"Filtered results count: {len(filtered_results)}")
return filtered_results
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
"""
Evaluate all filters against a single search result.
All filters must pass (AND logic) for the result to be included.
Args:
result: Full search result dictionary (including metadata, text, etc.)
filters: Filter specifications to evaluate
Returns:
True if all filters pass, False otherwise
"""
for field_name, filter_spec in filters.items():
if not self._evaluate_field_filter(result, field_name, filter_spec):
return False
return True
def _evaluate_field_filter(
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
) -> bool:
"""
Evaluate a single field filter against a search result.
Args:
result: Full search result dictionary
field_name: Name of the field to filter on
filter_spec: Filter specification for this field
Returns:
True if the filter passes, False otherwise
"""
# First check top-level fields, then check metadata
field_value = result.get(field_name)
if field_value is None:
# Try to get from metadata if not found at top level
metadata = result.get("metadata", {})
field_value = metadata.get(field_name)
# Handle missing fields - they fail all filters except existence checks
if field_value is None:
logger.debug(f"Field '{field_name}' not found in result or metadata")
return False
# Evaluate each operator in the filter spec
for operator, expected_value in filter_spec.items():
if operator not in self.operators:
logger.warning(f"Unsupported operator: {operator}")
return False
try:
if not self.operators[operator](field_value, expected_value):
logger.debug(
f"Filter failed: {field_name} {operator} {expected_value} "
f"(actual: {field_value})"
)
return False
except Exception as e:
logger.warning(
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
)
return False
return True
# Comparison operators
def _equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value equals expected value."""
return field_value == expected_value
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value does not equal expected value."""
return field_value != expected_value
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
# Membership operators
def _in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'in' operator requires a list, tuple, or set")
return field_value in expected_value
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is not in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'not_in' operator requires a list, tuple, or set")
return field_value not in expected_value
# String operators
def _contains(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value contains the expected substring."""
field_str = str(field_value)
expected_str = str(expected_value)
return expected_str in field_str
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value starts with the expected prefix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.startswith(expected_str)
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value ends with the expected suffix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.endswith(expected_str)
# Boolean operators
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is truthy."""
return bool(field_value)
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is falsy."""
return not bool(field_value)
# Helper methods
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
"""
Helper for numeric comparisons with type coercion.
Args:
field_value: Value from metadata
expected_value: Value to compare against
compare_func: Comparison function to apply
Returns:
Result of comparison
"""
try:
# Try to convert both values to numbers for comparison
if isinstance(field_value, str) and isinstance(expected_value, str):
# String comparison if both are strings
return compare_func(field_value, expected_value)
# Numeric comparison - attempt to convert to float
field_num = (
float(field_value) if not isinstance(field_value, (int, float)) else field_value
)
expected_num = (
float(expected_value)
if not isinstance(expected_value, (int, float))
else expected_value
)
return compare_func(field_num, expected_num)
except (ValueError, TypeError):
# Fall back to string comparison if numeric conversion fails
return compare_func(str(field_value), str(expected_value))