# Filename: evaluate_results_xai_line_sync.py import openai import json import os import time from dotenv import load_dotenv from tqdm import tqdm from collections import defaultdict import concurrent.futures from typing import List, Dict, Any, Tuple # --- Configuration --- load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: raise ValueError("Please set the OPENAI_API_KEY in your .env file") try: client = openai.OpenAI( api_key=OPENAI_API_KEY, ) except ImportError: print("Please install the latest OpenAI library: pip install --upgrade openai") exit() except openai.AuthenticationError: print("OpenAI library reported an AuthenticationError. Ensure OPENAI_API_KEY is correct.") exit() LLM_MODEL = "gpt-3.5-turbo" # Using OpenAI's standard model MAX_RETRIES = 5 INITIAL_RETRY_DELAY_SECONDS = 5 REQUEST_TIMEOUT_SECONDS = 90 MAX_WORKERS = 10 # Number of parallel workers # --- File Paths (Adjust as needed) --- # User provided paths QUERIES_FILE_PATH = "/opt/dlami/nvme/scaling_out/examples/enron_eval_retrieval.jsonl" RAW_PASSAGES_FILE_PATH = "/opt/dlami/nvme/scaling_out/passages/enron_emails/1-shards/raw_passages-0-of-1.jsonl" RESULTS_FILE_PATH = "search_results_top10_bm25.jsonl" # This file's Nth line corresponds to QUERIES_FILE_PATH's Nth line OUTPUT_EVALUATION_FILE = "llm_containment_evaluations_xai_line_sync.jsonl" # --- LLM Prompt Definitions for Containment (Same as before) --- CONTAINMENT_SYSTEM_PROMPT = """You are an AI evaluator. Your task is to determine if the core information presented in the 'Retrieved Passage' is directly contained within *any* of the text snippets provided in the 'Ground Truth Email Snippets' list.""" CONTAINMENT_USER_TEMPLATE = """Retrieved Passage: "{retrieved_passage_text}" --- Ground Truth Email Snippets (Parts of the correct source email): {ground_truth_snippets_formatted_list} --- Is the core information of the 'Retrieved Passage' directly present or fully contained within *any* of the 'Ground Truth Email Snippets' listed above? - Focus on whether the specific facts or statements in the 'Retrieved Passage' can be found within the ground truth snippets. - Ignore minor formatting differences. If the retrieved passage is a direct quote or a very close paraphrase of content within the ground truth snippets, answer YES. - Respond YES if the Retrieved Passage's content is clearly represented in one or more of the ground truth snippets. - Respond NO if the Retrieved Passage's content is not found, is contradictory, or introduces significant information not present in the ground truth snippets. Your response must be a single word: YES or NO. """ # --- Data Loading Functions --- def load_queries_as_list(file_path): """ Loads queries from a jsonl file into a list, preserving order. Each item in the list is a dict containing original_id, query_text, and ground_truth_message_ids. """ queries_list = [] try: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): try: data = json.loads(line) required_keys = ["id", "query", "ground_truth_message_ids"] if not all(key in data for key in required_keys): print(f"Warning: Skipping line {line_num + 1} in query file due to missing keys: {line.strip()}") continue if not isinstance(data["ground_truth_message_ids"], list): print(f"Warning: 'ground_truth_message_ids' is not a list in line {line_num + 1}. Skipping: {line.strip()}") continue queries_list.append({ "original_id": data["id"], # Store the original ID from the file "query_text": data["query"], "ground_truth_message_ids": data["ground_truth_message_ids"] }) except json.JSONDecodeError: print(f"Warning: Skipping malformed JSON line {line_num + 1} in query file: {line.strip()}") except FileNotFoundError: print(f"Error: Queries file not found at {file_path}") exit() print(f"Loaded {len(queries_list)} queries (as a list) from {file_path}") return queries_list def load_all_passages_by_message_id(raw_passages_file_path): """Loads all raw passages into memory, grouped by message_id. (Same as before)""" passages_dict = defaultdict(list) # ... (implementation from previous script, no changes needed here) ... print(f"Loading all raw passages from {raw_passages_file_path} into memory...") try: with open(raw_passages_file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): try: data = json.loads(line) if "message_id" in data and "text" in data: passages_dict[data["message_id"]].append(data["text"]) else: print(f"Warning: Skipping line {line_num+1} in raw passages file due to missing 'message_id' or 'text'.") except json.JSONDecodeError: print(f"Warning: Skipping malformed JSON line {line_num + 1} in raw passages file: {line.strip()}") print(f"Finished loading raw passages. Found {len(passages_dict)} unique message IDs.") except FileNotFoundError: print(f"Error: Raw passages file not found at {raw_passages_file_path}") exit() except MemoryError: print("Error: Ran out of memory loading all raw passages. Consider an indexed approach.") exit() return dict(passages_dict) def load_search_results_as_list(file_path): """Loads search results from a jsonl file into a list, preserving order.""" results_list = [] # ... (implementation similar to load_queries_as_list, parsing each line as a dict) ... try: with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): try: data = json.loads(line) # We expect "query_id" (though not used for matching) and "passages" if "passages" not in data: # query_id might be implicitly by order print(f"Warning: Skipping line {line_num + 1} in search results file due to missing 'passages' key: {line.strip()}") continue results_list.append(data) except json.JSONDecodeError: print(f"Warning: Skipping malformed JSON line {line_num + 1} in search results file: {line.strip()}") except FileNotFoundError: print(f"Error: Search results file not found at {file_path}") exit() print(f"Loaded {len(results_list)} search result sets (as a list) from {file_path}") return results_list def format_ground_truth_snippets(snippet_list): """Formats the list of ground truth snippets for the prompt. (Same as before)""" # ... (implementation from previous script) ... if not snippet_list: return " [No ground truth snippets found for the target message ID(s)]" formatted = [] for i, snippet in enumerate(snippet_list): display_snippet = (snippet[:500] + '...') if len(snippet) > 500 else snippet formatted.append(f" {i+1}. {display_snippet}") return "\n".join(formatted) # --- LLM API Call Function --- def get_llm_containment_evaluation(retrieved_passage_text: str, ground_truth_snippets_list: List[str], query_id_for_log: str, passage_identifier_info: str, query_text_for_context: str = "") -> str: """Calls the OpenAI API with retry logic.""" formatted_gt_snippets = format_ground_truth_snippets(ground_truth_snippets_list) # max_gt_chars_in_prompt = 5000 # Arbitrary limit, adjust as needed # if len(formatted_gt_snippets) > max_gt_chars_in_prompt: # print(f"Warning: Ground truth snippets for Q_log_id:{query_id_for_log} are too long ({len(formatted_gt_snippets)} chars), truncating for LLM prompt.") # formatted_gt_snippets = formatted_gt_snippets[:max_gt_chars_in_prompt] + "\n [... Snippets Truncated ...]" user_prompt = CONTAINMENT_USER_TEMPLATE.format( retrieved_passage_text=retrieved_passage_text, ground_truth_snippets_formatted_list=formatted_gt_snippets ) messages = [ {"role": "system", "content": CONTAINMENT_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ] current_retry_delay = INITIAL_RETRY_DELAY_SECONDS for attempt in range(MAX_RETRIES): try: response = client.chat.completions.create( model=LLM_MODEL, messages=messages, temperature=0.0, max_tokens=10, timeout=REQUEST_TIMEOUT_SECONDS ) answer = response.choices[0].message.content.strip().upper() if answer in ["YES", "NO"]: return answer else: print(f"Warning: Unexpected LLM response content '{answer[:100]}' for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Defaulting to NO.") return "NO" except openai.APIConnectionError as e: error_message = f"API Connection Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}" except openai.RateLimitError as e: error_message = f"API Rate Limit Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}" except openai.APIStatusError as e: error_message = f"API Status Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e.status_code} - {e.response}" if e.status_code == 401: return "ERROR_AUTH" if e.status_code == 500: pass else: return "ERROR_API_CLIENT" except Exception as e: error_message = f"Unexpected error with OpenAI lib (Attempt {attempt + 1}/{MAX_RETRIES}): {type(e).__name__} - {e}" print(f"{error_message}. Query Log ID: {query_id_for_log}, Passage: {passage_identifier_info}") if "ERROR_AUTH" in error_message or "ERROR_API_CLIENT" in error_message: break if attempt < MAX_RETRIES - 1: print(f"Retrying in {current_retry_delay} seconds...") time.sleep(current_retry_delay) current_retry_delay = min(current_retry_delay * 2, 60) else: print(f"Max retries ({MAX_RETRIES}) reached for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Skipping.") return "ERROR_MAX_RETRIES" return "ERROR_MAX_RETRIES" def process_query_passage_pair(args: Tuple[Dict[str, Any], Dict[str, Any], Dict[str, List[str]], set]) -> List[Dict[str, Any]]: """Process a single query-passage pair for parallel execution.""" query_info, result_item, passages_lookup, already_evaluated = args evaluations = [] query_original_id = query_info["original_id"] query_text = query_info["query_text"] target_message_ids = query_info.get("ground_truth_message_ids", []) if not target_message_ids: return evaluations ground_truth_snippets = [] for msg_id_in_query_file in target_message_ids: msg_id_to_lookup = msg_id_in_query_file if msg_id_in_query_file.startswith("<") and msg_id_in_query_file.endswith(">"): msg_id_to_lookup = msg_id_in_query_file[1:-1] snippets = passages_lookup.get(msg_id_to_lookup) if snippets: ground_truth_snippets.extend(snippets) if not ground_truth_snippets: return evaluations retrieved_passages = result_item.get("passages", []) if not retrieved_passages: return evaluations for passage_idx, passage_obj in enumerate(retrieved_passages): if not isinstance(passage_obj, dict): print(f"Warning: Invalid passage format for Q_original_id:{query_original_id}, passage index {passage_idx}. Skipping passage.") continue retrieved_passage_text = passage_obj.get("text", "").strip() passage_identifier = passage_obj.get("passage_id", passage_obj.get("id", f"retrieved_idx_{passage_idx}")) evaluation_key = (query_original_id, passage_identifier) if evaluation_key in already_evaluated: continue passage_text_preview = (retrieved_passage_text[:75] + '...') if len(retrieved_passage_text) > 75 else retrieved_passage_text if not retrieved_passage_text: evaluation = "NO" else: evaluation = get_llm_containment_evaluation( retrieved_passage_text, ground_truth_snippets, query_original_id, passage_identifier, query_text ) if evaluation == "ERROR_AUTH": print("Authentication error with OpenAI API. Stopping script.") return evaluations evaluation_record = { "query_original_id": query_original_id, "passage_identifier": passage_identifier, "passage_text_preview": passage_text_preview, "evaluation": evaluation, "model_used": LLM_MODEL, "ground_truth_message_ids_checked": target_message_ids } evaluations.append(evaluation_record) return evaluations # --- Resume Logic --- def load_existing_evaluations(output_file): """Loads already evaluated query-passage pairs using 'passage_identifier' and 'query_original_id'. (Same as before, but keying with original_id)""" # ... (implementation from previous script, ensure it uses the correct ID for keys) ... evaluated_pairs = set() if os.path.exists(output_file): print(f"Loading existing containment evaluations from {output_file}...") with open(output_file, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): try: data = json.loads(line) # Key for resuming should be based on the logged original query ID query_original_id = data.get('query_original_id') passage_identifier = data.get('passage_identifier') if query_original_id is not None and passage_identifier is not None: evaluated_pairs.add((query_original_id, passage_identifier)) else: print(f"Warning: Could not identify query_original_id/passage_identifier in existing file line {line_num + 1}.") except json.JSONDecodeError: print(f"Warning: Skipping malformed line {line_num + 1} in existing file: {line.strip()}") except KeyError as e: print(f"Warning: Skipping line {line_num + 1} with missing key '{e}' in existing file: {line.strip()}") print(f"Loaded {len(evaluated_pairs)} existing evaluation records.") else: print(f"No existing evaluation file found at {output_file}. Starting fresh.") return evaluated_pairs # --- Main Execution Logic --- def main(): """Main function to run the containment evaluation process using parallel processing.""" print(f"Starting containment evaluation using OpenAI model: {LLM_MODEL} via OpenAI library interface.") # Load data as lists queries_list = load_queries_as_list(QUERIES_FILE_PATH) passages_lookup = load_all_passages_by_message_id(RAW_PASSAGES_FILE_PATH) search_results_list = load_search_results_as_list(RESULTS_FILE_PATH) if not queries_list or not search_results_list or not passages_lookup: print("Error loading one or more input files or raw passages. Exiting.") return # Determine the number of items to process num_items_to_process = min(len(queries_list), len(search_results_list)) print(f"Will process {num_items_to_process} query-result pairs.") already_evaluated = load_existing_evaluations(OUTPUT_EVALUATION_FILE) try: with open(OUTPUT_EVALUATION_FILE, 'a', encoding='utf-8') as outfile: # Prepare arguments for parallel processing process_args = [ (queries_list[i], search_results_list[i], passages_lookup, already_evaluated) for i in range(num_items_to_process) ] # Use ThreadPoolExecutor for parallel processing with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: # Submit all tasks and get futures futures = [executor.submit(process_query_passage_pair, args) for args in process_args] # Process results as they complete for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing query-result pairs"): try: evaluations = future.result() for evaluation in evaluations: outfile.write(json.dumps(evaluation) + "\n") outfile.flush() # Update already_evaluated set already_evaluated.add((evaluation["query_original_id"], evaluation["passage_identifier"])) except Exception as e: print(f"Error processing query-result pair: {e}") except IOError as e: print(f"Error writing to output file {OUTPUT_EVALUATION_FILE}: {e}") return except Exception as e: print(f"An unexpected error occurred during the main processing loop: {e}") return print("\n--- Containment Evaluation Script Finished ---") # --- Final Summary Calculation --- print(f"Calculating final summary statistics from: {OUTPUT_EVALUATION_FILE}") final_query_containment_found = {} total_evaluated_pairs = 0 error_count = 0 evaluated_query_original_ids = set() try: with open(OUTPUT_EVALUATION_FILE, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f): total_evaluated_pairs += 1 try: data = json.loads(line) q_original_id = data['query_original_id'] eval_result = data['evaluation'] evaluated_query_original_ids.add(q_original_id) if eval_result == "YES": final_query_containment_found[q_original_id] = True elif q_original_id not in final_query_containment_found: final_query_containment_found[q_original_id] = False if eval_result not in ["YES", "NO"]: error_count += 1 except (json.JSONDecodeError, KeyError) as e: print(f"Error reading line {line_num + 1} during summary: {e} - Line: {line.strip()}") error_count += 1 num_queries_with_any_contained = sum(1 for contained in final_query_containment_found.values() if contained) total_unique_queries_evaluated = len(evaluated_query_original_ids) if total_unique_queries_evaluated > 0: containment_rate_at_10 = num_queries_with_any_contained / total_unique_queries_evaluated print(f"\n--- Final Statistics (Containment Check) ---") print(f"Total unique queries processed (based on output file entries): {total_unique_queries_evaluated}") print(f"Number of queries with at least one contained passage (YES): {num_queries_with_any_contained}") print(f"Containment Match Rate @ Top 10 (Any YES): {containment_rate_at_10:.4f}") print(f"Total query-passage pairs processed (lines in output file): {total_evaluated_pairs}") if error_count > 0: print(f"Number of evaluation errors or non-YES/NO results: {error_count}") else: print("No evaluation results found to summarize.") except FileNotFoundError: print(f"Error: Output file {OUTPUT_EVALUATION_FILE} not found for summary.") except Exception as e: print(f"An unexpected error occurred during summary calculation: {e}") print(f"\nDetailed containment evaluations saved to: {OUTPUT_EVALUATION_FILE}") if __name__ == "__main__": # Dummy files for testing the line sync logic if not os.path.exists(QUERIES_FILE_PATH): print(f"Warning: {QUERIES_FILE_PATH} not found. Creating dummy file.") with open(QUERIES_FILE_PATH, 'w', encoding='utf-8') as f: json.dump({"id": "q_alpha", "query": "Query Alpha Text", "ground_truth_message_ids": [""]}, f); f.write("\n") # Line 0 json.dump({"id": "q_beta", "query": "Query Beta Text", "ground_truth_message_ids": [""]}, f); f.write("\n") # Line 1 json.dump({"id": "q_gamma", "query": "Query Gamma Text", "ground_truth_message_ids": [""]}, f); f.write("\n")# Line 2 if not os.path.exists(RAW_PASSAGES_FILE_PATH): print(f"Warning: {RAW_PASSAGES_FILE_PATH} not found. Creating dummy file.") with open(RAW_PASSAGES_FILE_PATH, 'w', encoding='utf-8') as f: json.dump({"text": "Content from message A snippet 1.", "id": 100, "message_id": ""}, f); f.write("\n") json.dump({"text": "Content from message A snippet 2.", "id": 101, "message_id": ""}, f); f.write("\n") json.dump({"text": "Content from message B.", "id": 200, "message_id": ""}, f); f.write("\n") json.dump({"text": "Content from message D (unrelated).", "id": 300, "message_id": ""}, f); f.write("\n") # RESULTS_FILE_PATH should have results corresponding line-by-line to QUERIES_FILE_PATH if not os.path.exists(RESULTS_FILE_PATH): print(f"Warning: {RESULTS_FILE_PATH} not found. Creating dummy file (2 entries).") with open(RESULTS_FILE_PATH, 'w', encoding='utf-8') as f: # Result for query "q_alpha" (line 0 in queries file) json.dump({"query_id": "this_can_be_ignored_if_line_sync", "passages": [{"id": 101, "text": "Content from message A snippet 2."}, {"id": 300, "text": "Content from message D (unrelated)."}]}, f); f.write("\n") # Result for query "q_beta" (line 1 in queries file) json.dump({"query_id": "this_too", "passages": [{"id": 999, "text": "Some other text."}, {"id": 200, "text": "Content from message B."}]}, f); f.write("\n") # Note: Only 2 result sets, but 3 queries in dummy QUERIES_FILE_PATH. # The script will process min(len(queries_list), len(search_results_list)) if you uncomment that logic, # or just len(search_results_list) as it's currently written for tqdm. main()