import os import pickle import numpy as np import time import math from pathlib import Path # --- Configuration --- DOMAIN_NAME = "rpj_wiki" # Domain name used for finding passages EMBEDDER_NAME = "facebook/contriever-msmarco" # Used in paths ORIGINAL_EMBEDDING_SHARD_ID = 0 # The shard ID of the embedding file we are loading # Define the base directory SCALING_OUT_DIR = Path("/powerrag/scaling_out").resolve() # Original Data Paths (using functions similar to your utils) # Assuming embeddings for rpj_wiki are in a single file despite passage sharding # Adjust NUM_SHARDS_EMBEDDING if embeddings are also sharded NUM_SHARDS_EMBEDDING = 1 ORIGINAL_EMBEDDING_FILE_TEMPLATE = ( SCALING_OUT_DIR / "embeddings/{embedder_name}/{domain_name}/{total_shards}-shards/passages_{shard_id:02d}.pkl" ) ORIGINAL_EMBEDDING_FILE = str(ORIGINAL_EMBEDDING_FILE_TEMPLATE).format( embedder_name=EMBEDDER_NAME, domain_name=DOMAIN_NAME, total_shards=NUM_SHARDS_EMBEDDING, shard_id=ORIGINAL_EMBEDDING_SHARD_ID, ) # Passage Paths NUM_SHARDS_PASSAGE = 8 # As specified in your original utils (NUM_SHARDS['rpj_wiki']) ORIGINAL_PASSAGE_FILE_TEMPLATE = ( SCALING_OUT_DIR / "passages/{domain_name}/{total_shards}-shards/raw_passages-{shard_id}-of-{total_shards}.pkl" ) # New identifier for the sampled dataset NEW_DATASET_NAME = "rpj_wiki_1M" # Fraction to sample (1/60) SAMPLE_FRACTION = 1 / 60 # Output Paths for the new sampled dataset OUTPUT_EMBEDDING_DIR = SCALING_OUT_DIR / "embeddings" / EMBEDDER_NAME / NEW_DATASET_NAME / "1-shards" OUTPUT_PASSAGE_DIR = SCALING_OUT_DIR / "passages" / NEW_DATASET_NAME / "1-shards" OUTPUT_EMBEDDING_FILE = OUTPUT_EMBEDDING_DIR / f"passages_{ORIGINAL_EMBEDDING_SHARD_ID:02d}.pkl" # The new passage file represents the *single* shard of the sampled data OUTPUT_PASSAGE_FILE = OUTPUT_PASSAGE_DIR / f"raw_passages-0-of-1.pkl" # --- Directory Setup --- print("Creating output directories if they don't exist...") OUTPUT_EMBEDDING_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_PASSAGE_DIR.mkdir(parents=True, exist_ok=True) print(f"Embeddings output dir: {OUTPUT_EMBEDDING_DIR}") print(f"Passages output dir: {OUTPUT_PASSAGE_DIR}") # --- Helper Function to Load Passages --- def load_all_passages(domain_name, num_shards, template): """Loads all passage shards and creates an ID-to-content map.""" all_passages_list = [] passage_id_to_content_map = {} print(f"Loading passages for domain '{domain_name}' from {num_shards} shards...") total_loaded = 0 start_time = time.time() for shard_id in range(num_shards): shard_path_str = str(template).format( domain_name=domain_name, total_shards=num_shards, shard_id=shard_id, ) shard_path = Path(shard_path_str) if not shard_path.exists(): print(f"Warning: Passage shard file not found, skipping: {shard_path}") continue try: print(f" Loading shard {shard_id} from {shard_path}...") with open(shard_path, 'rb') as f: shard_passages = pickle.load(f) # Expected: list of dicts if not isinstance(shard_passages, list): print(f"Warning: Shard {shard_id} data is not a list.") continue all_passages_list.extend(shard_passages) # Build the map, ensuring IDs are strings for consistent lookup for passage_dict in shard_passages: if 'id' in passage_dict: passage_id_to_content_map[str(passage_dict['id'])] = passage_dict else: print(f"Warning: Passage dict in shard {shard_id} missing 'id' key.") print(f" Loaded {len(shard_passages)} passages from shard {shard_id}.") total_loaded += len(shard_passages) except Exception as e: print(f"Error loading passage shard {shard_id} from {shard_path}: {e}") load_time = time.time() - start_time print(f"Finished loading passages. Total passages loaded: {total_loaded} in {load_time:.2f} seconds.") print(f"Total unique passages mapped by ID: {len(passage_id_to_content_map)}") return all_passages_list, passage_id_to_content_map # --- Load Original Embeddings --- print(f"\nLoading original embeddings from {ORIGINAL_EMBEDDING_FILE}...") start_load_time = time.time() try: with open(ORIGINAL_EMBEDDING_FILE, 'rb') as f: original_embedding_data = pickle.load(f) except FileNotFoundError: print(f"Error: Original embedding file not found at {ORIGINAL_EMBEDDING_FILE}") exit(1) except Exception as e: print(f"Error loading embedding pickle file: {e}") exit(1) load_time = time.time() - start_load_time print(f"Loaded original embeddings data in {load_time:.2f} seconds.") # --- Extract and Validate Embeddings --- try: if not isinstance(original_embedding_data, (list, tuple)) or len(original_embedding_data) != 2: raise TypeError("Expected embedding data to be a list or tuple of length 2 (ids, embeddings)") original_embedding_ids = original_embedding_data[0] # Should be a list/iterable of IDs original_embeddings = original_embedding_data[1] # Should be a NumPy array # Ensure IDs are in a list for easier indexing later if they aren't already if not isinstance(original_embedding_ids, list): print("Converting embedding IDs to list...") original_embedding_ids = list(original_embedding_ids) if not isinstance(original_embeddings, np.ndarray): raise TypeError("Expected second element of embedding data to be a NumPy array") print(f"Original data contains {len(original_embedding_ids)} embedding IDs.") print(f"Original embeddings shape: {original_embeddings.shape}, dtype: {original_embeddings.dtype}") if len(original_embedding_ids) != original_embeddings.shape[0]: raise ValueError(f"Mismatch! Number of embedding IDs ({len(original_embedding_ids)}) does not match number of embeddings ({original_embeddings.shape[0]})") except (TypeError, ValueError, IndexError) as e: print(f"Error processing loaded embedding data: {e}") print("Please ensure the embedding pickle file contains: (list_of_passage_ids, numpy_embedding_array)") exit(1) total_embeddings = original_embeddings.shape[0] # --- Load Original Passages --- # This might take time and memory depending on the dataset size _, passage_id_to_content_map = load_all_passages( DOMAIN_NAME, NUM_SHARDS_PASSAGE, ORIGINAL_PASSAGE_FILE_TEMPLATE ) if not passage_id_to_content_map: print("Error: No passages were loaded. Cannot proceed with sampling.") exit(1) # --- Calculate Sample Size --- num_samples = math.ceil(total_embeddings * SAMPLE_FRACTION) # Use ceil to get at least 1/60th print(f"\nTotal original embeddings: {total_embeddings}") print(f"Sampling fraction: {SAMPLE_FRACTION:.6f} (1/60)") print(f"Target number of samples: {num_samples}") if num_samples > total_embeddings: print("Warning: Calculated sample size exceeds total embeddings. Using all embeddings.") num_samples = total_embeddings elif num_samples <= 0: print("Error: Calculated sample size is zero or negative.") exit(1) # --- Perform Random Sampling (Based on Embeddings) --- print("\nPerforming random sampling based on embeddings...") start_sample_time = time.time() # Set a seed for reproducibility if needed # np.random.seed(42) # Generate unique random indices from the embeddings list sampled_indices = np.random.choice(total_embeddings, size=num_samples, replace=False) # Retrieve the corresponding IDs and embeddings using the sampled indices sampled_embedding_ids = [original_embedding_ids[i] for i in sampled_indices] sampled_embeddings = original_embeddings[sampled_indices] sample_time = time.time() - start_sample_time print(f"Sampling completed in {sample_time:.2f} seconds.") print(f"Sampled {len(sampled_embedding_ids)} IDs and embeddings.") print(f"Sampled embeddings shape: {sampled_embeddings.shape}") # --- Retrieve Corresponding Passages --- print("\nRetrieving corresponding passages for sampled IDs...") start_passage_retrieval_time = time.time() sampled_passages = [] missing_ids_count = 0 for i, pid in enumerate(sampled_embedding_ids): # Convert pid to string for lookup in the map pid_str = str(pid) if pid_str in passage_id_to_content_map: sampled_passages.append(passage_id_to_content_map[pid_str]) else: # This indicates an inconsistency between embedding IDs and passage IDs print(f"Warning: Passage ID '{pid_str}' (from embedding index {sampled_indices[i]}) not found in passage map.") missing_ids_count += 1 passage_retrieval_time = time.time() - start_passage_retrieval_time print(f"Retrieved {len(sampled_passages)} passages in {passage_retrieval_time:.2f} seconds.") if missing_ids_count > 0: print(f"Warning: Could not find passages for {missing_ids_count} sampled IDs.") if not sampled_passages: print("Error: No corresponding passages found for the sampled embeddings. Check ID matching.") exit(1) # --- Prepare Output Data --- # Embeddings output format: tuple(list_of_ids, numpy_array_of_embeddings) output_embedding_data = (sampled_embedding_ids, sampled_embeddings) # Passages output format: list[dict] (matching input shard format) output_passage_data = sampled_passages # --- Save Sampled Embeddings --- print(f"\nSaving sampled embeddings to {OUTPUT_EMBEDDING_FILE}...") start_save_time = time.time() try: with open(OUTPUT_EMBEDDING_FILE, 'wb') as f: pickle.dump(output_embedding_data, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: print(f"Error saving sampled embeddings: {e}") # Continue to try saving passages if desired, or exit(1) save_time = time.time() - start_save_time print(f"Saved sampled embeddings in {save_time:.2f} seconds.") # --- Save Sampled Passages --- print(f"\nSaving sampled passages to {OUTPUT_PASSAGE_FILE}...") start_save_time = time.time() try: with open(OUTPUT_PASSAGE_FILE, 'wb') as f: pickle.dump(output_passage_data, f, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: print(f"Error saving sampled passages: {e}") exit(1) save_time = time.time() - start_save_time print(f"Saved sampled passages in {save_time:.2f} seconds.") print(f"\nScript finished successfully.") print(f"Sampled embeddings saved to: {OUTPUT_EMBEDDING_FILE}") print(f"Sampled passages saved to: {OUTPUT_PASSAGE_FILE}")