diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 2d4662e..dd8ae93 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -29,7 +29,7 @@ node_parser = DoclingNodeParser( ) print("Loading documents...") documents = SimpleDirectoryReader( - "examples/pangu", + "examples/data", recursive=True, file_extractor=file_extractor, encoding="utf-8", @@ -42,7 +42,7 @@ for doc in documents: for node in nodes: all_texts.append(node.get_content()) -INDEX_DIR = Path("./test_pdf_index_pangu") +INDEX_DIR = Path("./test_pdf_index_pangu_hnsw") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") if not INDEX_DIR.exists(): @@ -52,7 +52,7 @@ if not INDEX_DIR.exists(): # CSR compact mode with recompute builder = LeannBuilder( - backend_name="hnsw", + backend_name="diskann", embedding_model="facebook/contriever", graph_degree=32, complexity=64, diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index e838014..e41c6b6 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -330,7 +330,7 @@ class HNSWSearcher(LeannBackendSearcherInterface): """Search using HNSW index with optional recompute functionality""" from . import faiss - ef = kwargs.get("ef", 200) + ef = kwargs.get("complexity", 200) if self.is_pruned: print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.") diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 014264a..9c851d6 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -1,3 +1,4 @@ +import torch from .registry import BACKEND_REGISTRY from .interface import LeannBackendFactoryInterface from typing import List, Dict, Any, Optional @@ -34,6 +35,12 @@ def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray: model = SentenceTransformer(model_name) model = model.half() print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...") + # use acclerater GPU or MAC GPU + import torch + if torch.cuda.is_available(): + model = model.to("cuda") + elif torch.backends.mps.is_available(): + model = model.to("mps") embeddings = model.encode(chunks, show_progress_bar=True) return np.asarray(embeddings, dtype=np.float32)