feat: openai embeddings

This commit is contained in:
Andy Lee
2025-07-17 17:02:47 -07:00
parent 90d9f27383
commit a13c527e39
6 changed files with 311 additions and 49 deletions

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
OpenAI Embedding Example
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
"""
import os
import dotenv
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
# Load environment variables
dotenv.load_dotenv()
def main():
# Check if OpenAI API key is available
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY environment variable not set")
return False
print(f"✅ OpenAI API key found: {api_key[:10]}...")
# Sample texts
sample_texts = [
"Machine learning is a powerful technology that enables computers to learn from data.",
"Natural language processing helps computers understand and generate human language.",
"Deep learning uses neural networks with multiple layers to solve complex problems.",
"Computer vision allows machines to interpret and understand visual information.",
"Reinforcement learning trains agents to make decisions through trial and error.",
"Data science combines statistics, math, and programming to extract insights from data.",
"Artificial intelligence aims to create machines that can perform human-like tasks.",
"Python is a popular programming language used extensively in data science and AI.",
"Neural networks are inspired by the structure and function of the human brain.",
"Big data refers to extremely large datasets that require special tools to process."
]
INDEX_DIR = Path("./simple_openai_test_index")
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
print(f"\n=== Building Index with OpenAI Embeddings ===")
print(f"Index path: {INDEX_PATH}")
try:
# Use proper configuration for OpenAI embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# HNSW settings for OpenAI embeddings
M=16, # Smaller graph degree
efConstruction=64, # Smaller construction complexity
is_compact=True, # Enable compact storage for recompute
is_recompute=True, # MUST enable for OpenAI embeddings
num_threads=1,
)
print(f"Adding {len(sample_texts)} texts to the index...")
for i, text in enumerate(sample_texts):
metadata = {"id": f"doc_{i}", "topic": "AI"}
builder.add_text(text, metadata)
print("Building index...")
builder.build_index(INDEX_PATH)
print(f"✅ Index built successfully!")
except Exception as e:
print(f"❌ Error building index: {e}")
import traceback
traceback.print_exc()
return False
print(f"\n=== Testing Search ===")
try:
searcher = LeannSearcher(INDEX_PATH)
test_queries = [
"What is machine learning?",
"How do neural networks work?",
"Programming languages for data science"
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = searcher.search(query, top_k=3)
print(f" Found {len(results)} results:")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result.score:.4f}")
print(f" Text: {result.text[:80]}...")
print(f"\n✅ Search test completed successfully!")
return True
except Exception as e:
print(f"❌ Error during search: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
if success:
print(f"\n🎉 Simple OpenAI index test completed successfully!")
else:
print(f"\n💥 Simple OpenAI index test failed!")

View File

@@ -162,7 +162,7 @@ def create_embedding_server_thread(
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
use_mlx: bool = False,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
@@ -182,10 +182,27 @@ def create_embedding_server_thread(
print(f"{RED}Port {zmq_port} is already in use{RESET}")
return
if use_mlx:
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
embedding_mode = "openai"
if embedding_mode == "mlx":
from leann.api import compute_embeddings_mlx
import torch
print("INFO: Using MLX for embeddings")
else:
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
import torch
print("INFO: Using OpenAI API for embeddings")
# Set device to CPU for compatibility with DeviceTimer class
device = torch.device("cpu")
cuda_available = False
mps_available = False
elif embedding_mode == "sentence-transformers":
# 初始化模型
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
import torch
@@ -216,6 +233,8 @@ def create_embedding_server_thread(
print(f"INFO: Using FP16 precision with model: {model_name}")
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
else:
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
@@ -303,7 +322,7 @@ def create_embedding_server_thread(
self.start_time = 0
self.end_time = 0
if not use_mlx and torch.cuda.is_available():
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
else:
@@ -317,25 +336,25 @@ def create_embedding_server_thread(
self.end()
def start(self):
if not use_mlx and torch.cuda.is_available():
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
torch.cuda.synchronize()
self.start_event.record()
else:
if not use_mlx and self.device.type == "mps":
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.start_time = time.time()
def end(self):
if not use_mlx and torch.cuda.is_available():
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
self.end_event.record()
torch.cuda.synchronize()
else:
if not use_mlx and self.device.type == "mps":
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
torch.mps.synchronize()
self.end_time = time.time()
def elapsed_time(self):
if not use_mlx and torch.cuda.is_available():
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
return self.start_event.elapsed_time(self.end_event) / 1000.0
else:
return self.end_time - self.start_time
@@ -571,13 +590,15 @@ def create_embedding_server_thread(
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
if use_mlx:
if embedding_mode == "mlx":
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name)
else:
elif embedding_mode == "openai":
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
else: # sentence-transformers
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
all_embeddings.append(embeddings_chunk)
if not use_mlx:
if embedding_mode == "sentence-transformers":
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
@@ -586,9 +607,11 @@ def create_embedding_server_thread(
hidden = np.vstack(all_embeddings)
print(f"INFO: Combined embeddings shape: {hidden.shape}")
else:
if use_mlx:
if embedding_mode == "mlx":
hidden = compute_embeddings_mlx(texts, model_name)
else:
elif embedding_mode == "openai":
hidden = compute_embeddings_openai(texts, model_name)
else: # sentence-transformers
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
# 序列化响应
@@ -610,7 +633,7 @@ def create_embedding_server_thread(
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
if not use_mlx:
if embedding_mode == "sentence-transformers":
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
@@ -653,14 +676,14 @@ def create_embedding_server(
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
use_mlx: bool = False,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, use_mlx, enable_warmup)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
if __name__ == "__main__":
@@ -677,9 +700,17 @@ if __name__ == "__main__":
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings")
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
choices=["sentence-transformers", "mlx", "openai"],
help="Embedding backend mode")
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
args = parser.parse_args()
# Handle backward compatibility with use_mlx
embedding_mode = args.embedding_mode
if args.use_mlx:
embedding_mode = "mlx"
create_embedding_server(
domain=args.domain,
@@ -693,6 +724,6 @@ if __name__ == "__main__":
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
use_mlx=args.use_mlx,
embedding_mode=embedding_mode,
enable_warmup=not args.disable_warmup,
)

View File

@@ -150,7 +150,7 @@ def create_hnsw_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
use_mlx: bool = False,
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = False,
):
"""
@@ -170,13 +170,22 @@ def create_hnsw_embedding_server(
distance_metric: The distance metric to use
enable_warmup: Whether to perform warmup requests on server start
"""
if not use_mlx:
# Handle different embedding modes directly in HNSW server
# Auto-detect mode based on model name if not explicitly set
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
embedding_mode = "openai"
if embedding_mode == "openai":
print(f"Using OpenAI API mode for {model_name}")
tokenizer = None # No local tokenizer needed for OpenAI API
elif embedding_mode == "mlx":
print(f"Using MLX mode for {model_name}")
tokenizer = None # MLX handles tokenization separately
else: # sentence-transformers
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
else:
print("Using MLX mode - tokenizer will be loaded separately")
tokenizer = None
# Device setup
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
@@ -199,15 +208,17 @@ def create_hnsw_embedding_server(
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)")
if use_mlx:
if embedding_mode == "mlx":
# For MLX models, we need to use the MLX embedding computation
print("MLX model detected - using MLX backend for embeddings")
model = None # We'll handle MLX separately
tokenizer = None
elif embedding_mode == "openai":
# For OpenAI API, no local model needed
print("OpenAI API mode - no local model loading required")
model = None
else:
# Use standard transformers for non-MLX models
# Use standard transformers for sentence-transformers models
model = AutoModel.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Model {model_name} loaded successfully!")
# Check port availability
@@ -355,9 +366,12 @@ def create_hnsw_embedding_server(
def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings"""
# Handle MLX models separately
if use_mlx:
# Handle different embedding modes
if embedding_mode == "mlx":
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
elif embedding_mode == "openai":
from leann.api import compute_embeddings_openai
return compute_embeddings_openai(texts_batch, model_name)
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
@@ -795,14 +809,33 @@ def create_hnsw_embedding_server(
)
continue
# Standard embedding request
# Handle direct text embedding request (for OpenAI mode)
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
print(f"Processing direct text embedding request for {len(request_payload)} texts")
try:
from leann.api import compute_embeddings_openai
embeddings = compute_embeddings_openai(request_payload, model_name)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
continue
except Exception as e:
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
socket.send(msgpack.packb([]))
continue
# Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}"
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
@@ -986,11 +1019,18 @@ if __name__ == "__main__":
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "mlx", "openai"],
help="Embedding backend mode"
)
parser.add_argument(
"--use-mlx",
action="store_true",
default=False,
help="Use MLX for model inference",
help="Use MLX for model inference (deprecated: use --embedding-mode mlx)",
)
parser.add_argument(
"--disable-warmup",
@@ -1000,6 +1040,11 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Handle backward compatibility with use_mlx
embedding_mode = args.embedding_mode
if args.use_mlx:
embedding_mode = "mlx"
# Create and start the HNSW embedding server
create_hnsw_embedding_server(
@@ -1013,6 +1058,6 @@ if __name__ == "__main__":
model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
use_mlx=args.use_mlx,
embedding_mode=embedding_mode,
enable_warmup=not args.disable_warmup,
)

View File

@@ -18,11 +18,40 @@ from .chat import get_llm
def compute_embeddings(
chunks: List[str], model_name: str, use_mlx: bool = False
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers"
) -> np.ndarray:
"""Computes embeddings using sentence-transformers or MLX for consistent results."""
if use_mlx:
"""
Computes embeddings using different backends.
Args:
chunks: List of text chunks to embed
model_name: Name of the embedding model
mode: Embedding backend mode. Options:
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
Returns:
numpy array of embeddings
"""
# Auto-detect mode based on model name if not explicitly set
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
mode = "openai"
if mode == "mlx":
return compute_embeddings_mlx(chunks, model_name)
elif mode == "openai":
return compute_embeddings_openai(chunks, model_name)
elif mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(chunks, model_name)
else:
raise ValueError(f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai")
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers library."""
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
@@ -53,6 +82,49 @@ def compute_embeddings(
return embeddings
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using OpenAI API."""
try:
import openai
import os
except ImportError as e:
raise RuntimeError(
"openai not available. Install with: uv pip install openai"
) from e
# Get API key from environment
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
client = openai.OpenAI(api_key=api_key)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
# OpenAI has a limit on batch size and input length
max_batch_size = 100 # Conservative batch size
all_embeddings = []
for i in range(0, len(chunks), max_batch_size):
batch_chunks = chunks[i:i + max_batch_size]
print(f"INFO: Processing batch {i//max_batch_size + 1}/{(len(chunks) + max_batch_size - 1)//max_batch_size}")
try:
response = client.embeddings.create(
model=model_name,
input=batch_chunks
)
batch_embeddings = [embedding.embedding for embedding in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
print(f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}")
return embeddings
def compute_embeddings_mlx(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
@@ -140,7 +212,7 @@ class LeannBuilder:
backend_name: str,
embedding_model: str = "facebook/contriever-msmarco",
dimensions: Optional[int] = None,
use_mlx: bool = False,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
):
self.backend_name = backend_name
@@ -152,7 +224,7 @@ class LeannBuilder:
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.use_mlx = use_mlx
self.embedding_mode = embedding_mode
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
@@ -168,7 +240,7 @@ class LeannBuilder:
raise ValueError("No chunks added.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(["dummy"], self.embedding_model, self.use_mlx)[0]
compute_embeddings(["dummy"], self.embedding_model, self.embedding_mode)[0]
)
path = Path(index_path)
index_dir = path.parent
@@ -195,7 +267,7 @@ class LeannBuilder:
pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = compute_embeddings(
texts_to_embed, self.embedding_model, self.use_mlx
texts_to_embed, self.embedding_model, self.embedding_mode
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -210,7 +282,7 @@ class LeannBuilder:
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"use_mlx": self.use_mlx,
"embedding_mode": self.embedding_mode,
"passage_sources": [
{
"type": "jsonl",
@@ -241,7 +313,11 @@ class LeannSearcher:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
self.use_mlx = self.meta_data.get("use_mlx", False)
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
# Backward compatibility with use_mlx
if self.meta_data.get("use_mlx", False):
self.embedding_mode = "mlx"
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:

View File

@@ -177,7 +177,7 @@ class EmbeddingServerManager:
self.server_port: Optional[int] = None
# atexit.register(self.stop_server)
def start_server(self, port: int, model_name: str, **kwargs) -> bool:
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
"""
Starts the embedding server process.
@@ -310,8 +310,8 @@ class EmbeddingServerManager:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
if "use_mlx" in kwargs and kwargs["use_mlx"]:
command.extend(["--use-mlx"])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
command.extend(["--disable-warmup"])

View File

@@ -78,12 +78,14 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
"Cannot use recompute mode without 'embedding_model' in meta.json."
)
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
server_started = self.embedding_server_manager.start_server(
port=port,
model_name=self.embedding_model,
passages_file=passages_source_file,
distance_metric=kwargs.get("distance_metric"),
use_mlx=kwargs.get("use_mlx", False),
embedding_mode=embedding_mode,
enable_warmup=kwargs.get("enable_warmup", False),
)
if not server_started:
@@ -120,8 +122,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
# Fallback to direct computation
from .api import compute_embeddings
use_mlx = self.meta.get("use_mlx", False)
return compute_embeddings([query], self.embedding_model, use_mlx)
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""