feat: support more embedders
This commit is contained in:
@@ -85,6 +85,7 @@ def create_hnsw_embedding_server(
|
||||
max_batch_size: int = 128,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
custom_max_length_param: Optional[int] = None,
|
||||
distance_metric: str = "mips",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
@@ -100,6 +101,7 @@ def create_hnsw_embedding_server(
|
||||
max_batch_size: Maximum batch size for processing
|
||||
model_name: Transformer model name
|
||||
custom_max_length_param: Custom max sequence length
|
||||
distance_metric: The distance metric to use
|
||||
"""
|
||||
print(f"Loading tokenizer for {model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
@@ -222,6 +224,7 @@ def create_hnsw_embedding_server(
|
||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process a batch of texts and return embeddings"""
|
||||
_is_e5_model = "e5" in model_name.lower()
|
||||
_is_bge_model = "bge" in model_name.lower()
|
||||
batch_size = len(texts_batch)
|
||||
|
||||
# E5 model preprocessing
|
||||
@@ -262,7 +265,9 @@ def create_hnsw_embedding_server(
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
|
||||
with pool_timer.timing():
|
||||
if not hasattr(out, 'last_hidden_state'):
|
||||
if _is_bge_model:
|
||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||
elif not hasattr(out, 'last_hidden_state'):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
else:
|
||||
@@ -279,7 +284,7 @@ def create_hnsw_embedding_server(
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
final_embeddings = pooled_embeddings
|
||||
if _is_e5_model:
|
||||
if _is_e5_model or _is_bge_model:
|
||||
with norm_timer.timing():
|
||||
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
||||
|
||||
@@ -408,14 +413,14 @@ def create_hnsw_embedding_server(
|
||||
calc_timer = DeviceTimer("distance calculation", device)
|
||||
with calc_timer.timing():
|
||||
with torch.no_grad():
|
||||
if is_similarity_metric():
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
||||
query_np = query_tensor.cpu().numpy()
|
||||
distances = -np.dot(node_embeddings_np, query_np)
|
||||
else:
|
||||
if distance_metric == "l2":
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
|
||||
query_np = query_tensor.cpu().numpy().astype(np.float32)
|
||||
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
|
||||
else: # mips or cosine
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
||||
query_np = query_tensor.cpu().numpy()
|
||||
distances = -np.dot(node_embeddings_np, query_np)
|
||||
calc_timer.print_elapsed()
|
||||
|
||||
try:
|
||||
@@ -572,6 +577,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
|
||||
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -586,4 +592,5 @@ if __name__ == "__main__":
|
||||
max_batch_size=args.max_batch_size,
|
||||
model_name=args.model_name,
|
||||
custom_max_length_param=args.custom_max_length,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
Reference in New Issue
Block a user