feat: support more embedders

This commit is contained in:
Andy Lee
2025-07-06 00:35:07 +00:00
parent 0aa84e147b
commit 910927a405
6 changed files with 142 additions and 85 deletions

View File

@@ -85,6 +85,7 @@ def create_hnsw_embedding_server(
max_batch_size: int = 128,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
custom_max_length_param: Optional[int] = None,
distance_metric: str = "mips",
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
@@ -100,6 +101,7 @@ def create_hnsw_embedding_server(
max_batch_size: Maximum batch size for processing
model_name: Transformer model name
custom_max_length_param: Custom max sequence length
distance_metric: The distance metric to use
"""
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
@@ -222,6 +224,7 @@ def create_hnsw_embedding_server(
def process_batch(texts_batch, ids_batch, missing_ids):
"""Process a batch of texts and return embeddings"""
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# E5 model preprocessing
@@ -262,7 +265,9 @@ def create_hnsw_embedding_server(
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing():
if not hasattr(out, 'last_hidden_state'):
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else:
@@ -279,7 +284,7 @@ def create_hnsw_embedding_server(
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings
if _is_e5_model:
if _is_e5_model or _is_bge_model:
with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
@@ -408,14 +413,14 @@ def create_hnsw_embedding_server(
calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing():
with torch.no_grad():
if is_similarity_metric():
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
else:
if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
query_np = query_tensor.cpu().numpy().astype(np.float32)
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
else: # mips or cosine
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed()
try:
@@ -572,6 +577,7 @@ if __name__ == "__main__":
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
args = parser.parse_args()
@@ -586,4 +592,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size,
model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
)