Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df3350be43 | ||
|
|
94d9a203a2 | ||
|
|
72455bb269 | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 |
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
58
.github/workflows/build-reusable.yml
vendored
58
.github/workflows/build-reusable.yml
vendored
@@ -54,6 +54,17 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -108,13 +119,46 @@ jobs:
|
|||||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
patchelf
|
patchelf
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
# Debug: Show system information
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
echo "🔍 System Information:"
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
echo "Architecture: $(uname -m)"
|
||||||
source /opt/intel/oneapi/setvars.sh
|
echo "OS: $(uname -a)"
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
|
|||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -18,10 +18,12 @@ demo/experiment_results/**/*.json
|
|||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
*.png
|
||||||
!.vscode/*.json
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -100,3 +102,6 @@ CLAUDE.local.md
|
|||||||
.claude/*.local.*
|
.claude/*.local.*
|
||||||
.claude/local/*
|
.claude/local/*
|
||||||
benchmarks/data/
|
benchmarks/data/
|
||||||
|
|
||||||
|
## multi vector
|
||||||
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -656,6 +656,19 @@ results = searcher.search(
|
|||||||
|
|
||||||
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||||
|
|
||||||
|
### 🔍 Grep Search
|
||||||
|
|
||||||
|
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||||
|
|
||||||
|
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
|
|||||||
@@ -1,16 +1,38 @@
|
|||||||
"""
|
"""Unified chunking utilities facade.
|
||||||
Chunking utilities for LEANN RAG applications.
|
|
||||||
Provides AST-aware and traditional text chunking functionality.
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .utils import (
|
import sys
|
||||||
CODE_EXTENSIONS,
|
from pathlib import Path
|
||||||
create_ast_chunks,
|
|
||||||
create_text_chunks,
|
try:
|
||||||
create_traditional_chunks,
|
from leann.chunking_utils import (
|
||||||
detect_code_files,
|
CODE_EXTENSIONS,
|
||||||
get_language_from_extension,
|
create_ast_chunks,
|
||||||
)
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CODE_EXTENSIONS",
|
"CODE_EXTENSIONS",
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
|
|||||||
@@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class LeannMultiVector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
dim: int = 128,
|
||||||
|
distance_metric: str = "mips",
|
||||||
|
m: int = 16,
|
||||||
|
ef_construction: int = 500,
|
||||||
|
is_compact: bool = False,
|
||||||
|
is_recompute: bool = False,
|
||||||
|
embedding_model_name: str = "colvision",
|
||||||
|
) -> None:
|
||||||
|
self.index_path = index_path
|
||||||
|
self.dim = dim
|
||||||
|
self.embedding_model_name = embedding_model_name
|
||||||
|
self._pending_items: list[dict] = []
|
||||||
|
self._backend_kwargs = {
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"M": m,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
"is_compact": is_compact,
|
||||||
|
"is_recompute": is_recompute,
|
||||||
|
}
|
||||||
|
self._labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
def _meta_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": self.embedding_model_name,
|
||||||
|
"embedding_mode": "custom",
|
||||||
|
"dimensions": self.dim,
|
||||||
|
"backend_kwargs": self._backend_kwargs,
|
||||||
|
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||||
|
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||||
|
and self._backend_kwargs.get("is_recompute", True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_collection(self) -> None:
|
||||||
|
path = Path(self.index_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def insert(self, data: dict) -> None:
|
||||||
|
self._pending_items.append(
|
||||||
|
{
|
||||||
|
"doc_id": int(data["doc_id"]),
|
||||||
|
"filepath": data.get("filepath", ""),
|
||||||
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _labels_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||||
|
|
||||||
|
def _meta_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def create_index(self) -> None:
|
||||||
|
if not self._pending_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings: list[np.ndarray] = []
|
||||||
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
for item in self._pending_items:
|
||||||
|
doc_id = int(item["doc_id"])
|
||||||
|
filepath = item.get("filepath", "")
|
||||||
|
colbert_vecs = item["colbert_vecs"]
|
||||||
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
|
embeddings.append(vec_np)
|
||||||
|
labels_meta.append(
|
||||||
|
{
|
||||||
|
"id": f"{doc_id}:{seq_id}",
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"seq_id": int(seq_id),
|
||||||
|
"filepath": filepath,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
|
# print shape of embeddings_np
|
||||||
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
|
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||||
|
builder.build(embeddings_np, ids, self.index_path)
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(self._meta_dict(), f, indent=2)
|
||||||
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
|
if self._labels_meta:
|
||||||
|
return
|
||||||
|
labels_path = self._labels_path()
|
||||||
|
if labels_path.exists():
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = raw.get("labels")
|
||||||
|
distances = raw.get("distances")
|
||||||
|
if labels is None or distances is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc_scores: dict[int, float] = {}
|
||||||
|
B = len(labels)
|
||||||
|
for b in range(B):
|
||||||
|
per_doc_best: dict[int, float] = {}
|
||||||
|
for k, sid in enumerate(labels[b]):
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
score = float(distances[b][k])
|
||||||
|
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||||
|
per_doc_best[doc_id] = score
|
||||||
|
for doc_id, best_score in per_doc_best.items():
|
||||||
|
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||||
|
|
||||||
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
@@ -0,0 +1,477 @@
|
|||||||
|
## Jupyter-style notebook script
|
||||||
|
# %%
|
||||||
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Config
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||||
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
|
USE_HF_DATASET: bool = True
|
||||||
|
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||||
|
DATASET_SPLIT: str = "train"
|
||||||
|
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||||
|
|
||||||
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
|
PAGES_DIR: str = "./pages"
|
||||||
|
|
||||||
|
# Index + retrieval settings
|
||||||
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
|
TOPK: int = 1
|
||||||
|
FIRST_STAGE_K: int = 500
|
||||||
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
|
# Artifacts
|
||||||
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
|
SIMILARITY_MAP: bool = True
|
||||||
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
|
ANSWER: bool = True
|
||||||
|
MAX_NEW_TOKENS: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Helpers
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 1: Prepare data
|
||||||
|
if USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(N), desc="Loading dataset"):
|
||||||
|
p = dataset[i]
|
||||||
|
# Compose a descriptive identifier for printing later
|
||||||
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
print(identifier)
|
||||||
|
filepaths.append(identifier)
|
||||||
|
images.append(p["page_image"]) # PIL Image
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load model and processor
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Build or load index
|
||||||
|
retriever: Optional[LeannMultiVector] = None
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
try:
|
||||||
|
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||||
|
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||||
|
except Exception:
|
||||||
|
retriever = None
|
||||||
|
|
||||||
|
if retriever is None:
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 4: Embed query and search
|
||||||
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
if not results:
|
||||||
|
print("No results found.")
|
||||||
|
else:
|
||||||
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
top_images: list[Image.Image] = []
|
||||||
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
|
path = filepaths[doc_id]
|
||||||
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
|
top_images.append(images[doc_id])
|
||||||
|
|
||||||
|
if SAVE_TOP_IMAGE:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
base = _Path(SAVE_TOP_IMAGE)
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if base.suffix:
|
||||||
|
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||||
|
else:
|
||||||
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
|
img.save(str(out_path))
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 5: Similarity maps for top-K results
|
||||||
|
if results and SIMILARITY_MAP:
|
||||||
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if output_base:
|
||||||
|
if output_base.suffix:
|
||||||
|
out_dir = output_base.parent
|
||||||
|
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||||
|
out_path = str(out_dir / out_name)
|
||||||
|
else:
|
||||||
|
out_dir = output_base
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||||
|
else:
|
||||||
|
out_path = None
|
||||||
|
chosen_idx, max_sim = _generate_similarity_map(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
image=img,
|
||||||
|
query=QUERY,
|
||||||
|
token_idx=token_idx,
|
||||||
|
output_path=out_path,
|
||||||
|
)
|
||||||
|
if out_path:
|
||||||
|
print(
|
||||||
|
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 6: Optional answer generation
|
||||||
|
if results and ANSWER:
|
||||||
|
qwen = QwenVL(device=device_str)
|
||||||
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
print("\nAnswer:")
|
||||||
|
print(response)
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
# pip install pdf2image
|
||||||
|
# pip install pymilvus
|
||||||
|
# pip install colpali_engine
|
||||||
|
# pip install tqdm
|
||||||
|
# pip install pillow
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
|
||||||
|
pdf_path = "pdfs/2004.12832v2.pdf"
|
||||||
|
images = convert_from_path(pdf_path)
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(f"pages/page_{i + 1}.png", "PNG")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make local leann packages importable without installing
|
||||||
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector
|
||||||
|
|
||||||
|
|
||||||
|
class LeannRetriever(LeannMultiVector):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||||
|
_device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(_device_str)
|
||||||
|
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||||
|
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=_dtype,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"How to end-to-end retrieval with ColBert",
|
||||||
|
"Where is ColBERT performance Table, including text representation results?",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
qs: list[torch.Tensor] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
print(qs[0].shape)
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
ds: list[torch.Tensor] = []
|
||||||
|
for batch_doc in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
|
||||||
|
print(ds[0].shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Build HNSW index via LeannRetriever primitives and run search
|
||||||
|
index_path = "./indexes/colpali.leann"
|
||||||
|
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
|
retriever.create_collection()
|
||||||
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": ds[i].float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
for query in qs:
|
||||||
|
query_np = query.float().numpy()
|
||||||
|
result = retriever.search(query_np, topk=1)
|
||||||
|
print(filepaths[result[0][1]])
|
||||||
@@ -26,6 +26,21 @@ leann build my-code-index --docs ./src --use-ast-chunking
|
|||||||
uv pip install -e "."
|
uv pip install -e "."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
### When to Use AST Chunking
|
### When to Use AST Chunking
|
||||||
|
|||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
@@ -1,11 +1,11 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
|
||||||
build-backend = "scikit_build_core.build"
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: c593831474...19f9603c72
@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable additional SIMD versions to speed up compilation
|
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# ARM64-specific configuration
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||||
|
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||||
|
else()
|
||||||
|
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||||
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||||
|
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.2",
|
"leann-core==0.3.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: a0361858fc...ed96ff7dba
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -653,6 +655,7 @@ class LeannSearcher:
|
|||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
@@ -679,6 +682,10 @@ class LeannSearcher:
|
|||||||
Returns:
|
Returns:
|
||||||
List of SearchResult objects with text, metadata, and similarity scores
|
List of SearchResult objects with text, metadata, and similarity scores
|
||||||
"""
|
"""
|
||||||
|
# Handle grep search
|
||||||
|
if use_grep:
|
||||||
|
return self._grep_search(query, top_k)
|
||||||
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
@@ -795,9 +802,96 @@ class LeannSearcher:
|
|||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def _find_jsonl_file(self) -> Optional[str]:
|
||||||
|
"""Find the .jsonl file containing raw passages for grep search"""
|
||||||
|
index_path = Path(self.meta_path_str).parent
|
||||||
|
potential_files = [
|
||||||
|
index_path / "documents.leann.passages.jsonl",
|
||||||
|
index_path.parent / "documents.leann.passages.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in potential_files:
|
||||||
|
if file_path.exists():
|
||||||
|
return str(file_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Perform grep-based search on raw passages"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl passages file found for grep search")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = ["grep", "-i", "-n", query, jsonl_file]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
return []
|
||||||
|
elif result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Grep failed: {result.stderr}")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(parts[1])
|
||||||
|
text = data.get("text", "")
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", parts[0]),
|
||||||
|
text=text,
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"grep command not found. Please install grep or use semantic search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Fallback regex search"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl file found")
|
||||||
|
|
||||||
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
with open(jsonl_file, encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if pattern.search(line):
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", str(line_num)),
|
||||||
|
text=data.get("text", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(len(pattern.findall(data.get("text", "")))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Explicitly cleanup embedding server resources.
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
This method should be called after you're done using the searcher,
|
This method should be called after you're done using the searcher,
|
||||||
especially in test environments or batch processing scenarios.
|
especially in test environments or batch processing scenarios.
|
||||||
"""
|
"""
|
||||||
@@ -853,6 +947,7 @@ class LeannChat:
|
|||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Enhanced chunking utilities with AST-aware code chunking support.
|
Enhanced chunking utilities with AST-aware code chunking support.
|
||||||
Provides unified interface for both traditional and AST-based text chunking.
|
Packaged within leann-core so installed wheels can import it reliably.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -22,30 +22,9 @@ CODE_EXTENSIONS = {
|
|||||||
".jsx": "typescript",
|
".jsx": "typescript",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default chunk parameters for different content types
|
|
||||||
DEFAULT_CHUNK_PARAMS = {
|
|
||||||
"code": {
|
|
||||||
"max_chunk_size": 512,
|
|
||||||
"chunk_overlap": 64,
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"chunk_size": 256,
|
|
||||||
"chunk_overlap": 128,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||||
"""
|
"""Separate documents into code files and regular text files."""
|
||||||
Separate documents into code files and regular text files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (code_documents, text_documents)
|
|
||||||
"""
|
|
||||||
if code_extensions is None:
|
if code_extensions is None:
|
||||||
code_extensions = CODE_EXTENSIONS
|
code_extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
@@ -53,16 +32,10 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
text_docs = []
|
text_docs = []
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Get file path from metadata
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if not file_path:
|
|
||||||
# Fallback to file_name
|
|
||||||
file_path = doc.metadata.get("file_name", "")
|
|
||||||
|
|
||||||
if file_path:
|
if file_path:
|
||||||
file_ext = Path(file_path).suffix.lower()
|
file_ext = Path(file_path).suffix.lower()
|
||||||
if file_ext in code_extensions:
|
if file_ext in code_extensions:
|
||||||
# Add language info to metadata
|
|
||||||
doc.metadata["language"] = code_extensions[file_ext]
|
doc.metadata["language"] = code_extensions[file_ext]
|
||||||
doc.metadata["is_code"] = True
|
doc.metadata["is_code"] = True
|
||||||
code_docs.append(doc)
|
code_docs.append(doc)
|
||||||
@@ -70,7 +43,6 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
doc.metadata["is_code"] = False
|
doc.metadata["is_code"] = False
|
||||||
text_docs.append(doc)
|
text_docs.append(doc)
|
||||||
else:
|
else:
|
||||||
# If no file path, treat as text
|
|
||||||
doc.metadata["is_code"] = False
|
doc.metadata["is_code"] = False
|
||||||
text_docs.append(doc)
|
text_docs.append(doc)
|
||||||
|
|
||||||
@@ -79,7 +51,7 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
|
|
||||||
|
|
||||||
def get_language_from_extension(file_path: str) -> Optional[str]:
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||||
"""Get the programming language from file extension."""
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||||
ext = Path(file_path).suffix.lower()
|
ext = Path(file_path).suffix.lower()
|
||||||
return CODE_EXTENSIONS.get(ext)
|
return CODE_EXTENSIONS.get(ext)
|
||||||
|
|
||||||
@@ -90,40 +62,26 @@ def create_ast_chunks(
|
|||||||
chunk_overlap: int = 64,
|
chunk_overlap: int = 64,
|
||||||
metadata_template: str = "default",
|
metadata_template: str = "default",
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
Create AST-aware chunks from code documents using astchunk.
|
|
||||||
|
|
||||||
Args:
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
documents: List of code documents
|
|
||||||
max_chunk_size: Maximum characters per chunk
|
|
||||||
chunk_overlap: Number of AST nodes to overlap between chunks
|
|
||||||
metadata_template: Template for chunk metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks with preserved code structure
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from astchunk import ASTChunkBuilder
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"astchunk not available: {e}")
|
logger.error(f"astchunk not available: {e}")
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Get language from metadata (set by detect_code_files)
|
|
||||||
language = doc.metadata.get("language")
|
language = doc.metadata.get("language")
|
||||||
if not language:
|
if not language:
|
||||||
logger.warning(
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
"No language detected for document, falling back to traditional chunking"
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
)
|
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Configure astchunk
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
"language": language,
|
"language": language,
|
||||||
@@ -131,7 +89,6 @@ def create_ast_chunks(
|
|||||||
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add repository-level metadata if available
|
|
||||||
repo_metadata = {
|
repo_metadata = {
|
||||||
"file_path": doc.metadata.get("file_path", ""),
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
@@ -140,17 +97,13 @@ def create_ast_chunks(
|
|||||||
}
|
}
|
||||||
configs["repo_level_metadata"] = repo_metadata
|
configs["repo_level_metadata"] = repo_metadata
|
||||||
|
|
||||||
# Create chunk builder and process
|
|
||||||
chunk_builder = ASTChunkBuilder(**configs)
|
chunk_builder = ASTChunkBuilder(**configs)
|
||||||
code_content = doc.get_content()
|
code_content = doc.get_content()
|
||||||
|
|
||||||
if not code_content or not code_content.strip():
|
if not code_content or not code_content.strip():
|
||||||
logger.warning("Empty code content, skipping")
|
logger.warning("Empty code content, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
|
|
||||||
# Extract text content from chunks
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = chunk.text
|
||||||
@@ -159,7 +112,6 @@ def create_ast_chunks(
|
|||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
else:
|
else:
|
||||||
# Try to convert to string
|
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
if chunk_text and chunk_text.strip():
|
||||||
@@ -168,12 +120,10 @@ def create_ast_chunks(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
logger.info("Falling back to traditional chunking")
|
logger.info("Falling back to traditional chunking")
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|
||||||
@@ -181,23 +131,10 @@ def create_ast_chunks(
|
|||||||
def create_traditional_chunks(
|
def create_traditional_chunks(
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
Create traditional text chunks using LlamaIndex SentenceSplitter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of documents to chunk
|
|
||||||
chunk_size: Size of each chunk in characters
|
|
||||||
chunk_overlap: Overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
# Handle invalid chunk_size values
|
|
||||||
if chunk_size <= 0:
|
if chunk_size <= 0:
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
chunk_size = 256
|
chunk_size = 256
|
||||||
|
|
||||||
# Ensure chunk_overlap is not negative and not larger than chunk_size
|
|
||||||
if chunk_overlap < 0:
|
if chunk_overlap < 0:
|
||||||
chunk_overlap = 0
|
chunk_overlap = 0
|
||||||
if chunk_overlap >= chunk_size:
|
if chunk_overlap >= chunk_size:
|
||||||
@@ -215,12 +152,9 @@ def create_traditional_chunks(
|
|||||||
try:
|
try:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
if nodes:
|
if nodes:
|
||||||
chunk_texts = [node.get_content() for node in nodes]
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
all_texts.extend(chunk_texts)
|
|
||||||
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
# As last resort, add the raw content
|
|
||||||
content = doc.get_content()
|
content = doc.get_content()
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
all_texts.append(content.strip())
|
all_texts.append(content.strip())
|
||||||
@@ -238,32 +172,13 @@ def create_text_chunks(
|
|||||||
code_file_extensions: Optional[list[str]] = None,
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
ast_fallback_traditional: bool = True,
|
ast_fallback_traditional: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
Create text chunks from documents with optional AST support for code files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
chunk_size: Size for traditional text chunks
|
|
||||||
chunk_overlap: Overlap for traditional text chunks
|
|
||||||
use_ast_chunking: Whether to use AST chunking for code files
|
|
||||||
ast_chunk_size: Size for AST chunks
|
|
||||||
ast_chunk_overlap: Overlap for AST chunks
|
|
||||||
code_file_extensions: Custom list of code file extensions
|
|
||||||
ast_fallback_traditional: Fall back to traditional chunking on AST errors
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("No documents provided for chunking")
|
logger.warning("No documents provided for chunking")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a local copy of supported extensions for this function call
|
|
||||||
local_code_extensions = CODE_EXTENSIONS.copy()
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||||
|
|
||||||
# Update supported extensions if provided
|
|
||||||
if code_file_extensions:
|
if code_file_extensions:
|
||||||
# Map extensions to languages (simplified mapping)
|
|
||||||
ext_mapping = {
|
ext_mapping = {
|
||||||
".py": "python",
|
".py": "python",
|
||||||
".java": "java",
|
".java": "java",
|
||||||
@@ -273,47 +188,32 @@ def create_text_chunks(
|
|||||||
}
|
}
|
||||||
for ext in code_file_extensions:
|
for ext in code_file_extensions:
|
||||||
if ext.lower() not in local_code_extensions:
|
if ext.lower() not in local_code_extensions:
|
||||||
# Try to guess language from extension
|
|
||||||
if ext.lower() in ext_mapping:
|
if ext.lower() in ext_mapping:
|
||||||
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
|
|
||||||
if use_ast_chunking:
|
if use_ast_chunking:
|
||||||
# Separate code and text documents using local extensions
|
|
||||||
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||||
|
|
||||||
# Process code files with AST chunking
|
|
||||||
if code_docs:
|
if code_docs:
|
||||||
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
|
|
||||||
try:
|
try:
|
||||||
ast_chunks = create_ast_chunks(
|
all_chunks.extend(
|
||||||
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
create_ast_chunks(
|
||||||
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||||
|
)
|
||||||
)
|
)
|
||||||
all_chunks.extend(ast_chunks)
|
|
||||||
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"AST chunking failed: {e}")
|
logger.error(f"AST chunking failed: {e}")
|
||||||
if ast_fallback_traditional:
|
if ast_fallback_traditional:
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
all_chunks.extend(
|
||||||
traditional_code_chunks = create_traditional_chunks(
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
code_docs, chunk_size, chunk_overlap
|
|
||||||
)
|
)
|
||||||
all_chunks.extend(traditional_code_chunks)
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process text files with traditional chunking
|
|
||||||
if text_docs:
|
if text_docs:
|
||||||
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(text_chunks)
|
|
||||||
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
|
|
||||||
else:
|
else:
|
||||||
# Use traditional chunking for all files
|
|
||||||
logger.info(f"Processing {len(documents)} documents with traditional chunking")
|
|
||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@@ -1216,13 +1215,8 @@ Examples:
|
|||||||
if use_ast:
|
if use_ast:
|
||||||
print("🧠 Using AST-aware chunking for code files")
|
print("🧠 Using AST-aware chunking for code files")
|
||||||
try:
|
try:
|
||||||
# Import enhanced chunking utilities
|
# Import enhanced chunking utilities from packaged module
|
||||||
# Add apps directory to path to import chunking utilities
|
from .chunking_utils import create_text_chunks
|
||||||
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
|
|
||||||
if apps_dir.exists():
|
|
||||||
sys.path.insert(0, str(apps_dir))
|
|
||||||
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
# Use enhanced chunking with AST support
|
# Use enhanced chunking with AST support
|
||||||
all_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
@@ -1237,7 +1231,9 @@ Examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking")
|
print(
|
||||||
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
|
)
|
||||||
use_ast = False
|
use_ast = False
|
||||||
|
|
||||||
if not use_ast:
|
if not use_ast:
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||||
|
|
||||||
|
For agent-facing discovery details, see `llms.txt` in the repository root.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Install LEANN globally for MCP integration (with default backend):
|
Install LEANN globally for MCP integration (with default backend):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -99,11 +99,16 @@ wechat-exporter = "wechat_exporter.main:main"
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = ["third_party"]
|
extend-exclude = [
|
||||||
|
"third_party",
|
||||||
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
|
||||||
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
|
|||||||
43
uv.lock
generated
43
uv.lock
generated
@@ -201,7 +201,7 @@ wheels = [
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "astchunk"
|
name = "astchunk"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { editable = "packages/astchunk-leann" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" },
|
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" },
|
||||||
@@ -214,10 +214,31 @@ dependencies = [
|
|||||||
{ name = "tree-sitter-python" },
|
{ name = "tree-sitter-python" },
|
||||||
{ name = "tree-sitter-typescript" },
|
{ name = "tree-sitter-typescript" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/db/2a/7a35e2fac7d550265ae2ee40651425083b37555f921d1a1b77c3f525e0df/astchunk-0.1.0.tar.gz", hash = "sha256:f4dff0ef8b3b3bcfeac363384db1e153f74d4c825dc2e35864abfab027713be4", size = 18093, upload-time = "2025-06-19T04:37:25.34Z" }
|
|
||||||
wheels = [
|
[package.metadata]
|
||||||
{ url = "https://files.pythonhosted.org/packages/be/84/5433ab0e933b572750cb16fd7edf3d6c7902b069461a22ec670042752a4d/astchunk-0.1.0-py3-none-any.whl", hash = "sha256:33ada9fc3620807fdda5846fa1948af463f281a60e0d43d4f3782b6dbb416d24", size = 15396, upload-time = "2025-06-19T04:37:23.87Z" },
|
requires-dist = [
|
||||||
|
{ name = "black", marker = "extra == 'dev'", specifier = ">=22.0.0" },
|
||||||
|
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=5.0.0" },
|
||||||
|
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.10.0" },
|
||||||
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.0.0" },
|
||||||
|
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=0.18.0" },
|
||||||
|
{ name = "numpy", specifier = ">=1.20.0" },
|
||||||
|
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=2.20.0" },
|
||||||
|
{ name = "pyrsistent", specifier = ">=0.18.0" },
|
||||||
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||||
|
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" },
|
||||||
|
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" },
|
||||||
|
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" },
|
||||||
|
{ name = "pytest-xdist", marker = "extra == 'test'", specifier = ">=2.5.0" },
|
||||||
|
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=5.0.0" },
|
||||||
|
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=1.0.0" },
|
||||||
|
{ name = "tree-sitter", specifier = ">=0.20.0" },
|
||||||
|
{ name = "tree-sitter-c-sharp", specifier = ">=0.20.0" },
|
||||||
|
{ name = "tree-sitter-java", specifier = ">=0.20.0" },
|
||||||
|
{ name = "tree-sitter-python", specifier = ">=0.20.0" },
|
||||||
|
{ name = "tree-sitter-typescript", specifier = ">=0.20.0" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["dev", "docs", "test"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "asttokens"
|
name = "asttokens"
|
||||||
@@ -1564,7 +1585,7 @@ name = "importlib-metadata"
|
|||||||
version = "8.7.0"
|
version = "8.7.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "zipp" },
|
{ name = "zipp", marker = "python_full_version < '3.10'" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
@@ -2117,7 +2138,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
source = { editable = "packages/leann-backend-diskann" }
|
source = { editable = "packages/leann-backend-diskann" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "leann-core" },
|
{ name = "leann-core" },
|
||||||
@@ -2129,14 +2150,14 @@ dependencies = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "leann-core", specifier = "==0.3.2" },
|
{ name = "leann-core", specifier = "==0.3.4" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "protobuf", specifier = ">=3.19.0" },
|
{ name = "protobuf", specifier = ">=3.19.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
source = { editable = "packages/leann-backend-hnsw" }
|
source = { editable = "packages/leann-backend-hnsw" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "leann-core" },
|
{ name = "leann-core" },
|
||||||
@@ -2149,7 +2170,7 @@ dependencies = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "leann-core", specifier = "==0.3.2" },
|
{ name = "leann-core", specifier = "==0.3.4" },
|
||||||
{ name = "msgpack", specifier = ">=1.0.0" },
|
{ name = "msgpack", specifier = ">=1.0.0" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "pyzmq", specifier = ">=23.0.0" },
|
{ name = "pyzmq", specifier = ">=23.0.0" },
|
||||||
@@ -2157,7 +2178,7 @@ requires-dist = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
source = { editable = "packages/leann-core" }
|
source = { editable = "packages/leann-core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "accelerate" },
|
{ name = "accelerate" },
|
||||||
@@ -2297,7 +2318,7 @@ test = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "astchunk", specifier = ">=0.1.0" },
|
{ name = "astchunk", editable = "packages/astchunk-leann" },
|
||||||
{ name = "beautifulsoup4", marker = "extra == 'documents'", specifier = ">=4.13.0" },
|
{ name = "beautifulsoup4", marker = "extra == 'documents'", specifier = ">=4.13.0" },
|
||||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.0" },
|
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.0" },
|
||||||
{ name = "boto3" },
|
{ name = "boto3" },
|
||||||
|
|||||||
Reference in New Issue
Block a user