Compare commits
1 Commits
fix/empty-
...
feature/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61b1691448 |
@@ -46,6 +46,7 @@ def compute_embeddings(
|
|||||||
- "sentence-transformers": Use sentence-transformers library (default)
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
|
- "gemini": Use Google Gemini embedding API
|
||||||
use_server: Whether to use embedding server (True for search, False for build)
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -306,23 +307,6 @@ class LeannBuilder:
|
|||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
|
|
||||||
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
|
|
||||||
valid_chunks: list[dict[str, Any]] = []
|
|
||||||
skipped = 0
|
|
||||||
for chunk in self.chunks:
|
|
||||||
text = chunk.get("text", "")
|
|
||||||
if isinstance(text, str) and text.strip():
|
|
||||||
valid_chunks.append(chunk)
|
|
||||||
else:
|
|
||||||
skipped += 1
|
|
||||||
if skipped > 0:
|
|
||||||
print(
|
|
||||||
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
|
|
||||||
)
|
|
||||||
self.chunks = valid_chunks
|
|
||||||
if not self.chunks:
|
|
||||||
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
|
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(
|
compute_embeddings(
|
||||||
|
|||||||
@@ -680,6 +680,52 @@ class HFChat(LLMInterface):
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiChat(LLMInterface):
|
||||||
|
"""LLM interface for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Initializing Gemini Chat with model='{model}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.genai as genai
|
||||||
|
|
||||||
|
self.client = genai.Client(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Gemini with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set generation configuration
|
||||||
|
generation_config = {
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
"max_output_tokens": kwargs.get("max_tokens", 1000),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle top_p parameter
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
generation_config["top_p"] = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=self.model, contents=prompt, config=generation_config
|
||||||
|
)
|
||||||
|
return response.text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Gemini: {e}")
|
||||||
|
return f"Error: Could not get a response from Gemini. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
@@ -793,6 +839,8 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "gemini":
|
||||||
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ def compute_embeddings(
|
|||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
elif mode == "ollama":
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||||
|
elif mode == "gemini":
|
||||||
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
@@ -244,16 +246,6 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
# Validate input list
|
|
||||||
if not texts:
|
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
|
||||||
# Extra validation: abort early if any item is empty/whitespace
|
|
||||||
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
|
|
||||||
if invalid_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
|
||||||
)
|
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
@@ -668,3 +660,83 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_gemini(
|
||||||
|
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using Google Gemini API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Gemini model name (default: "text-embedding-004")
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import google.genai as genai
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Google GenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Cache Gemini client
|
||||||
|
cache_key = "gemini_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = genai.Client(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
logger.info("Gemini client cached")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gemini supports batch embedding
|
||||||
|
max_batch_size = 100 # Conservative batch size for Gemini
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the embed_content method from the new Google GenAI SDK
|
||||||
|
response = client.models.embed_content(
|
||||||
|
model=model_name,
|
||||||
|
contents=batch_texts,
|
||||||
|
config=genai.types.EmbedContentConfig(
|
||||||
|
task_type="RETRIEVAL_DOCUMENT" # For document embedding
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
for embedding_data in response.embeddings:
|
||||||
|
all_embeddings.append(embedding_data.values)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|||||||
Reference in New Issue
Block a user