import argparse import gc # Import garbage collector interface import os import struct import sys import time import numpy as np # --- FourCCs (add more if needed) --- INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little") # Add other HNSW fourccs if you expect different storage types inside HNSW # INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little') # INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little') # INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed NULL_INDEX_FOURCC = int.from_bytes(b"null", "little") # --- Helper functions for reading/writing binary data --- def read_struct(f, fmt): """Reads data according to the struct format.""" size = struct.calcsize(fmt) data = f.read(size) if len(data) != size: raise EOFError( f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}." ) return struct.unpack(fmt, data)[0] def read_vector_raw(f, element_fmt_char): """Reads a vector (size followed by data), returns count and raw bytes.""" count = -1 # Initialize count total_bytes = -1 # Initialize total_bytes try: count = read_struct(f, " max_reasonable_count or count < 0: raise MemoryError( f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read." ) total_bytes = count * element_size # --- FIX for MemoryError: Check for huge byte size before allocation --- max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow raise MemoryError( f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch." ) data_bytes = f.read(total_bytes) if len(data_bytes) != total_bytes: raise EOFError( f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}." ) return count, data_bytes except (MemoryError, OverflowError) as e: # Add context to the error message print( f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr, ) raise e # Re-raise the original error type def read_numpy_vector(f, np_dtype, struct_fmt_char): """Reads a vector into a NumPy array.""" count = -1 # Initialize count for robust error handling print( f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True, ) try: count, data_bytes = read_vector_raw(f, struct_fmt_char) print(f"Count={count}, Bytes={len(data_bytes)}") if count > 0 and len(data_bytes) > 0: arr = np.frombuffer(data_bytes, dtype=np_dtype) if arr.size != count: raise ValueError( f"Inconsistent array size after reading. Expected {count}, got {arr.size}" ) return arr elif count == 0: return np.array([], dtype=np_dtype) else: raise ValueError("Read zero bytes but count > 0.") except MemoryError as e: # Now count should be defined (or -1 if error was in read_struct) print( f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr, ) raise e except Exception as e: # Catch other potential errors like ValueError print( f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr, ) raise e def write_numpy_vector(f, arr, struct_fmt_char): """Writes a NumPy array as a vector (size followed by data).""" count = arr.size f.write(struct.pack(" 0 else 0 def write_compact_format( f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, levels_np, compact_level_ptr, compact_node_offsets_np, compact_neighbors_data, storage_fourcc, storage_data, ): """Write HNSW data in compact format following C++ read order exactly.""" # Write IndexHNSW Header f_out.write(struct.pack(" 1: f_out.write(struct.pack(" {output_filename}") start_time = time.time() original_hnsw_data = {} neighbors_np = None # Initialize to allow check in finally block try: with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out: # --- Read IndexHNSW FourCC and Header --- print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...") # ... (Keep the header reading logic as before) ... hnsw_index_fourcc = read_struct(f_in, " 1: original_hnsw_data["metric_arg"] = read_struct(f_in, " 0 else 0 if neighbors_np.size != expected_neighbors_size: print( f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}." ) gc.collect() original_hnsw_data["entry_point"] = read_struct(f_in, " 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1% progress = (i / ntotal) * 100 elapsed = time.time() - start_time print( f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="", ) node_max_level = levels_np[i] - 1 if node_max_level < -1: node_max_level = -1 node_ptr_start_index = current_level_ptr_idx compact_node_offsets_np[i] = node_ptr_start_index original_offset_start = offsets_np[i] num_pointers_expected = (node_max_level + 1) + 1 for level in range(node_max_level + 1): compact_level_ptr.append(current_data_idx) begin_orig_np = original_offset_start + get_cum_neighbors( cum_nneighbor_per_level_np, level ) end_orig_np = original_offset_start + get_cum_neighbors( cum_nneighbor_per_level_np, level + 1 ) begin_orig = int(begin_orig_np) end_orig = int(end_orig_np) neighbors_len = len(neighbors_np) # Cache length begin_orig = min(max(0, begin_orig), neighbors_len) end_orig = min(max(begin_orig, end_orig), neighbors_len) if begin_orig < end_orig: # Slicing creates a copy, could be memory intensive for large M # Consider iterating if memory becomes an issue here level_neighbors_slice = neighbors_np[begin_orig:end_orig] valid_neighbors_mask = level_neighbors_slice >= 0 num_valid = np.count_nonzero(valid_neighbors_mask) if num_valid > 0: # Append valid neighbors compact_neighbors_data.extend( level_neighbors_slice[valid_neighbors_mask] ) current_data_idx += num_valid total_valid_neighbors_counted += num_valid compact_level_ptr.append(current_data_idx) current_level_ptr_idx += num_pointers_expected compact_node_offsets_np[ntotal] = current_level_ptr_idx print( f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. " ) # Clear progress line # --- Validation Checks --- print(f"[{time.time() - start_time:.2f}s] Running validation checks...") valid_check_passed = True # Check 1: Total valid neighbors count print(" Checking total valid neighbor count...") expected_valid_count = np.sum(neighbors_np >= 0) if total_valid_neighbors_counted != len(compact_neighbors_data): print( f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr, ) valid_check_passed = False if expected_valid_count != len(compact_neighbors_data): print( f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr, ) valid_check_passed = False else: print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}") # Check 2: Final pointer indices consistency print(" Checking final pointer indices...") if compact_node_offsets_np[ntotal] != len(compact_level_ptr): print( f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr, ) valid_check_passed = False if ( len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data) ) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0): last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1 print( f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr, ) valid_check_passed = False else: print(" OK: Final pointers match data size.") if not valid_check_passed: print( "Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr, ) # Optional: Exit here if validation fails # return False # --- Explicitly delete large intermediate arrays --- print( f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..." ) del neighbors_np del offsets_np gc.collect() print( f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}" ) # --- Write CSR HNSW graph data using unified function --- print( f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..." ) # Determine storage fourcc and data based on prune_embeddings if prune_embeddings: print(" Pruning embeddings: Writing NULL storage marker.") output_storage_fourcc = NULL_INDEX_FOURCC storage_data = b"" else: # Keep embeddings - read and preserve original storage data if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC: print(" Preserving embeddings: Reading original storage data...") storage_data = f_in.read() # Read remaining storage data output_storage_fourcc = storage_fourcc print(f" Read {len(storage_data)} bytes of storage data") else: print(" No embeddings found in original file (NULL storage)") output_storage_fourcc = NULL_INDEX_FOURCC storage_data = b"" # Use the unified write function write_compact_format( f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, levels_np, compact_level_ptr, compact_node_offsets_np, compact_neighbors_data, output_storage_fourcc, storage_data, ) # Clean up memory del assign_probas_np, cum_nneighbor_per_level_np, levels_np del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np gc.collect() end_time = time.time() print(f"[{end_time - start_time:.2f}s] Conversion complete.") return True except FileNotFoundError: print(f"Error: Input file not found: {input_filename}", file=sys.stderr) return False except MemoryError as e: print( f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr, ) # Clean up potentially partially written output file? try: os.remove(output_filename) except OSError: pass return False except EOFError as e: print( f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr, ) try: os.remove(output_filename) except OSError: pass return False except Exception as e: print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr) import traceback traceback.print_exc() try: os.remove(output_filename) except OSError: pass return False # Ensure neighbors_np is deleted even if an error occurs after its allocation finally: try: if "neighbors_np" in locals() and neighbors_np is not None: del neighbors_np gc.collect() except NameError: pass # --- Script Execution --- if __name__ == "__main__": parser = argparse.ArgumentParser( description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file." ) parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file") parser.add_argument( "output_csr_graph_file", help="Path to write the output CSR HNSW graph file" ) parser.add_argument( "--prune-embeddings", action="store_true", default=True, help="Prune embedding storage (write NULL storage marker)", ) parser.add_argument( "--keep-embeddings", action="store_true", help="Keep embedding storage (overrides --prune-embeddings)", ) args = parser.parse_args() if not os.path.exists(args.input_index_file): print(f"Error: Input file not found: {args.input_index_file}", file=sys.stderr) sys.exit(1) if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file): print("Error: Input and output filenames cannot be the same.", file=sys.stderr) sys.exit(1) prune_embeddings = args.prune_embeddings and not args.keep_embeddings success = convert_hnsw_graph_to_csr( args.input_index_file, args.output_csr_graph_file, prune_embeddings ) if not success: sys.exit(1)