138 lines
7.1 KiB
Python
138 lines
7.1 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import json
|
|
from contriever.src.contriever import load_retriever
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
os.environ["KMP_BLOCKTIME"] = "0"
|
|
|
|
torch.set_num_threads(1)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
|
"""Embed queries using the model with batching"""
|
|
model = model.half()
|
|
model.eval()
|
|
embeddings = []
|
|
batch_question = []
|
|
|
|
with torch.no_grad():
|
|
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
|
batch_question.append(query)
|
|
|
|
# Process when batch is full or at the end
|
|
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
|
encoded_batch = tokenizer.batch_encode_plus(
|
|
batch_question,
|
|
return_tensors="pt",
|
|
max_length=512,
|
|
padding=True,
|
|
truncation=True,
|
|
)
|
|
|
|
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
output = model(**encoded_batch)
|
|
|
|
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
|
# if "contriever" not in model_name_or_path:
|
|
# output = output.last_hidden_state[:, 0, :]
|
|
|
|
embeddings.append(output.cpu())
|
|
batch_question = [] # Reset batch
|
|
|
|
embeddings = torch.cat(embeddings, dim=0).numpy()
|
|
print(f"Query embeddings shape: {embeddings.shape}")
|
|
return embeddings
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Debug embedding tool")
|
|
parser.add_argument("--model", type=str, default="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
|
help="Model name for embedding (default: facebook/contriever-msmarco)")
|
|
parser.add_argument("--batch-size", type=int, default=32,
|
|
help="Batch size for encoding (default: 32)")
|
|
parser.add_argument("--input-file", type=str,
|
|
help="Input file with queries (JSON lines format with 'query' field)")
|
|
parser.add_argument("--output-file", type=str, default="embeddings.npy",
|
|
help="Output file to save embeddings (default: embeddings.npy)")
|
|
parser.add_argument("--text", type=str, nargs="+",
|
|
help="Direct text input to embed (can provide multiple)")
|
|
parser.add_argument("--save-text", action="store_true",
|
|
help="Save the input text alongside embeddings")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load model
|
|
print(f"Loading query encoder: {args.model}")
|
|
query_encoder, query_tokenizer, _ = load_retriever(args.model)
|
|
query_encoder = query_encoder.to(device)
|
|
query_encoder.eval()
|
|
|
|
# Get queries
|
|
queries = []
|
|
|
|
# From file if provided
|
|
if args.input_file:
|
|
print(f"Loading queries from: {args.input_file}")
|
|
with open(args.input_file, "r") as f:
|
|
for line in f:
|
|
data = json.loads(line)
|
|
queries.append(data["query"])
|
|
|
|
# From command line if provided
|
|
if args.text:
|
|
print(f"Using {len(args.text)} queries from command line")
|
|
queries.extend(args.text)
|
|
|
|
# If no queries, use some examples
|
|
if not queries:
|
|
print("No queries provided, using example queries")
|
|
queries = [
|
|
"Were there any variances detected for hour 6 on 3/9/01?"
|
|
]
|
|
|
|
print(f"Embedding {len(queries)} queries")
|
|
for i, q in enumerate(queries[:5]): # Print first 5 queries
|
|
print(f"Query {i+1}: {q}")
|
|
if len(queries) > 5:
|
|
print(f"... and {len(queries)-5} more")
|
|
|
|
# Encode queries
|
|
embeddings = embed_queries(
|
|
queries, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size
|
|
)
|
|
|
|
|
|
passages = [
|
|
"Start Date: 3/9/01; HourAhead hour: 6; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. Variances detected in Energy Import/Export schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001030906.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 20.00 / Final: 19.80) TRANS_TYPE: FINAL SC_ID: TOSCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: UNCHEM_1_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 29.00 / Final: 28.20) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 45.00 / Final: 43.80) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: PANDOL_6_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 13.00 / Final: 12.60) TRANS_TYPE: FINAL SC_ID: Wheelabrat MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: MARTEL_2_AMFOR ---- Energy Import/Export Schedule ---- $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 62.00 / Final: 60.40) TRANS_TYPE: FINAL SC_ID: ECTstCA MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5001 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 63.00 / Final: 61.23) TRANS_TYPE: FINAL SC_ID: ECTstSW MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5000 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 17.00 / Final: 11.00) TRANS_TYPE: FINAL SC_ID: ECTRT MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: SYLMAR_2_NOB INTERCHG_ID: EPMI_CISO_LUCKY ENGY_TYPE: NFRM",
|
|
"Start Date: 3/30/01; HourAhead hour: 15; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001033015.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 0.00 / Final: 0.00) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 45.00 / Final: 44.00) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: PANDOL_6_UNIT"
|
|
]
|
|
|
|
# Embed passages
|
|
passage_embeddings = embed_queries(passages, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size)
|
|
|
|
|
|
# distance with passages 0 and query
|
|
distance_0 = np.linalg.norm(embeddings[0] - passage_embeddings[0])
|
|
print(f"Distance between query 0 and passage 0: {distance_0}")
|
|
|
|
# distance with passages 1 and query
|
|
distance_1 = np.linalg.norm(embeddings[0] - passage_embeddings[1])
|
|
print(f"Distance between query 0 and passage 1: {distance_1}")
|
|
|
|
# print which one is closer
|
|
if distance_0 < distance_1:
|
|
print("Query 0 is closer to passage 0")
|
|
else:
|
|
print("Query 0 is closer to passage 1")
|
|
|
|
|
|
|
|
print("Done!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|