Merge remote-tracking branch 'origin/main' into financebench

This commit is contained in:
Andy Lee
2025-09-23 21:52:14 -07:00
36 changed files with 6472 additions and 4077 deletions

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import msgpack
import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
# Ensure we have handlers if none exist
if not logger.handlers:
handler = logging.StreamHandler()
stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server(
@@ -167,7 +189,12 @@ def create_hnsw_embedding_server(
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -217,7 +244,10 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -283,7 +313,12 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)