add gpu chunk embedd and add complexity in hnsw
This commit is contained in:
@@ -29,7 +29,7 @@ node_parser = DoclingNodeParser(
|
||||
)
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/pangu",
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
file_extractor=file_extractor,
|
||||
encoding="utf-8",
|
||||
@@ -42,7 +42,7 @@ for doc in documents:
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_pangu")
|
||||
INDEX_DIR = Path("./test_pdf_index_pangu_hnsw")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
@@ -52,7 +52,7 @@ if not INDEX_DIR.exists():
|
||||
|
||||
# CSR compact mode with recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
|
||||
@@ -330,7 +330,7 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
"""Search using HNSW index with optional recompute functionality"""
|
||||
from . import faiss
|
||||
|
||||
ef = kwargs.get("ef", 200)
|
||||
ef = kwargs.get("complexity", 200)
|
||||
|
||||
if self.is_pruned:
|
||||
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from typing import List, Dict, Any, Optional
|
||||
@@ -34,6 +35,12 @@ def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
model = SentenceTransformer(model_name)
|
||||
model = model.half()
|
||||
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
|
||||
# use acclerater GPU or MAC GPU
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
embeddings = model.encode(chunks, show_progress_bar=True)
|
||||
|
||||
return np.asarray(embeddings, dtype=np.float32)
|
||||
|
||||
Reference in New Issue
Block a user