diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 144e858..094f136 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -46,6 +46,7 @@ def compute_embeddings( - "sentence-transformers": Use sentence-transformers library (default) - "mlx": Use MLX backend for Apple Silicon - "openai": Use OpenAI embedding API + - "gemini": Use Google Gemini embedding API use_server: Whether to use embedding server (True for search, False for build) Returns: diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py index 11bbcee..a428462 100644 --- a/packages/leann-core/src/leann/chat.py +++ b/packages/leann-core/src/leann/chat.py @@ -680,6 +680,52 @@ class HFChat(LLMInterface): return response.strip() +class GeminiChat(LLMInterface): + """LLM interface for Google Gemini models.""" + + def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None): + self.model = model + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + + if not self.api_key: + raise ValueError( + "Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter." + ) + + logger.info(f"Initializing Gemini Chat with model='{model}'") + + try: + import google.genai as genai + + self.client = genai.Client(api_key=self.api_key) + except ImportError: + raise ImportError( + "The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'." + ) + + def ask(self, prompt: str, **kwargs) -> str: + logger.info(f"Sending request to Gemini with model {self.model}") + + try: + # Set generation configuration + generation_config = { + "temperature": kwargs.get("temperature", 0.7), + "max_output_tokens": kwargs.get("max_tokens", 1000), + } + + # Handle top_p parameter + if "top_p" in kwargs: + generation_config["top_p"] = kwargs["top_p"] + + response = self.client.models.generate_content( + model=self.model, contents=prompt, config=generation_config + ) + return response.text.strip() + except Exception as e: + logger.error(f"Error communicating with Gemini: {e}") + return f"Error: Could not get a response from Gemini. Details: {e}" + + class OpenAIChat(LLMInterface): """LLM interface for OpenAI models.""" @@ -793,6 +839,8 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface: return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") elif llm_type == "openai": return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) + elif llm_type == "gemini": + return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) elif llm_type == "simulated": return SimulatedChat() else: diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 9cce58c..add435c 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -57,6 +57,8 @@ def compute_embeddings( return compute_embeddings_mlx(texts, model_name) elif mode == "ollama": return compute_embeddings_ollama(texts, model_name, is_build=is_build) + elif mode == "gemini": + return compute_embeddings_gemini(texts, model_name, is_build=is_build) else: raise ValueError(f"Unsupported embedding mode: {mode}") @@ -658,3 +660,83 @@ def compute_embeddings_ollama( logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") return embeddings + + +def compute_embeddings_gemini( + texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False +) -> np.ndarray: + """ + Compute embeddings using Google Gemini API. + + Args: + texts: List of texts to compute embeddings for + model_name: Gemini model name (default: "text-embedding-004") + is_build: Whether this is a build operation (shows progress bar) + + Returns: + Embeddings array, shape: (len(texts), embedding_dim) + """ + try: + import os + + import google.genai as genai + except ImportError as e: + raise ImportError(f"Google GenAI package not installed: {e}") + + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + raise RuntimeError("GEMINI_API_KEY environment variable not set") + + # Cache Gemini client + cache_key = "gemini_client" + if cache_key in _model_cache: + client = _model_cache[cache_key] + else: + client = genai.Client(api_key=api_key) + _model_cache[cache_key] = client + logger.info("Gemini client cached") + + logger.info( + f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'" + ) + + # Gemini supports batch embedding + max_batch_size = 100 # Conservative batch size for Gemini + all_embeddings = [] + + try: + from tqdm import tqdm + + total_batches = (len(texts) + max_batch_size - 1) // max_batch_size + batch_range = range(0, len(texts), max_batch_size) + batch_iterator = tqdm( + batch_range, desc="Computing embeddings", unit="batch", total=total_batches + ) + except ImportError: + # Fallback when tqdm is not available + batch_iterator = range(0, len(texts), max_batch_size) + + for i in batch_iterator: + batch_texts = texts[i : i + max_batch_size] + + try: + # Use the embed_content method from the new Google GenAI SDK + response = client.models.embed_content( + model=model_name, + contents=batch_texts, + config=genai.types.EmbedContentConfig( + task_type="RETRIEVAL_DOCUMENT" # For document embedding + ), + ) + + # Extract embeddings from response + for embedding_data in response.embeddings: + all_embeddings.append(embedding_data.values) + except Exception as e: + logger.error(f"Batch {i} failed: {e}") + raise + + embeddings = np.array(all_embeddings, dtype=np.float32) + logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") + + return embeddings