fix: build with direct embedding
This commit is contained in:
@@ -141,6 +141,14 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
raise RuntimeError("Index is pruned but recompute is disabled.")
|
raise RuntimeError("Index is pruned but recompute is disabled.")
|
||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
|
# Load label mapping
|
||||||
|
label_map_file = self.index_dir / "leann.labels.map"
|
||||||
|
if not label_map_file.exists():
|
||||||
|
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||||
|
|
||||||
|
with open(label_map_file, "rb") as f:
|
||||||
|
self.label_map = pickle.load(f)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ from .chat import get_llm
|
|||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
chunks: List[str],
|
chunks: List[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers"
|
mode: str = "sentence-transformers",
|
||||||
|
use_server: bool = True
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -32,6 +33,7 @@ def compute_embeddings(
|
|||||||
- "sentence-transformers": Use sentence-transformers library (default)
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
numpy array of embeddings
|
numpy array of embeddings
|
||||||
@@ -45,13 +47,23 @@ def compute_embeddings(
|
|||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(chunks, model_name)
|
return compute_embeddings_openai(chunks, model_name)
|
||||||
elif mode == "sentence-transformers":
|
elif mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(chunks, model_name)
|
return compute_embeddings_sentence_transformers(chunks, model_name, use_server=use_server)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str, use_server: bool = True) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers via embedding server."""
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of text chunks to embed
|
||||||
|
model_name: Name of the sentence transformer model
|
||||||
|
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||||
|
"""
|
||||||
|
if not use_server:
|
||||||
|
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)...")
|
||||||
|
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
@@ -296,7 +308,7 @@ class LeannBuilder:
|
|||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0]
|
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode, use_server=False)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -323,7 +335,7 @@ class LeannBuilder:
|
|||||||
pickle.dump(offset_map, f)
|
pickle.dump(offset_map, f)
|
||||||
texts_to_embed = [c["text"] for c in self.chunks]
|
texts_to_embed = [c["text"] for c in self.chunks]
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts_to_embed, self.embedding_model, self.embedding_mode
|
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
|
|||||||
Reference in New Issue
Block a user