Compare commits
2 Commits
fix/drop-p
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
198044d033 | ||
|
|
a2e5f5294b |
118
.github/workflows/build-reusable.yml
vendored
118
.github/workflows/build-reusable.yml
vendored
@@ -28,15 +28,36 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
name: Type Check with ty
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install ty
|
||||||
|
run: uv tool install ty
|
||||||
|
|
||||||
|
- name: Run ty type checker
|
||||||
|
run: |
|
||||||
|
# Run ty on core packages, apps, and tests
|
||||||
|
ty check packages/leann-core/src apps tests
|
||||||
|
|
||||||
build:
|
build:
|
||||||
needs: lint
|
needs: [lint, type-check]
|
||||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- os: ubuntu-22.04
|
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||||
python: '3.9'
|
# which requires Python 3.10+
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
@@ -46,8 +67,6 @@ jobs:
|
|||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
# ARM64 Linux builds
|
# ARM64 Linux builds
|
||||||
- os: ubuntu-24.04-arm
|
|
||||||
python: '3.9'
|
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
@@ -56,8 +75,6 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-24.04-arm
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-14
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -66,8 +83,6 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-15
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
@@ -76,16 +91,24 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-15
|
- os: macos-15
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-13
|
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||||
python: '3.9'
|
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||||
- os: macos-13
|
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||||
|
- os: macos-15-intel
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-13
|
- os: macos-15-intel
|
||||||
python: '3.11'
|
python: '3.11'
|
||||||
- os: macos-13
|
- os: macos-15-intel
|
||||||
python: '3.12'
|
python: '3.12'
|
||||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
# macOS 26 (beta) - arm64
|
||||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
- os: macos-26
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-26
|
||||||
|
python: '3.13'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -204,13 +227,16 @@ jobs:
|
|||||||
# Use system clang for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# Homebrew libraries on each macOS version require matching minimum version
|
# Set deployment target based on runner
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
@@ -224,14 +250,16 @@ jobs:
|
|||||||
# Use system clang for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
# Set deployment target based on runner
|
||||||
# But Homebrew libraries on each macOS version require matching minimum version
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
@@ -269,16 +297,19 @@ jobs:
|
|||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
run: |
|
run: |
|
||||||
# Determine deployment target based on runner OS
|
# Determine deployment target based on runner OS
|
||||||
# Must match the Homebrew libraries for each macOS version
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
HNSW_TARGET="13.0"
|
|
||||||
DISKANN_TARGET="13.3"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
HNSW_TARGET="14.0"
|
|
||||||
DISKANN_TARGET="14.0"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
HNSW_TARGET="15.0"
|
HNSW_TARGET="15.0"
|
||||||
DISKANN_TARGET="15.0"
|
DISKANN_TARGET="15.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
HNSW_TARGET="26.0"
|
||||||
|
DISKANN_TARGET="26.0"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Repair HNSW wheel
|
# Repair HNSW wheel
|
||||||
@@ -334,12 +365,15 @@ jobs:
|
|||||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Provides common parameters and functionality for all RAG examples.
|
|||||||
import argparse
|
import argparse
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
@@ -257,8 +257,8 @@ class BaseRAGExample(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load data from the source. Returns list of text chunks (strings or dicts with 'text' key)."""
|
"""Load data from the source. Returns list of text chunks as dicts with 'text' and 'metadata' keys."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_llm_config(self, args) -> dict[str, Any]:
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
@@ -282,8 +282,8 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
async def build_index(self, args, texts: list[Union[str, dict[str, Any]]]) -> str:
|
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||||
"""Build LEANN index from texts (accepts strings or dicts with 'text' key)."""
|
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Supports Chrome browser history.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -85,7 +86,7 @@ class BrowserRAG(BaseRAGExample):
|
|||||||
|
|
||||||
return profiles
|
return profiles
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load browser history and convert to text chunks."""
|
"""Load browser history and convert to text chunks."""
|
||||||
# Determine Chrome profiles
|
# Determine Chrome profiles
|
||||||
if args.chrome_profile and not args.auto_find_profiles:
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Supports ChatGPT export data from chat.html files.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -80,7 +81,7 @@ class ChatGPTRAG(BaseRAGExample):
|
|||||||
|
|
||||||
return export_files
|
return export_files
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load ChatGPT export data and convert to text chunks."""
|
"""Load ChatGPT export data and convert to text chunks."""
|
||||||
export_path = Path(args.export_path)
|
export_path = Path(args.export_path)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Supports Claude export data from JSON files.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -80,7 +81,7 @@ class ClaudeRAG(BaseRAGExample):
|
|||||||
|
|
||||||
return export_files
|
return export_files
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load Claude export data and convert to text chunks."""
|
"""Load Claude export data and convert to text chunks."""
|
||||||
export_path = Path(args.export_path)
|
export_path = Path(args.export_path)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ optimized chunking parameters.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -77,7 +78,7 @@ class CodeRAG(BaseRAGExample):
|
|||||||
help="Try to preserve import statements in chunks (default: True)",
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load code files and convert to AST-aware chunks."""
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
print(f"📁 Including extensions: {args.include_extensions}")
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
@@ -88,14 +89,6 @@ class CodeRAG(BaseRAGExample):
|
|||||||
if not repo_path.exists():
|
if not repo_path.exists():
|
||||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
# Load code files with filtering
|
|
||||||
reader_kwargs = {
|
|
||||||
"recursive": True,
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"required_exts": args.include_extensions,
|
|
||||||
"exclude_hidden": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create exclusion filter
|
# Create exclusion filter
|
||||||
def file_filter(file_path: str) -> bool:
|
def file_filter(file_path: str) -> bool:
|
||||||
"""Filter out unwanted files and directories."""
|
"""Filter out unwanted files and directories."""
|
||||||
@@ -120,8 +113,11 @@ class CodeRAG(BaseRAGExample):
|
|||||||
# Load documents with file filtering
|
# Load documents with file filtering
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
args.repo_dir,
|
args.repo_dir,
|
||||||
file_extractor=None, # Use default extractors
|
file_extractor=None,
|
||||||
**reader_kwargs,
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=args.include_extensions,
|
||||||
|
exclude_hidden=True,
|
||||||
).load_data(show_progress=True)
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
# Apply custom filtering
|
# Apply custom filtering
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Supports PDF, TXT, MD, and other document formats.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -52,7 +52,7 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
help="Enable AST-aware chunking for code files in the data directory",
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[Union[str, dict[str, Any]]]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load documents and convert to text chunks."""
|
"""Load documents and convert to text chunks."""
|
||||||
print(f"Loading documents from: {args.data_dir}")
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
if args.file_types:
|
if args.file_types:
|
||||||
@@ -66,16 +66,12 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
# Load documents
|
# Load documents
|
||||||
reader_kwargs = {
|
documents = SimpleDirectoryReader(
|
||||||
"recursive": True,
|
args.data_dir,
|
||||||
"encoding": "utf-8",
|
recursive=True,
|
||||||
}
|
encoding="utf-8",
|
||||||
if args.file_types:
|
required_exts=args.file_types if args.file_types else None,
|
||||||
reader_kwargs["required_exts"] = args.file_types
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
|||||||
@@ -127,11 +127,12 @@ class EmlxMboxReader(MboxReader):
|
|||||||
|
|
||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
file: Path, # Note: for EmlxMboxReader, this is actually a directory
|
||||||
extra_info: dict | None = None,
|
extra_info: dict | None = None,
|
||||||
fs: AbstractFileSystem | None = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
directory = file # Rename for clarity - this is a directory of .emlx files
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Supports Apple Mail on macOS.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -64,7 +65,7 @@ class EmailRAG(BaseRAGExample):
|
|||||||
|
|
||||||
return messages_dirs
|
return messages_dirs
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load emails and convert to text chunks."""
|
"""Load emails and convert to text chunks."""
|
||||||
# Determine mail directories
|
# Determine mail directories
|
||||||
if args.mail_path:
|
if args.mail_path:
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
text=True,
|
text=True,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and bool(result.stdout.strip())
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -314,7 +314,9 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
def _create_concatenated_content(
|
||||||
|
self, message_group: dict, contact_name: str
|
||||||
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import argparse
|
|||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -65,7 +66,7 @@ class ImageRAG(BaseRAGExample):
|
|||||||
help="Batch size for CLIP embedding generation (default: 32)",
|
help="Batch size for CLIP embedding generation (default: 32)",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||||
self._image_data = self._load_images_and_embeddings(args)
|
self._image_data = self._load_images_and_embeddings(args)
|
||||||
return [entry["text"] for entry in self._image_data]
|
return [entry["text"] for entry in self._image_data]
|
||||||
@@ -168,7 +169,7 @@ class ImageRAG(BaseRAGExample):
|
|||||||
print(f"✅ Processed {len(image_data)} images")
|
print(f"✅ Processed {len(image_data)} images")
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
async def build_index(self, args, texts: list[str]) -> str:
|
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||||
"""Build index using pre-computed CLIP embeddings."""
|
"""Build index using pre-computed CLIP embeddings."""
|
||||||
from leann.api import LeannBuilder
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ This example demonstrates how to build a RAG system on your iMessage conversatio
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from leann.chunking_utils import create_text_chunks
|
from leann.chunking_utils import create_text_chunks
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ class IMessageRAG(BaseRAGExample):
|
|||||||
help="Overlap between text chunks (default: 200)",
|
help="Overlap between text chunks (default: 200)",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load iMessage history and convert to text chunks."""
|
"""Load iMessage history and convert to text chunks."""
|
||||||
print("Loading iMessage conversation history...")
|
print("Loading iMessage conversation history...")
|
||||||
|
|
||||||
|
|||||||
@@ -18,10 +18,11 @@ _repo_root = Path(__file__).resolve().parents[3]
|
|||||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
if str(_leann_core_src) not in sys.path:
|
if str(_leann_core_src) not in sys.path:
|
||||||
sys.path.append(str(_leann_core_src))
|
sys.path.insert(0, str(_leann_core_src))
|
||||||
if str(_leann_hnsw_pkg) not in sys.path:
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
sys.path.append(str(_leann_hnsw_pkg))
|
sys.path.insert(0, str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colpali_engine.models import ColPali
|
from colpali_engine.models import ColPali
|
||||||
@@ -93,9 +94,9 @@ for batch_doc in tqdm(dataloader):
|
|||||||
print(ds[0].shape)
|
print(ds[0].shape)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Build HNSW index via LeannRetriever primitives and run search
|
# Build HNSW index via LeannMultiVector primitives and run search
|
||||||
index_path = "./indexes/colpali.leann"
|
index_path = "./indexes/colpali.leann"
|
||||||
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
retriever.create_collection()
|
retriever.create_collection()
|
||||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
for i in range(len(filepaths)):
|
for i in range(len(filepaths)):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import argparse
|
|||||||
import faulthandler
|
import faulthandler
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -223,7 +223,7 @@ if need_to_build_index:
|
|||||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||||
elif USE_HF_DATASET:
|
elif USE_HF_DATASET:
|
||||||
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||||
|
|
||||||
# Determine which datasets to load
|
# Determine which datasets to load
|
||||||
if DATASET_NAMES is not None:
|
if DATASET_NAMES is not None:
|
||||||
@@ -281,12 +281,12 @@ if need_to_build_index:
|
|||||||
splits_to_load = DATASET_SPLITS
|
splits_to_load = DATASET_SPLITS
|
||||||
|
|
||||||
# Load and concatenate multiple splits for this dataset
|
# Load and concatenate multiple splits for this dataset
|
||||||
datasets_to_concat = []
|
datasets_to_concat: list[Dataset] = []
|
||||||
for split in splits_to_load:
|
for split in splits_to_load:
|
||||||
if split not in dataset_dict:
|
if split not in dataset_dict:
|
||||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||||
continue
|
continue
|
||||||
split_dataset = dataset_dict[split]
|
split_dataset = cast(Dataset, dataset_dict[split])
|
||||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||||
datasets_to_concat.append(split_dataset)
|
datasets_to_concat.append(split_dataset)
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from leann_multi_vector import (
|
from leann_multi_vector import (
|
||||||
ViDoReBenchmarkEvaluator,
|
ViDoReBenchmarkEvaluator,
|
||||||
_ensure_repo_paths_importable,
|
_ensure_repo_paths_importable,
|
||||||
@@ -151,40 +151,43 @@ def load_vidore_v1_data(
|
|||||||
"""
|
"""
|
||||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||||
|
|
||||||
# Load queries
|
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||||
|
|
||||||
queries = {}
|
queries: dict[str, str] = {}
|
||||||
for row in query_ds:
|
for row in query_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
queries[query_id] = row["query"]
|
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||||
|
queries[query_id] = row_dict["query"]
|
||||||
|
|
||||||
# Load corpus (images)
|
# Load corpus (images) - cast to Dataset
|
||||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||||
|
|
||||||
corpus = {}
|
corpus: dict[str, Any] = {}
|
||||||
for row in corpus_ds:
|
for row in corpus_ds:
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
|
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||||
# Extract image from the dataset row
|
# Extract image from the dataset row
|
||||||
if "image" in row:
|
if "image" in row_dict:
|
||||||
corpus[corpus_id] = row["image"]
|
corpus[corpus_id] = row_dict["image"]
|
||||||
elif "page_image" in row:
|
elif "page_image" in row_dict:
|
||||||
corpus[corpus_id] = row["page_image"]
|
corpus[corpus_id] = row_dict["page_image"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load qrels (relevance judgments)
|
# Load qrels (relevance judgments) - cast to Dataset
|
||||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||||
|
|
||||||
qrels = {}
|
qrels: dict[str, dict[str, int]] = {}
|
||||||
for row in qrels_ds:
|
for row in qrels_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||||
if query_id not in qrels:
|
if query_id not in qrels:
|
||||||
qrels[query_id] = {}
|
qrels[query_id] = {}
|
||||||
qrels[query_id][corpus_id] = int(row["score"])
|
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
@@ -234,8 +237,8 @@ def evaluate_task(
|
|||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
|
|
||||||
task_config = VIDORE_V1_TASKS[task_name]
|
task_config = VIDORE_V1_TASKS[task_name]
|
||||||
dataset_path = task_config["dataset_path"]
|
dataset_path = str(task_config["dataset_path"])
|
||||||
revision = task_config["revision"]
|
revision = str(task_config["revision"])
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
corpus, queries, qrels = load_vidore_v1_data(
|
corpus, queries, qrels = load_vidore_v1_data(
|
||||||
@@ -286,7 +289,7 @@ def evaluate_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Search queries
|
# Search queries
|
||||||
task_prompt = task_config.get("prompt")
|
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||||
results = evaluator.search_queries(
|
results = evaluator.search_queries(
|
||||||
queries=queries,
|
queries=queries,
|
||||||
corpus_ids=corpus_ids_ordered,
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from leann_multi_vector import (
|
from leann_multi_vector import (
|
||||||
ViDoReBenchmarkEvaluator,
|
ViDoReBenchmarkEvaluator,
|
||||||
_ensure_repo_paths_importable,
|
_ensure_repo_paths_importable,
|
||||||
@@ -91,8 +91,8 @@ def load_vidore_v2_data(
|
|||||||
"""
|
"""
|
||||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||||
|
|
||||||
# Load queries
|
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||||
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||||
|
|
||||||
# Check if dataset has language field before filtering
|
# Check if dataset has language field before filtering
|
||||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||||
@@ -112,8 +112,9 @@ def load_vidore_v2_data(
|
|||||||
if len(query_ds_filtered) == 0:
|
if len(query_ds_filtered) == 0:
|
||||||
# Try to get a sample to see actual language values
|
# Try to get a sample to see actual language values
|
||||||
try:
|
try:
|
||||||
sample_ds = load_dataset(
|
sample_ds = cast(
|
||||||
dataset_path, "queries", split=split, revision=revision
|
Dataset,
|
||||||
|
load_dataset(dataset_path, "queries", split=split, revision=revision),
|
||||||
)
|
)
|
||||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||||
sample_langs = set(sample_ds["language"])
|
sample_langs = set(sample_ds["language"])
|
||||||
@@ -126,37 +127,40 @@ def load_vidore_v2_data(
|
|||||||
)
|
)
|
||||||
query_ds = query_ds_filtered
|
query_ds = query_ds_filtered
|
||||||
|
|
||||||
queries = {}
|
queries: dict[str, str] = {}
|
||||||
for row in query_ds:
|
for row in query_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
queries[query_id] = row["query"]
|
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||||
|
queries[query_id] = row_dict["query"]
|
||||||
|
|
||||||
# Load corpus (images)
|
# Load corpus (images) - cast to Dataset
|
||||||
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||||
|
|
||||||
corpus = {}
|
corpus: dict[str, Any] = {}
|
||||||
for row in corpus_ds:
|
for row in corpus_ds:
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
|
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||||
# Extract image from the dataset row
|
# Extract image from the dataset row
|
||||||
if "image" in row:
|
if "image" in row_dict:
|
||||||
corpus[corpus_id] = row["image"]
|
corpus[corpus_id] = row_dict["image"]
|
||||||
elif "page_image" in row:
|
elif "page_image" in row_dict:
|
||||||
corpus[corpus_id] = row["page_image"]
|
corpus[corpus_id] = row_dict["page_image"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load qrels (relevance judgments)
|
# Load qrels (relevance judgments) - cast to Dataset
|
||||||
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||||
|
|
||||||
qrels = {}
|
qrels: dict[str, dict[str, int]] = {}
|
||||||
for row in qrels_ds:
|
for row in qrels_ds:
|
||||||
query_id = f"query-{split}-{row['query-id']}"
|
row_dict = cast(dict[str, Any], row)
|
||||||
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||||
if query_id not in qrels:
|
if query_id not in qrels:
|
||||||
qrels[query_id] = {}
|
qrels[query_id] = {}
|
||||||
qrels[query_id][corpus_id] = int(row["score"])
|
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
@@ -204,13 +208,13 @@ def evaluate_task(
|
|||||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||||
|
|
||||||
task_config = VIDORE_V2_TASKS[task_name]
|
task_config = VIDORE_V2_TASKS[task_name]
|
||||||
dataset_path = task_config["dataset_path"]
|
dataset_path = str(task_config["dataset_path"])
|
||||||
revision = task_config["revision"]
|
revision = str(task_config["revision"])
|
||||||
|
|
||||||
# Determine language
|
# Determine language
|
||||||
if language is None:
|
if language is None:
|
||||||
# Use first language if multiple available
|
# Use first language if multiple available
|
||||||
languages = task_config.get("languages")
|
languages = cast(Optional[list[str]], task_config.get("languages"))
|
||||||
if languages is None:
|
if languages is None:
|
||||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||||
language = None
|
language = None
|
||||||
@@ -269,7 +273,7 @@ def evaluate_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Search queries
|
# Search queries
|
||||||
task_prompt = task_config.get("prompt")
|
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||||
results = evaluator.search_queries(
|
results = evaluator.search_queries(
|
||||||
queries=queries,
|
queries=queries,
|
||||||
corpus_ids=corpus_ids_ordered,
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
|||||||
@@ -177,7 +177,9 @@ class SlackMCPReader:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# If we get here, all retries failed or it's not a retryable error
|
# If we get here, all retries failed or it's not a retryable error
|
||||||
raise last_exception
|
if last_exception is not None:
|
||||||
|
raise last_exception
|
||||||
|
raise RuntimeError("Unexpected error: no exception captured during retry loop")
|
||||||
|
|
||||||
async def fetch_slack_messages(
|
async def fetch_slack_messages(
|
||||||
self, channel: Optional[str] = None, limit: int = 100
|
self, channel: Optional[str] = None, limit: int = 100
|
||||||
@@ -267,7 +269,10 @@ class SlackMCPReader:
|
|||||||
messages = json.loads(content["text"])
|
messages = json.loads(content["text"])
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||||
messages = self._parse_csv_messages(content["text"], channel)
|
text_content = content.get("text", "")
|
||||||
|
messages = self._parse_csv_messages(
|
||||||
|
text_content if text_content else "", channel or "unknown"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
messages = result["content"]
|
messages = result["content"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
from apps.base_rag_example import BaseRAGExample
|
||||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||||
@@ -139,7 +140,7 @@ class SlackMCPRAG(BaseRAGExample):
|
|||||||
print("4. Try running the MCP server command directly to test it")
|
print("4. Try running the MCP server command directly to test it")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load Slack messages via MCP server."""
|
"""Load Slack messages via MCP server."""
|
||||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
@@ -188,7 +189,8 @@ class SlackMCPRAG(BaseRAGExample):
|
|||||||
print(sample_text)
|
print(sample_text)
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
return texts
|
# Convert strings to dict format expected by base class
|
||||||
|
return [{"text": text, "metadata": {"source": "slack"}} for text in texts]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading Slack data: {e}")
|
print(f"Error loading Slack data: {e}")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from apps.base_rag_example import BaseRAGExample
|
from apps.base_rag_example import BaseRAGExample
|
||||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||||
@@ -116,7 +117,7 @@ class TwitterMCPRAG(BaseRAGExample):
|
|||||||
print("5. Try running the MCP server command directly to test it")
|
print("5. Try running the MCP server command directly to test it")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load Twitter bookmarks via MCP server."""
|
"""Load Twitter bookmarks via MCP server."""
|
||||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
@@ -156,7 +157,8 @@ class TwitterMCPRAG(BaseRAGExample):
|
|||||||
print(sample_text)
|
print(sample_text)
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
return texts
|
# Convert strings to dict format expected by base class
|
||||||
|
return [{"text": text, "metadata": {"source": "twitter"}} for text in texts]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Supports WeChat chat history export and search.
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
@@ -91,7 +92,7 @@ class WeChatRAG(BaseRAGExample):
|
|||||||
print(f"Export error: {e}")
|
print(f"Export error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||||
"""Load WeChat history and convert to text chunks."""
|
"""Load WeChat history and convert to text chunks."""
|
||||||
# Initialize WeChat reader with export capabilities
|
# Initialize WeChat reader with export capabilities
|
||||||
reader = WeChatHistoryReader()
|
reader = WeChatHistoryReader()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ name = "leann-core"
|
|||||||
version = "0.3.5"
|
version = "0.3.5"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
# All required dependencies included
|
# All required dependencies included
|
||||||
|
|||||||
@@ -239,11 +239,11 @@ def create_ast_chunks(
|
|||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_text = None
|
chunk_text: str | None = None
|
||||||
astchunk_metadata = {}
|
astchunk_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = str(chunk.text) if chunk.text else None
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
elif isinstance(chunk, dict):
|
elif isinstance(chunk, dict):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from .settings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str | None:
|
||||||
"""Extract text from PDF using PyMuPDF for better quality."""
|
"""Extract text from PDF using PyMuPDF for better quality."""
|
||||||
try:
|
try:
|
||||||
import fitz # PyMuPDF
|
import fitz # PyMuPDF
|
||||||
@@ -35,7 +35,7 @@ def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
def extract_pdf_text_with_pdfplumber(file_path: str) -> str | None:
|
||||||
"""Extract text from PDF using pdfplumber for better quality."""
|
"""Extract text from PDF using pdfplumber for better quality."""
|
||||||
try:
|
try:
|
||||||
import pdfplumber
|
import pdfplumber
|
||||||
|
|||||||
@@ -451,7 +451,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
# TODO: Haven't tested this yet
|
# TODO: Haven't tested this yet
|
||||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||||
try:
|
try:
|
||||||
torch.backends.mkldnn.enabled = True
|
# PyTorch's ContextProp type is complex; cast for type checker
|
||||||
|
torch.backends.mkldnn.enabled = True # type: ignore[assignment]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -11,14 +11,15 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
# Try to import readline with fallback for Windows
|
# Try to import readline with fallback for Windows
|
||||||
|
HAS_READLINE = False
|
||||||
|
readline = None # type: ignore[assignment]
|
||||||
try:
|
try:
|
||||||
import readline
|
import readline # type: ignore[no-redef]
|
||||||
|
|
||||||
HAS_READLINE = True
|
HAS_READLINE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Windows doesn't have readline by default
|
# Windows doesn't have readline by default
|
||||||
HAS_READLINE = False
|
pass
|
||||||
readline = None
|
|
||||||
|
|
||||||
|
|
||||||
class InteractiveSession:
|
class InteractiveSession:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ operators for different data types including numbers, strings, booleans, and lis
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class MetadataFilterEngine:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def apply_filters(
|
def apply_filters(
|
||||||
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
self, search_results: list[dict[str, Any]], metadata_filters: Optional[MetadataFilters]
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Apply metadata filters to a list of search results.
|
Apply metadata filters to a list of search results.
|
||||||
|
|||||||
@@ -56,7 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
with open(meta_path, encoding="utf-8") as f:
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
|
def _ensure_server_running(
|
||||||
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
@@ -81,7 +83,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port if port is not None else 5557,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
embedding_mode=self.embedding_mode,
|
embedding_mode=self.embedding_mode,
|
||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
@@ -98,7 +100,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: int = 5557,
|
zmq_port: Optional[int] = None,
|
||||||
query_template: Optional[str] = None,
|
query_template: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ name = "leann"
|
|||||||
version = "0.3.5"
|
version = "0.3.5"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "LEANN Team" }
|
{ name = "LEANN Team" }
|
||||||
@@ -18,10 +18,10 @@ classifiers = [
|
|||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Default installation: core + hnsw + diskann
|
# Default installation: core + hnsw + diskann
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
@@ -157,6 +157,19 @@ exclude = ["localhost", "127.0.0.1", "example.com"]
|
|||||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||||
scheme = ["https", "http"]
|
scheme = ["https", "http"]
|
||||||
|
|
||||||
|
[tool.ty]
|
||||||
|
# Type checking with ty (Astral's fast Python type checker)
|
||||||
|
# ty is 10-100x faster than mypy. See: https://docs.astral.sh/ty/
|
||||||
|
|
||||||
|
[tool.ty.environment]
|
||||||
|
python-version = "3.11"
|
||||||
|
extra-paths = ["apps", "packages/leann-core/src"]
|
||||||
|
|
||||||
|
[tool.ty.rules]
|
||||||
|
# Disable some noisy rules that have many false positives
|
||||||
|
possibly-missing-attribute = "ignore"
|
||||||
|
unresolved-import = "ignore" # Many optional dependencies
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
python_files = ["test_*.py"]
|
python_files = ["test_*.py"]
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ def test_large_index():
|
|||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
results = searcher.search(["word10 word20"], top_k=10)
|
results = searcher.search("word10 word20", top_k=10)
|
||||||
assert len(results[0]) == 10
|
assert len(results) == 10
|
||||||
# Cleanup
|
# Cleanup
|
||||||
searcher.cleanup()
|
searcher.cleanup()
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a document so builder is created
|
# Mock load_documents to return a document so builder is created
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a document so builder is created
|
# Mock load_documents to return a document so builder is created
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a document so builder is created
|
# Mock load_documents to return a document so builder is created
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
@@ -307,7 +307,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a document so builder is created
|
# Mock load_documents to return a document so builder is created
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
@@ -376,7 +376,7 @@ class TestPromptTemplateStoredInEmbeddingOptions:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a document so builder is created
|
# Mock load_documents to return a document so builder is created
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
@@ -432,7 +432,7 @@ class TestPromptTemplateFlowsToComputeEmbeddings:
|
|||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|
||||||
# Mock load_documents to return a simple document
|
# Mock load_documents to return a simple document
|
||||||
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}])
|
cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) # type: ignore[assignment]
|
||||||
|
|
||||||
parser = cli.create_parser()
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def check_lmstudio_available() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_lmstudio_first_model() -> str:
|
def get_lmstudio_first_model() -> str | None:
|
||||||
"""Get the first available model from LM Studio."""
|
"""Get the first available model from LM Studio."""
|
||||||
try:
|
try:
|
||||||
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
response = requests.get("http://localhost:1234/v1/models", timeout=5.0)
|
||||||
@@ -91,6 +91,7 @@ class TestPromptTemplateOpenAI:
|
|||||||
model_name = get_lmstudio_first_model()
|
model_name = get_lmstudio_first_model()
|
||||||
if not model_name:
|
if not model_name:
|
||||||
pytest.skip("No models loaded in LM Studio")
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
assert model_name is not None # Type narrowing for type checker
|
||||||
|
|
||||||
texts = ["artificial intelligence", "machine learning"]
|
texts = ["artificial intelligence", "machine learning"]
|
||||||
prompt_template = "search_query: "
|
prompt_template = "search_query: "
|
||||||
@@ -120,6 +121,7 @@ class TestPromptTemplateOpenAI:
|
|||||||
model_name = get_lmstudio_first_model()
|
model_name = get_lmstudio_first_model()
|
||||||
if not model_name:
|
if not model_name:
|
||||||
pytest.skip("No models loaded in LM Studio")
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
assert model_name is not None # Type narrowing for type checker
|
||||||
|
|
||||||
text = "machine learning"
|
text = "machine learning"
|
||||||
base_url = "http://localhost:1234/v1"
|
base_url = "http://localhost:1234/v1"
|
||||||
@@ -271,6 +273,7 @@ class TestLMStudioSDK:
|
|||||||
model_name = get_lmstudio_first_model()
|
model_name = get_lmstudio_first_model()
|
||||||
if not model_name:
|
if not model_name:
|
||||||
pytest.skip("No models loaded in LM Studio")
|
pytest.skip("No models loaded in LM Studio")
|
||||||
|
assert model_name is not None # Type narrowing for type checker
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.embedding_compute import _query_lmstudio_context_limit
|
from leann.embedding_compute import _query_lmstudio_context_limit
|
||||||
|
|||||||
@@ -581,7 +581,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
|||||||
|
|
||||||
# Create a concrete implementation for testing
|
# Create a concrete implementation for testing
|
||||||
class TestSearcher(BaseSearcher):
|
class TestSearcher(BaseSearcher):
|
||||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
def search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
pruning_strategy="global",
|
||||||
|
zmq_port=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return {"labels": [], "distances": []}
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
searcher = object.__new__(TestSearcher)
|
searcher = object.__new__(TestSearcher)
|
||||||
@@ -625,7 +636,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
|||||||
|
|
||||||
# Create a concrete implementation for testing
|
# Create a concrete implementation for testing
|
||||||
class TestSearcher(BaseSearcher):
|
class TestSearcher(BaseSearcher):
|
||||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
def search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
pruning_strategy="global",
|
||||||
|
zmq_port=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return {"labels": [], "distances": []}
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
searcher = object.__new__(TestSearcher)
|
searcher = object.__new__(TestSearcher)
|
||||||
@@ -671,7 +693,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
|||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
class TestSearcher(BaseSearcher):
|
class TestSearcher(BaseSearcher):
|
||||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
def search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
pruning_strategy="global",
|
||||||
|
zmq_port=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return {"labels": [], "distances": []}
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
searcher = object.__new__(TestSearcher)
|
searcher = object.__new__(TestSearcher)
|
||||||
@@ -710,7 +743,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
|||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
class TestSearcher(BaseSearcher):
|
class TestSearcher(BaseSearcher):
|
||||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
def search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
pruning_strategy="global",
|
||||||
|
zmq_port=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return {"labels": [], "distances": []}
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
searcher = object.__new__(TestSearcher)
|
searcher = object.__new__(TestSearcher)
|
||||||
@@ -774,7 +818,18 @@ class TestQueryTemplateApplicationInComputeEmbedding:
|
|||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
class TestSearcher(BaseSearcher):
|
class TestSearcher(BaseSearcher):
|
||||||
def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs):
|
def search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
complexity=64,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
pruning_strategy="global",
|
||||||
|
zmq_port=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
return {"labels": [], "distances": []}
|
return {"labels": [], "distances": []}
|
||||||
|
|
||||||
searcher = object.__new__(TestSearcher)
|
searcher = object.__new__(TestSearcher)
|
||||||
|
|||||||
@@ -97,17 +97,17 @@ def test_backend_options():
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Use smaller model in CI to avoid memory issues
|
# Use smaller model in CI to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
is_ci = os.environ.get("CI") == "true"
|
||||||
model_args = {
|
embedding_model = (
|
||||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
"sentence-transformers/all-MiniLM-L6-v2" if is_ci else "facebook/contriever"
|
||||||
"dimensions": 384,
|
)
|
||||||
}
|
dimensions = 384 if is_ci else None
|
||||||
else:
|
|
||||||
model_args = {}
|
|
||||||
|
|
||||||
# Test HNSW backend (as shown in README)
|
# Test HNSW backend (as shown in README)
|
||||||
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
|
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
|
||||||
builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args)
|
builder_hnsw = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model=embedding_model, dimensions=dimensions
|
||||||
|
)
|
||||||
builder_hnsw.add_text("Test document for HNSW backend")
|
builder_hnsw.add_text("Test document for HNSW backend")
|
||||||
builder_hnsw.build_index(hnsw_path)
|
builder_hnsw.build_index(hnsw_path)
|
||||||
assert Path(hnsw_path).parent.exists()
|
assert Path(hnsw_path).parent.exists()
|
||||||
@@ -115,7 +115,9 @@ def test_backend_options():
|
|||||||
|
|
||||||
# Test DiskANN backend (mentioned as available option)
|
# Test DiskANN backend (mentioned as available option)
|
||||||
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
|
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
|
||||||
builder_diskann = LeannBuilder(backend_name="diskann", **model_args)
|
builder_diskann = LeannBuilder(
|
||||||
|
backend_name="diskann", embedding_model=embedding_model, dimensions=dimensions
|
||||||
|
)
|
||||||
builder_diskann.add_text("Test document for DiskANN backend")
|
builder_diskann.add_text("Test document for DiskANN backend")
|
||||||
builder_diskann.build_index(diskann_path)
|
builder_diskann.build_index(diskann_path)
|
||||||
assert Path(diskann_path).parent.exists()
|
assert Path(diskann_path).parent.exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user