Metadata filtering feature (#75)
* Metadata filtering initial version * Metadata filtering initial version * Fixes linter issues * Cleanup code * Clean up and readme * Fix after review * Use UV in example * Merge main into feature/metadata-filtering
This commit is contained in:
@@ -10,7 +10,7 @@ import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -18,6 +18,7 @@ from leann.interface import LeannBackendSearcherInterface
|
||||
|
||||
from .chat import get_llm
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .metadata_filter import MetadataFilterEngine
|
||||
from .registry import BACKEND_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -125,6 +126,7 @@ class PassageManager:
|
||||
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
||||
# per-shard maps and do a lightweight per-shard lookup on demand.
|
||||
self._total_count: int = 0
|
||||
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
||||
|
||||
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||
index_name_base = None
|
||||
@@ -212,6 +214,56 @@ class PassageManager:
|
||||
continue
|
||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||
|
||||
def filter_search_results(
|
||||
self,
|
||||
search_results: list[SearchResult],
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Apply metadata filters to search results.
|
||||
|
||||
Args:
|
||||
search_results: List of SearchResult objects
|
||||
metadata_filters: Filter specifications to apply
|
||||
|
||||
Returns:
|
||||
Filtered list of SearchResult objects
|
||||
"""
|
||||
if not metadata_filters:
|
||||
return search_results
|
||||
|
||||
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
||||
|
||||
# Convert SearchResult objects to dictionaries for the filter engine
|
||||
result_dicts = []
|
||||
for result in search_results:
|
||||
result_dicts.append(
|
||||
{
|
||||
"id": result.id,
|
||||
"score": result.score,
|
||||
"text": result.text,
|
||||
"metadata": result.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Apply filters using the filter engine
|
||||
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
||||
|
||||
# Convert back to SearchResult objects
|
||||
filtered_results = []
|
||||
for result_dict in filtered_dicts:
|
||||
filtered_results.append(
|
||||
SearchResult(
|
||||
id=result_dict["id"],
|
||||
score=result_dict["score"],
|
||||
text=result_dict["text"],
|
||||
metadata=result_dict["metadata"],
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
||||
return filtered_results
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._total_count
|
||||
|
||||
@@ -599,12 +651,38 @@ class LeannSearcher:
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
expected_zmq_port: int = 5557,
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search for nearest neighbors with optional metadata filtering.
|
||||
|
||||
Args:
|
||||
query: Text query to search for
|
||||
top_k: Number of nearest neighbors to return
|
||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||
beam_width: Number of parallel search paths/IO requests per iteration
|
||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||
expected_zmq_port: ZMQ port for embedding server communication
|
||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||
Format: {"field_name": {"operator": value}}
|
||||
Supported operators:
|
||||
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
||||
- Membership: "in", "not_in"
|
||||
- String: "contains", "starts_with", "ends_with"
|
||||
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
||||
**kwargs: Backend-specific parameters
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects with text, metadata, and similarity scores
|
||||
"""
|
||||
logger.info("🔍 LeannSearcher.search() called:")
|
||||
logger.info(f" Query: '{query}'")
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Metadata filters: {metadata_filters}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Smart top_k detection and adjustment
|
||||
@@ -704,6 +782,13 @@ class LeannSearcher:
|
||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||
)
|
||||
|
||||
# Apply metadata filters if specified
|
||||
if metadata_filters:
|
||||
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
||||
enriched_results = self.passage_manager.filter_search_results(
|
||||
enriched_results, metadata_filters
|
||||
)
|
||||
|
||||
# Define color codes outside the loop for final message
|
||||
GREEN = "\033[92m"
|
||||
RESET = "\033[0m"
|
||||
@@ -766,6 +851,7 @@ class LeannChat:
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||
batch_size: int = 0,
|
||||
**search_kwargs,
|
||||
):
|
||||
@@ -781,6 +867,7 @@ class LeannChat:
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
pruning_strategy=pruning_strategy,
|
||||
expected_zmq_port=expected_zmq_port,
|
||||
metadata_filters=metadata_filters,
|
||||
batch_size=batch_size,
|
||||
**search_kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user