Files
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

444 lines
22 KiB
Python

# Filename: evaluate_results_xai_line_sync.py
import openai
import json
import os
import time
from dotenv import load_dotenv
from tqdm import tqdm
from collections import defaultdict
import concurrent.futures
from typing import List, Dict, Any, Tuple
# --- Configuration ---
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("Please set the OPENAI_API_KEY in your .env file")
try:
client = openai.OpenAI(
api_key=OPENAI_API_KEY,
)
except ImportError:
print("Please install the latest OpenAI library: pip install --upgrade openai")
exit()
except openai.AuthenticationError:
print("OpenAI library reported an AuthenticationError. Ensure OPENAI_API_KEY is correct.")
exit()
LLM_MODEL = "gpt-3.5-turbo" # Using OpenAI's standard model
MAX_RETRIES = 5
INITIAL_RETRY_DELAY_SECONDS = 5
REQUEST_TIMEOUT_SECONDS = 90
MAX_WORKERS = 10 # Number of parallel workers
# --- File Paths (Adjust as needed) ---
# User provided paths
QUERIES_FILE_PATH = "/opt/dlami/nvme/scaling_out/examples/enron_eval_retrieval.jsonl"
RAW_PASSAGES_FILE_PATH = "/opt/dlami/nvme/scaling_out/passages/enron_emails/1-shards/raw_passages-0-of-1.jsonl"
RESULTS_FILE_PATH = "search_results_top10_bm25.jsonl" # This file's Nth line corresponds to QUERIES_FILE_PATH's Nth line
OUTPUT_EVALUATION_FILE = "llm_containment_evaluations_xai_line_sync.jsonl"
# --- LLM Prompt Definitions for Containment (Same as before) ---
CONTAINMENT_SYSTEM_PROMPT = """You are an AI evaluator. Your task is to determine if the core information presented in the 'Retrieved Passage' is directly contained within *any* of the text snippets provided in the 'Ground Truth Email Snippets' list."""
CONTAINMENT_USER_TEMPLATE = """Retrieved Passage:
"{retrieved_passage_text}"
---
Ground Truth Email Snippets (Parts of the correct source email):
{ground_truth_snippets_formatted_list}
---
Is the core information of the 'Retrieved Passage' directly present or fully contained within *any* of the 'Ground Truth Email Snippets' listed above?
- Focus on whether the specific facts or statements in the 'Retrieved Passage' can be found within the ground truth snippets.
- Ignore minor formatting differences. If the retrieved passage is a direct quote or a very close paraphrase of content within the ground truth snippets, answer YES.
- Respond YES if the Retrieved Passage's content is clearly represented in one or more of the ground truth snippets.
- Respond NO if the Retrieved Passage's content is not found, is contradictory, or introduces significant information not present in the ground truth snippets.
Your response must be a single word: YES or NO.
"""
# --- Data Loading Functions ---
def load_queries_as_list(file_path):
"""
Loads queries from a jsonl file into a list, preserving order.
Each item in the list is a dict containing original_id, query_text, and ground_truth_message_ids.
"""
queries_list = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
required_keys = ["id", "query", "ground_truth_message_ids"]
if not all(key in data for key in required_keys):
print(f"Warning: Skipping line {line_num + 1} in query file due to missing keys: {line.strip()}")
continue
if not isinstance(data["ground_truth_message_ids"], list):
print(f"Warning: 'ground_truth_message_ids' is not a list in line {line_num + 1}. Skipping: {line.strip()}")
continue
queries_list.append({
"original_id": data["id"], # Store the original ID from the file
"query_text": data["query"],
"ground_truth_message_ids": data["ground_truth_message_ids"]
})
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in query file: {line.strip()}")
except FileNotFoundError:
print(f"Error: Queries file not found at {file_path}")
exit()
print(f"Loaded {len(queries_list)} queries (as a list) from {file_path}")
return queries_list
def load_all_passages_by_message_id(raw_passages_file_path):
"""Loads all raw passages into memory, grouped by message_id. (Same as before)"""
passages_dict = defaultdict(list)
# ... (implementation from previous script, no changes needed here) ...
print(f"Loading all raw passages from {raw_passages_file_path} into memory...")
try:
with open(raw_passages_file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
if "message_id" in data and "text" in data:
passages_dict[data["message_id"]].append(data["text"])
else:
print(f"Warning: Skipping line {line_num+1} in raw passages file due to missing 'message_id' or 'text'.")
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in raw passages file: {line.strip()}")
print(f"Finished loading raw passages. Found {len(passages_dict)} unique message IDs.")
except FileNotFoundError:
print(f"Error: Raw passages file not found at {raw_passages_file_path}")
exit()
except MemoryError:
print("Error: Ran out of memory loading all raw passages. Consider an indexed approach.")
exit()
return dict(passages_dict)
def load_search_results_as_list(file_path):
"""Loads search results from a jsonl file into a list, preserving order."""
results_list = []
# ... (implementation similar to load_queries_as_list, parsing each line as a dict) ...
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
# We expect "query_id" (though not used for matching) and "passages"
if "passages" not in data: # query_id might be implicitly by order
print(f"Warning: Skipping line {line_num + 1} in search results file due to missing 'passages' key: {line.strip()}")
continue
results_list.append(data)
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in search results file: {line.strip()}")
except FileNotFoundError:
print(f"Error: Search results file not found at {file_path}")
exit()
print(f"Loaded {len(results_list)} search result sets (as a list) from {file_path}")
return results_list
def format_ground_truth_snippets(snippet_list):
"""Formats the list of ground truth snippets for the prompt. (Same as before)"""
# ... (implementation from previous script) ...
if not snippet_list:
return " [No ground truth snippets found for the target message ID(s)]"
formatted = []
for i, snippet in enumerate(snippet_list):
display_snippet = (snippet[:500] + '...') if len(snippet) > 500 else snippet
formatted.append(f" {i+1}. {display_snippet}")
return "\n".join(formatted)
# --- LLM API Call Function ---
def get_llm_containment_evaluation(retrieved_passage_text: str, ground_truth_snippets_list: List[str], query_id_for_log: str, passage_identifier_info: str, query_text_for_context: str = "") -> str:
"""Calls the OpenAI API with retry logic."""
formatted_gt_snippets = format_ground_truth_snippets(ground_truth_snippets_list)
# max_gt_chars_in_prompt = 5000 # Arbitrary limit, adjust as needed
# if len(formatted_gt_snippets) > max_gt_chars_in_prompt:
# print(f"Warning: Ground truth snippets for Q_log_id:{query_id_for_log} are too long ({len(formatted_gt_snippets)} chars), truncating for LLM prompt.")
# formatted_gt_snippets = formatted_gt_snippets[:max_gt_chars_in_prompt] + "\n [... Snippets Truncated ...]"
user_prompt = CONTAINMENT_USER_TEMPLATE.format(
retrieved_passage_text=retrieved_passage_text,
ground_truth_snippets_formatted_list=formatted_gt_snippets
)
messages = [
{"role": "system", "content": CONTAINMENT_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
current_retry_delay = INITIAL_RETRY_DELAY_SECONDS
for attempt in range(MAX_RETRIES):
try:
response = client.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0,
max_tokens=10,
timeout=REQUEST_TIMEOUT_SECONDS
)
answer = response.choices[0].message.content.strip().upper()
if answer in ["YES", "NO"]:
return answer
else:
print(f"Warning: Unexpected LLM response content '{answer[:100]}' for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Defaulting to NO.")
return "NO"
except openai.APIConnectionError as e:
error_message = f"API Connection Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
except openai.RateLimitError as e:
error_message = f"API Rate Limit Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
except openai.APIStatusError as e:
error_message = f"API Status Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e.status_code} - {e.response}"
if e.status_code == 401:
return "ERROR_AUTH"
if e.status_code == 500:
pass
else:
return "ERROR_API_CLIENT"
except Exception as e:
error_message = f"Unexpected error with OpenAI lib (Attempt {attempt + 1}/{MAX_RETRIES}): {type(e).__name__} - {e}"
print(f"{error_message}. Query Log ID: {query_id_for_log}, Passage: {passage_identifier_info}")
if "ERROR_AUTH" in error_message or "ERROR_API_CLIENT" in error_message:
break
if attempt < MAX_RETRIES - 1:
print(f"Retrying in {current_retry_delay} seconds...")
time.sleep(current_retry_delay)
current_retry_delay = min(current_retry_delay * 2, 60)
else:
print(f"Max retries ({MAX_RETRIES}) reached for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Skipping.")
return "ERROR_MAX_RETRIES"
return "ERROR_MAX_RETRIES"
def process_query_passage_pair(args: Tuple[Dict[str, Any], Dict[str, Any], Dict[str, List[str]], set]) -> List[Dict[str, Any]]:
"""Process a single query-passage pair for parallel execution."""
query_info, result_item, passages_lookup, already_evaluated = args
evaluations = []
query_original_id = query_info["original_id"]
query_text = query_info["query_text"]
target_message_ids = query_info.get("ground_truth_message_ids", [])
if not target_message_ids:
return evaluations
ground_truth_snippets = []
for msg_id_in_query_file in target_message_ids:
msg_id_to_lookup = msg_id_in_query_file
if msg_id_in_query_file.startswith("<") and msg_id_in_query_file.endswith(">"):
msg_id_to_lookup = msg_id_in_query_file[1:-1]
snippets = passages_lookup.get(msg_id_to_lookup)
if snippets:
ground_truth_snippets.extend(snippets)
if not ground_truth_snippets:
return evaluations
retrieved_passages = result_item.get("passages", [])
if not retrieved_passages:
return evaluations
for passage_idx, passage_obj in enumerate(retrieved_passages):
if not isinstance(passage_obj, dict):
print(f"Warning: Invalid passage format for Q_original_id:{query_original_id}, passage index {passage_idx}. Skipping passage.")
continue
retrieved_passage_text = passage_obj.get("text", "").strip()
passage_identifier = passage_obj.get("passage_id", passage_obj.get("id", f"retrieved_idx_{passage_idx}"))
evaluation_key = (query_original_id, passage_identifier)
if evaluation_key in already_evaluated:
continue
passage_text_preview = (retrieved_passage_text[:75] + '...') if len(retrieved_passage_text) > 75 else retrieved_passage_text
if not retrieved_passage_text:
evaluation = "NO"
else:
evaluation = get_llm_containment_evaluation(
retrieved_passage_text,
ground_truth_snippets,
query_original_id,
passage_identifier,
query_text
)
if evaluation == "ERROR_AUTH":
print("Authentication error with OpenAI API. Stopping script.")
return evaluations
evaluation_record = {
"query_original_id": query_original_id,
"passage_identifier": passage_identifier,
"passage_text_preview": passage_text_preview,
"evaluation": evaluation,
"model_used": LLM_MODEL,
"ground_truth_message_ids_checked": target_message_ids
}
evaluations.append(evaluation_record)
return evaluations
# --- Resume Logic ---
def load_existing_evaluations(output_file):
"""Loads already evaluated query-passage pairs using 'passage_identifier' and 'query_original_id'. (Same as before, but keying with original_id)"""
# ... (implementation from previous script, ensure it uses the correct ID for keys) ...
evaluated_pairs = set()
if os.path.exists(output_file):
print(f"Loading existing containment evaluations from {output_file}...")
with open(output_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
# Key for resuming should be based on the logged original query ID
query_original_id = data.get('query_original_id')
passage_identifier = data.get('passage_identifier')
if query_original_id is not None and passage_identifier is not None:
evaluated_pairs.add((query_original_id, passage_identifier))
else:
print(f"Warning: Could not identify query_original_id/passage_identifier in existing file line {line_num + 1}.")
except json.JSONDecodeError:
print(f"Warning: Skipping malformed line {line_num + 1} in existing file: {line.strip()}")
except KeyError as e:
print(f"Warning: Skipping line {line_num + 1} with missing key '{e}' in existing file: {line.strip()}")
print(f"Loaded {len(evaluated_pairs)} existing evaluation records.")
else:
print(f"No existing evaluation file found at {output_file}. Starting fresh.")
return evaluated_pairs
# --- Main Execution Logic ---
def main():
"""Main function to run the containment evaluation process using parallel processing."""
print(f"Starting containment evaluation using OpenAI model: {LLM_MODEL} via OpenAI library interface.")
# Load data as lists
queries_list = load_queries_as_list(QUERIES_FILE_PATH)
passages_lookup = load_all_passages_by_message_id(RAW_PASSAGES_FILE_PATH)
search_results_list = load_search_results_as_list(RESULTS_FILE_PATH)
if not queries_list or not search_results_list or not passages_lookup:
print("Error loading one or more input files or raw passages. Exiting.")
return
# Determine the number of items to process
num_items_to_process = min(len(queries_list), len(search_results_list))
print(f"Will process {num_items_to_process} query-result pairs.")
already_evaluated = load_existing_evaluations(OUTPUT_EVALUATION_FILE)
try:
with open(OUTPUT_EVALUATION_FILE, 'a', encoding='utf-8') as outfile:
# Prepare arguments for parallel processing
process_args = [
(queries_list[i], search_results_list[i], passages_lookup, already_evaluated)
for i in range(num_items_to_process)
]
# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submit all tasks and get futures
futures = [executor.submit(process_query_passage_pair, args) for args in process_args]
# Process results as they complete
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing query-result pairs"):
try:
evaluations = future.result()
for evaluation in evaluations:
outfile.write(json.dumps(evaluation) + "\n")
outfile.flush()
# Update already_evaluated set
already_evaluated.add((evaluation["query_original_id"], evaluation["passage_identifier"]))
except Exception as e:
print(f"Error processing query-result pair: {e}")
except IOError as e:
print(f"Error writing to output file {OUTPUT_EVALUATION_FILE}: {e}")
return
except Exception as e:
print(f"An unexpected error occurred during the main processing loop: {e}")
return
print("\n--- Containment Evaluation Script Finished ---")
# --- Final Summary Calculation ---
print(f"Calculating final summary statistics from: {OUTPUT_EVALUATION_FILE}")
final_query_containment_found = {}
total_evaluated_pairs = 0
error_count = 0
evaluated_query_original_ids = set()
try:
with open(OUTPUT_EVALUATION_FILE, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
total_evaluated_pairs += 1
try:
data = json.loads(line)
q_original_id = data['query_original_id']
eval_result = data['evaluation']
evaluated_query_original_ids.add(q_original_id)
if eval_result == "YES":
final_query_containment_found[q_original_id] = True
elif q_original_id not in final_query_containment_found:
final_query_containment_found[q_original_id] = False
if eval_result not in ["YES", "NO"]:
error_count += 1
except (json.JSONDecodeError, KeyError) as e:
print(f"Error reading line {line_num + 1} during summary: {e} - Line: {line.strip()}")
error_count += 1
num_queries_with_any_contained = sum(1 for contained in final_query_containment_found.values() if contained)
total_unique_queries_evaluated = len(evaluated_query_original_ids)
if total_unique_queries_evaluated > 0:
containment_rate_at_10 = num_queries_with_any_contained / total_unique_queries_evaluated
print(f"\n--- Final Statistics (Containment Check) ---")
print(f"Total unique queries processed (based on output file entries): {total_unique_queries_evaluated}")
print(f"Number of queries with at least one contained passage (YES): {num_queries_with_any_contained}")
print(f"Containment Match Rate @ Top 10 (Any YES): {containment_rate_at_10:.4f}")
print(f"Total query-passage pairs processed (lines in output file): {total_evaluated_pairs}")
if error_count > 0:
print(f"Number of evaluation errors or non-YES/NO results: {error_count}")
else:
print("No evaluation results found to summarize.")
except FileNotFoundError:
print(f"Error: Output file {OUTPUT_EVALUATION_FILE} not found for summary.")
except Exception as e:
print(f"An unexpected error occurred during summary calculation: {e}")
print(f"\nDetailed containment evaluations saved to: {OUTPUT_EVALUATION_FILE}")
if __name__ == "__main__":
# Dummy files for testing the line sync logic
if not os.path.exists(QUERIES_FILE_PATH):
print(f"Warning: {QUERIES_FILE_PATH} not found. Creating dummy file.")
with open(QUERIES_FILE_PATH, 'w', encoding='utf-8') as f:
json.dump({"id": "q_alpha", "query": "Query Alpha Text", "ground_truth_message_ids": ["<msg_A>"]}, f); f.write("\n") # Line 0
json.dump({"id": "q_beta", "query": "Query Beta Text", "ground_truth_message_ids": ["<msg_B>"]}, f); f.write("\n") # Line 1
json.dump({"id": "q_gamma", "query": "Query Gamma Text", "ground_truth_message_ids": ["<msg_C>"]}, f); f.write("\n")# Line 2
if not os.path.exists(RAW_PASSAGES_FILE_PATH):
print(f"Warning: {RAW_PASSAGES_FILE_PATH} not found. Creating dummy file.")
with open(RAW_PASSAGES_FILE_PATH, 'w', encoding='utf-8') as f:
json.dump({"text": "Content from message A snippet 1.", "id": 100, "message_id": "<msg_A>"}, f); f.write("\n")
json.dump({"text": "Content from message A snippet 2.", "id": 101, "message_id": "<msg_A>"}, f); f.write("\n")
json.dump({"text": "Content from message B.", "id": 200, "message_id": "<msg_B>"}, f); f.write("\n")
json.dump({"text": "Content from message D (unrelated).", "id": 300, "message_id": "<msg_D>"}, f); f.write("\n")
# RESULTS_FILE_PATH should have results corresponding line-by-line to QUERIES_FILE_PATH
if not os.path.exists(RESULTS_FILE_PATH):
print(f"Warning: {RESULTS_FILE_PATH} not found. Creating dummy file (2 entries).")
with open(RESULTS_FILE_PATH, 'w', encoding='utf-8') as f:
# Result for query "q_alpha" (line 0 in queries file)
json.dump({"query_id": "this_can_be_ignored_if_line_sync", "passages": [{"id": 101, "text": "Content from message A snippet 2."}, {"id": 300, "text": "Content from message D (unrelated)."}]}, f); f.write("\n")
# Result for query "q_beta" (line 1 in queries file)
json.dump({"query_id": "this_too", "passages": [{"id": 999, "text": "Some other text."}, {"id": 200, "text": "Content from message B."}]}, f); f.write("\n")
# Note: Only 2 result sets, but 3 queries in dummy QUERIES_FILE_PATH.
# The script will process min(len(queries_list), len(search_results_list)) if you uncomment that logic,
# or just len(search_results_list) as it's currently written for tqdm.
main()