Initial commit
This commit is contained in:
1
research/utils/.gitignore
vendored
Normal file
1
research/utils/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
analyze_diskann_graph
|
||||
227
research/utils/analyze_diskann_graph.cpp
Normal file
227
research/utils/analyze_diskann_graph.cpp
Normal file
@@ -0,0 +1,227 @@
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static const size_t DISKANN_SECTOR_LEN = 4096; // Typical sector size
|
||||
|
||||
// ! Use float as CoordT
|
||||
using CoordT = float;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 3) {
|
||||
std::cerr << "Usage: " << argv[0]
|
||||
<< " <diskann_index_file> <output_degree_file>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string disk_index_path = argv[1];
|
||||
std::string output_degree_path = argv[2];
|
||||
std::ifstream in(disk_index_path, std::ios::binary);
|
||||
if (!in.is_open()) {
|
||||
std::cerr << "Failed to open file: " << disk_index_path << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// =========== 1) Read meta information (corresponds to
|
||||
// save_bin<uint64_t>(...,...,...,1,0)) ============== Read bin header:
|
||||
// (npts_i32, dim_i32)
|
||||
int32_t meta_count_i32 = 0, meta_dim_i32 = 0;
|
||||
in.read(reinterpret_cast<char *>(&meta_count_i32), sizeof(int32_t));
|
||||
in.read(reinterpret_cast<char *>(&meta_dim_i32), sizeof(int32_t));
|
||||
size_t meta_count = static_cast<size_t>(meta_count_i32);
|
||||
size_t meta_dim = static_cast<size_t>(meta_dim_i32);
|
||||
|
||||
// According to the diskann::save_bin writing method, here meta_dim is usually
|
||||
// 1
|
||||
std::cout << "[LOG] meta_count=" << meta_count << ", meta_dim=" << meta_dim
|
||||
<< std::endl;
|
||||
if (meta_dim != 1) {
|
||||
std::cerr << "[ERROR] meta_dim != 1,不符合 create_disk_layout 的写盘约定。"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Read meta array
|
||||
std::vector<uint64_t> meta(meta_count);
|
||||
in.read(reinterpret_cast<char *>(meta.data()), meta_count * sizeof(uint64_t));
|
||||
if (!in.good()) {
|
||||
std::cerr << "[ERROR] Failed to read meta array, file is incomplete."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// meta[0..] Metadata
|
||||
// 0: npts_64, 1: ndims_64, 2: medoid, 3: max_node_len, 4: nnodes_per_sector,
|
||||
// 5: vamana_frozen_num, 6: vamana_frozen_loc, 7: append_reorder_data, ...
|
||||
const uint64_t npts_64 = meta[0];
|
||||
const uint64_t ndims_64 = meta[1];
|
||||
const uint64_t medoid = meta[2];
|
||||
const uint64_t max_node_len = meta[3];
|
||||
const uint64_t nnodes_per_sector = meta[4];
|
||||
const uint64_t vamana_frozen_num = meta[5];
|
||||
const uint64_t vamana_frozen_loc = meta[6];
|
||||
const uint64_t append_reorder_data = meta[7];
|
||||
|
||||
std::cout << "[LOG] npts_64=" << npts_64 << " ndims_64=" << ndims_64
|
||||
<< " max_node_len=" << max_node_len
|
||||
<< " nnodes_per_sector=" << nnodes_per_sector << std::endl;
|
||||
// If append_reorder_data==1, it means that reorder_data is appended at the
|
||||
// end of the index, but it does not affect the degree statistics, we can
|
||||
// ignore that part of the vector.
|
||||
|
||||
// =========== 2) Skip the first sector (all empty/placeholder information)
|
||||
// ==============
|
||||
in.seekg(DISKANN_SECTOR_LEN, std::ios::beg);
|
||||
if (!in.good()) {
|
||||
std::cerr << "[ERROR] Failed to seek to the first sector." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// =========== 3) Calculate the total number of sectors ==============
|
||||
// In create_disk_layout:
|
||||
// If nnodes_per_sector > 0, then n_sectors = ceil(npts_64 /
|
||||
// nnodes_per_sector) Otherwise nsectors_per_node = ceil(max_node_len /
|
||||
// 4096), n_sectors = nsectors_per_node * npts_64
|
||||
uint64_t n_sectors = 0;
|
||||
if (nnodes_per_sector > 0) {
|
||||
// Equivalent to Roundup(npts_64, nnodes_per_sector) / nnodes_per_sector
|
||||
n_sectors = (npts_64 + nnodes_per_sector - 1) / nnodes_per_sector;
|
||||
} else {
|
||||
// multi-sector per node
|
||||
uint64_t nsectors_per_node =
|
||||
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
|
||||
n_sectors = nsectors_per_node * npts_64;
|
||||
}
|
||||
std::cout << "[LOG] estimated #sectors storing adjacency = " << n_sectors
|
||||
<< std::endl;
|
||||
|
||||
// =========== 4) Read the degree of all nodes in order ==============
|
||||
// The memory layout of adjacency_count in each node: offset = ndims_64 *
|
||||
// sizeof(CoordT) This is followed by 4 bytes for the number of neighbors
|
||||
// uint32_t If you want to read the complete neighbor list, it is
|
||||
// adjacency_count * sizeof(uint32_t) But we only count the count
|
||||
std::vector<uint32_t> degrees(npts_64, 0); // Store the degree of each node
|
||||
size_t node_id = 0; // Current node number
|
||||
// Buffer for reading one sector at a time
|
||||
std::vector<char> sector_buf(DISKANN_SECTOR_LEN, 0);
|
||||
// If nnodes_per_sector>0, it means that one sector holds multiple nodes
|
||||
// Otherwise, one node occupies nsectors_per_node sectors
|
||||
if (nnodes_per_sector > 0) {
|
||||
// Read one sector at a time
|
||||
for (uint64_t s = 0; s < n_sectors; s++) {
|
||||
in.read((char *)sector_buf.data(), DISKANN_SECTOR_LEN);
|
||||
if (!in.good()) {
|
||||
if (node_id < npts_64) {
|
||||
std::cerr << "[ERROR] Failed to read sector " << s
|
||||
<< ", nodes not finished, file error or incomplete."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
break; // If all nodes are read, you can exit
|
||||
}
|
||||
// Parse each node in sector_buf
|
||||
for (uint64_t i = 0; i < nnodes_per_sector; i++) {
|
||||
if (node_id >= npts_64)
|
||||
break; // All node degrees have been obtained
|
||||
// The starting offset of the node in sector_buf
|
||||
size_t node_offset = i * max_node_len;
|
||||
// offset first skips ndims_64 * sizeof(CoordT)
|
||||
size_t degree_offset = node_offset + ndims_64 * sizeof(CoordT);
|
||||
// Ensure not out of bounds
|
||||
if (degree_offset + sizeof(uint32_t) > sector_buf.size()) {
|
||||
std::cerr << "[ERROR] 不应该发生: 读取degree越过了扇区边界."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
uint32_t deg = 0;
|
||||
memcpy(°, sector_buf.data() + degree_offset, sizeof(uint32_t));
|
||||
degrees[node_id] = deg;
|
||||
node_id++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Each node occupies nsectors_per_node sectors
|
||||
uint64_t nsectors_per_node =
|
||||
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
|
||||
// Read each node
|
||||
for (uint64_t n = 0; n < npts_64; n++) {
|
||||
// Read multiple sectors into a multi-sector buffer
|
||||
std::vector<char> node_buf(nsectors_per_node * DISKANN_SECTOR_LEN, 0);
|
||||
in.read((char *)node_buf.data(), node_buf.size());
|
||||
if (!in.good()) {
|
||||
std::cerr << "[ERROR] Failed to read sector corresponding to node " << n
|
||||
<< ", file error or incomplete." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
// Parse the degree in node_buf
|
||||
size_t degree_offset = ndims_64 * sizeof(CoordT);
|
||||
if (degree_offset + sizeof(uint32_t) > node_buf.size()) {
|
||||
std::cerr << "[ERROR] Should not happen: reading degree beyond node "
|
||||
"region."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
uint32_t deg = 0;
|
||||
memcpy(°, node_buf.data() + degree_offset, sizeof(uint32_t));
|
||||
degrees[n] = deg;
|
||||
}
|
||||
}
|
||||
|
||||
// We assert here: node_id should equal npts_64 (in multi-node mode)
|
||||
if (nnodes_per_sector > 0) {
|
||||
if (node_id != npts_64) {
|
||||
std::cerr << "[ERROR] Actually read " << node_id
|
||||
<< " nodes, but meta npts_64=" << npts_64
|
||||
<< ", file may be incorrect or parsing method is wrong."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// =========== 5) Calculate min / max / average degree ==============
|
||||
uint64_t sum_deg = 0;
|
||||
uint32_t min_deg = std::numeric_limits<uint32_t>::max();
|
||||
uint32_t max_deg = 0;
|
||||
|
||||
for (uint64_t n = 0; n < npts_64; n++) {
|
||||
uint32_t d = degrees[n];
|
||||
sum_deg += d;
|
||||
if (d < min_deg)
|
||||
min_deg = d;
|
||||
if (d > max_deg)
|
||||
max_deg = d;
|
||||
}
|
||||
double avg_deg = (npts_64 == 0) ? 0.0 : double(sum_deg) / double(npts_64);
|
||||
|
||||
// =========== 6) Output results ==============
|
||||
std::cout << "DiskANN index file: " << disk_index_path << std::endl;
|
||||
std::cout << "Total points: " << npts_64 << std::endl;
|
||||
std::cout << "Min degree : " << min_deg << std::endl;
|
||||
std::cout << "Max degree : " << max_deg << std::endl;
|
||||
std::cout << "Avg degree : " << avg_deg << std::endl;
|
||||
|
||||
// =========== 7) Write degrees to output file ==============
|
||||
std::ofstream out_deg(output_degree_path);
|
||||
if (!out_deg.is_open()) {
|
||||
std::cerr << "[ERROR] Failed to open output file: " << output_degree_path
|
||||
<< std::endl;
|
||||
// Don't necessarily exit, maybe just warn? Depends on desired behavior.
|
||||
// For now, we continue closing the input file.
|
||||
} else {
|
||||
std::cout << "[LOG] Writing degrees to " << output_degree_path << "..."
|
||||
<< std::endl;
|
||||
for (uint64_t n = 0; n < npts_64; n++) {
|
||||
out_deg << degrees[n] << std::endl;
|
||||
}
|
||||
out_deg.close();
|
||||
std::cout << "[LOG] Finished writing degrees." << std::endl;
|
||||
}
|
||||
|
||||
in.close();
|
||||
return 0;
|
||||
}
|
||||
187
research/utils/analyze_top3_positions.py
Normal file
187
research/utils/analyze_top3_positions.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
|
||||
# 设置风格
|
||||
plt.style.use('ggplot')
|
||||
sns.set(font_scale=1.2)
|
||||
|
||||
# 读取数据 - 修改为自定义读取逻辑
|
||||
log_file = './top3_positions_log.txt'
|
||||
|
||||
# 手动解析文件
|
||||
data = []
|
||||
header = None
|
||||
with open(log_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
header = lines[0].strip().split(',')
|
||||
|
||||
# 检查是否存在ThreadID列
|
||||
has_thread_id = 'ThreadID' in header
|
||||
|
||||
for line in lines[1:]:
|
||||
# 跳过非数据行,如"Search X results:"
|
||||
if 'results:' in line or not ',' in line:
|
||||
continue
|
||||
|
||||
# 分割并解析数据行
|
||||
parts = line.strip().split(',')
|
||||
|
||||
# 检查数据是否符合格式
|
||||
if len(parts) >= 7: # 至少需要7个字段
|
||||
# 对于旧格式(无ThreadID)的数据
|
||||
if not has_thread_id and len(parts) == 7:
|
||||
data.append([parts[0], 0, parts[1], parts[2], parts[3], parts[4], parts[5], parts[6]])
|
||||
# 对于新格式(有ThreadID)的数据
|
||||
elif has_thread_id and len(parts) == 8:
|
||||
data.append(parts)
|
||||
# 处理不一致的格式
|
||||
elif not has_thread_id and len(parts) == 8:
|
||||
# 假设第二列是ThreadID
|
||||
data.append(parts)
|
||||
if not has_thread_id:
|
||||
has_thread_id = True
|
||||
header.insert(1, 'ThreadID')
|
||||
|
||||
# 确保header正确
|
||||
if not has_thread_id:
|
||||
header.insert(1, 'ThreadID')
|
||||
|
||||
# 创建DataFrame并确保列名正确
|
||||
if len(header) == 8: # 确保有8列
|
||||
df = pd.DataFrame(data, columns=header)
|
||||
else:
|
||||
# 如果header不正确,则使用默认列名
|
||||
default_header = ['Search#', 'ThreadID', 'FullSetSize', 'Rank', 'ID', 'PQ_Rank', 'PQ_Distance', 'Exact_Distance']
|
||||
df = pd.DataFrame(data, columns=default_header)
|
||||
|
||||
# 转换数值列
|
||||
df['Search#'] = pd.to_numeric(df['Search#'], errors='coerce').fillna(0).astype(int)
|
||||
df['ThreadID'] = pd.to_numeric(df['ThreadID'], errors='coerce').fillna(0).astype(int)
|
||||
df['FullSetSize'] = pd.to_numeric(df['FullSetSize'], errors='coerce').fillna(0).astype(int)
|
||||
df['Rank'] = pd.to_numeric(df['Rank'], errors='coerce').fillna(0).astype(int)
|
||||
df['ID'] = pd.to_numeric(df['ID'], errors='coerce').fillna(0).astype(int)
|
||||
df['PQ_Rank'] = pd.to_numeric(df['PQ_Rank'], errors='coerce').fillna(0).astype(int)
|
||||
df['PQ_Distance'] = pd.to_numeric(df['PQ_Distance'], errors='coerce').fillna(0).astype(float)
|
||||
df['Exact_Distance'] = pd.to_numeric(df['Exact_Distance'], errors='coerce').fillna(0).astype(float)
|
||||
|
||||
print(f"读取了 {len(df)} 行数据")
|
||||
print(f"搜索次数: {df['Search#'].nunique()}")
|
||||
print(f"线程数: {df['ThreadID'].nunique()}")
|
||||
|
||||
# 提取前3名的结果
|
||||
top3_df = df[df['Rank'] <= 3].copy()
|
||||
|
||||
# 分析PQ Rank的分布
|
||||
pq_positions = []
|
||||
for rank in [1, 2, 3]:
|
||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
||||
pq_positions.append(rank_df['PQ_Rank'].values)
|
||||
|
||||
# 创建结果目录
|
||||
result_dir = './analysis_results'
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
# 1. 箱型图:展示top-3结果在PQ排序中的位置分布
|
||||
plt.figure(figsize=(10, 6))
|
||||
box_data = [top3_df[top3_df['Rank'] == i]['PQ_Rank'].values for i in [1, 2, 3]]
|
||||
sns.boxplot(data=box_data)
|
||||
plt.xticks([0, 1, 2], ['Top 1', 'Top 2', 'Top 3'])
|
||||
plt.ylabel('PQ Rank Position')
|
||||
plt.title('Distribution of PQ Ranks for Top-3 Exact Results')
|
||||
plt.savefig(os.path.join(result_dir, 'pq_rank_boxplot.png'), dpi=300)
|
||||
|
||||
# 2. 直方图:每个排名在PQ结果中的位置分布
|
||||
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
|
||||
for i, rank in enumerate([1, 2, 3]):
|
||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
||||
sns.histplot(x=rank_df['PQ_Rank'].values, bins=20, ax=axs[i])
|
||||
axs[i].set_title(f'Exact Rank {rank}')
|
||||
axs[i].set_xlabel('PQ Rank')
|
||||
axs[i].set_ylabel('Frequency')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(result_dir, 'pq_rank_histogram.png'), dpi=300)
|
||||
|
||||
# 3. 热力图:PQ排名与精确排名的关系
|
||||
plt.figure(figsize=(10, 8))
|
||||
# 只关注Top 20的排名
|
||||
bins = list(range(0, 22))
|
||||
pq_rank_bins = pd.cut(top3_df['PQ_Rank'], bins=bins)
|
||||
heatmap_data = pd.crosstab(pq_rank_bins, top3_df['Rank'])
|
||||
sns.heatmap(heatmap_data, cmap='YlGnBu', annot=True, fmt='d')
|
||||
plt.title('Heatmap of Exact Rank vs PQ Rank (Top 20)')
|
||||
plt.xlabel('Exact Rank')
|
||||
plt.ylabel('PQ Rank Range')
|
||||
plt.savefig(os.path.join(result_dir, 'rank_heatmap.png'), dpi=300)
|
||||
|
||||
# 4. 散点图:比较PQ距离和精确距离的关系
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.scatterplot(x=top3_df['Exact_Distance'], y=top3_df['PQ_Distance'], hue=top3_df['Rank'], palette='viridis')
|
||||
plt.title('PQ Distance vs Exact Distance')
|
||||
plt.xlabel('Exact Distance')
|
||||
plt.ylabel('PQ Distance')
|
||||
plt.legend(title='Exact Rank')
|
||||
# 添加对角线表示完美匹配
|
||||
min_val = min(top3_df['Exact_Distance'].min(), top3_df['PQ_Distance'].min())
|
||||
max_val = max(top3_df['Exact_Distance'].max(), top3_df['PQ_Distance'].max())
|
||||
plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
|
||||
plt.savefig(os.path.join(result_dir, 'distance_scatter.png'), dpi=300)
|
||||
|
||||
# 5. 折线图:PQ Rank随结果集大小的变化
|
||||
plt.figure(figsize=(12, 6))
|
||||
size_grouped = top3_df.groupby(['FullSetSize', 'Rank'])['PQ_Rank'].mean().reset_index()
|
||||
for rank in [1, 2, 3]:
|
||||
rank_data = size_grouped[size_grouped['Rank'] == rank]
|
||||
plt.plot(rank_data['FullSetSize'], rank_data['PQ_Rank'], marker='o', label=f'Rank {rank}')
|
||||
plt.xlabel('Result Set Size')
|
||||
plt.ylabel('Average PQ Rank')
|
||||
plt.title('Average PQ Rank by Result Set Size')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(os.path.join(result_dir, 'pq_rank_by_size.png'), dpi=300)
|
||||
|
||||
# 6. 百分比热力图:在PQ排名前K的概率
|
||||
top_k_values = [1, 5, 10, 20, 50, 100, 200, 300, 500, 700, 800, 900]
|
||||
top_k_probs = []
|
||||
|
||||
for rank in [1, 2, 3]:
|
||||
rank_df = top3_df[top3_df['Rank'] == rank]
|
||||
probs = []
|
||||
for k in top_k_values:
|
||||
prob = (rank_df['PQ_Rank'] <= k).mean() * 100
|
||||
probs.append(prob)
|
||||
top_k_probs.append(probs)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.heatmap(top_k_probs, annot=True, fmt='.1f', cmap='YlGnBu',
|
||||
xticklabels=[f'Top-{k}' for k in top_k_values],
|
||||
yticklabels=['Rank 1', 'Rank 2', 'Rank 3'])
|
||||
plt.title('Probability (%) of Finding Exact Top-K Results in PQ Top-K')
|
||||
plt.xlabel('PQ Top-K')
|
||||
plt.ylabel('Exact Rank')
|
||||
plt.savefig(os.path.join(result_dir, 'topk_probability.png'), dpi=300)
|
||||
|
||||
# 7. 生成统计摘要报告
|
||||
with open(os.path.join(result_dir, 'summary_report.txt'), 'w') as f:
|
||||
f.write(f"数据分析摘要\n")
|
||||
f.write(f"=================\n")
|
||||
f.write(f"总搜索次数: {df['Search#'].nunique()}\n")
|
||||
f.write(f"使用线程数: {df['ThreadID'].nunique()}\n\n")
|
||||
|
||||
f.write("精确排名前3的结果在PQ排序中的平均位置:\n")
|
||||
for rank in [1, 2, 3]:
|
||||
avg_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].mean()
|
||||
median_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].median()
|
||||
f.write(f" 排名 {rank}: 平均位置 = {avg_pq_rank:.2f}, 中位数位置 = {median_pq_rank:.1f}\n")
|
||||
|
||||
f.write("\n各排名结果在PQ排序前K的命中率:\n")
|
||||
for rank in [1, 2, 3]:
|
||||
f.write(f" 排名 {rank}:\n")
|
||||
for k in top_k_values:
|
||||
hit_rate = (top3_df[top3_df['Rank'] == rank]['PQ_Rank'] <= k).mean() * 100
|
||||
f.write(f" 在PQ前 {k} 中的命中率: {hit_rate:.2f}%\n")
|
||||
|
||||
print(f"分析完成! 结果已保存到 {result_dir} 目录")
|
||||
137
research/utils/debug_embedd.py
Normal file
137
research/utils/debug_embedd.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from contriever.src.contriever import load_retriever
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
|
||||
"""Embed queries using the model with batching"""
|
||||
model = model.half()
|
||||
model.eval()
|
||||
embeddings = []
|
||||
batch_question = []
|
||||
|
||||
with torch.no_grad():
|
||||
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
|
||||
batch_question.append(query)
|
||||
|
||||
# Process when batch is full or at the end
|
||||
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
batch_question,
|
||||
return_tensors="pt",
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
output = model(**encoded_batch)
|
||||
|
||||
# Contriever typically uses output.last_hidden_state pooling or something specialized
|
||||
# if "contriever" not in model_name_or_path:
|
||||
# output = output.last_hidden_state[:, 0, :]
|
||||
|
||||
embeddings.append(output.cpu())
|
||||
batch_question = [] # Reset batch
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0).numpy()
|
||||
print(f"Query embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Debug embedding tool")
|
||||
parser.add_argument("--model", type=str, default="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
help="Model name for embedding (default: facebook/contriever-msmarco)")
|
||||
parser.add_argument("--batch-size", type=int, default=32,
|
||||
help="Batch size for encoding (default: 32)")
|
||||
parser.add_argument("--input-file", type=str,
|
||||
help="Input file with queries (JSON lines format with 'query' field)")
|
||||
parser.add_argument("--output-file", type=str, default="embeddings.npy",
|
||||
help="Output file to save embeddings (default: embeddings.npy)")
|
||||
parser.add_argument("--text", type=str, nargs="+",
|
||||
help="Direct text input to embed (can provide multiple)")
|
||||
parser.add_argument("--save-text", action="store_true",
|
||||
help="Save the input text alongside embeddings")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model
|
||||
print(f"Loading query encoder: {args.model}")
|
||||
query_encoder, query_tokenizer, _ = load_retriever(args.model)
|
||||
query_encoder = query_encoder.to(device)
|
||||
query_encoder.eval()
|
||||
|
||||
# Get queries
|
||||
queries = []
|
||||
|
||||
# From file if provided
|
||||
if args.input_file:
|
||||
print(f"Loading queries from: {args.input_file}")
|
||||
with open(args.input_file, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
|
||||
# From command line if provided
|
||||
if args.text:
|
||||
print(f"Using {len(args.text)} queries from command line")
|
||||
queries.extend(args.text)
|
||||
|
||||
# If no queries, use some examples
|
||||
if not queries:
|
||||
print("No queries provided, using example queries")
|
||||
queries = [
|
||||
"Were there any variances detected for hour 6 on 3/9/01?"
|
||||
]
|
||||
|
||||
print(f"Embedding {len(queries)} queries")
|
||||
for i, q in enumerate(queries[:5]): # Print first 5 queries
|
||||
print(f"Query {i+1}: {q}")
|
||||
if len(queries) > 5:
|
||||
print(f"... and {len(queries)-5} more")
|
||||
|
||||
# Encode queries
|
||||
embeddings = embed_queries(
|
||||
queries, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size
|
||||
)
|
||||
|
||||
|
||||
passages = [
|
||||
"Start Date: 3/9/01; HourAhead hour: 6; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. Variances detected in Energy Import/Export schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001030906.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 20.00 / Final: 19.80) TRANS_TYPE: FINAL SC_ID: TOSCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: UNCHEM_1_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 29.00 / Final: 28.20) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 45.00 / Final: 43.80) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: PANDOL_6_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 13.00 / Final: 12.60) TRANS_TYPE: FINAL SC_ID: Wheelabrat MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: MARTEL_2_AMFOR ---- Energy Import/Export Schedule ---- $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 62.00 / Final: 60.40) TRANS_TYPE: FINAL SC_ID: ECTstCA MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5001 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 63.00 / Final: 61.23) TRANS_TYPE: FINAL SC_ID: ECTstSW MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5000 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 17.00 / Final: 11.00) TRANS_TYPE: FINAL SC_ID: ECTRT MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: SYLMAR_2_NOB INTERCHG_ID: EPMI_CISO_LUCKY ENGY_TYPE: NFRM",
|
||||
"Start Date: 3/30/01; HourAhead hour: 15; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001033015.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 0.00 / Final: 0.00) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 45.00 / Final: 44.00) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: PANDOL_6_UNIT"
|
||||
]
|
||||
|
||||
# Embed passages
|
||||
passage_embeddings = embed_queries(passages, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size)
|
||||
|
||||
|
||||
# distance with passages 0 and query
|
||||
distance_0 = np.linalg.norm(embeddings[0] - passage_embeddings[0])
|
||||
print(f"Distance between query 0 and passage 0: {distance_0}")
|
||||
|
||||
# distance with passages 1 and query
|
||||
distance_1 = np.linalg.norm(embeddings[0] - passage_embeddings[1])
|
||||
print(f"Distance between query 0 and passage 1: {distance_1}")
|
||||
|
||||
# print which one is closer
|
||||
if distance_0 < distance_1:
|
||||
print("Query 0 is closer to passage 0")
|
||||
else:
|
||||
print("Query 0 is closer to passage 1")
|
||||
|
||||
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
research/utils/dedup_eval_data.py
Executable file
33
research/utils/dedup_eval_data.py
Executable file
@@ -0,0 +1,33 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
input_file = "/gscratch/zlab/rulins/data/lm-eval-data/raw_mmlu.jsonl"
|
||||
output_file = "/gscratch/zlab/rulins/data/lm-eval-data/mmlu.jsonl"
|
||||
|
||||
|
||||
raw_data = []
|
||||
|
||||
with open(input_file, "r") as fin:
|
||||
for line in fin:
|
||||
raw_data.append(json.loads(line))
|
||||
|
||||
|
||||
def deduplicate_dicts(dict_list):
|
||||
unique_dicts = set()
|
||||
unique_items = []
|
||||
for item in dict_list:
|
||||
# Make a hashable version of the dictionary by sorting it
|
||||
hashable_item = tuple(sorted(item.items()))
|
||||
if hashable_item not in unique_dicts:
|
||||
unique_dicts.add(hashable_item)
|
||||
unique_items.append(item)
|
||||
return unique_items
|
||||
|
||||
|
||||
unique_data = deduplicate_dicts(raw_data)
|
||||
print(len(unique_data))
|
||||
|
||||
with open(output_file, "w") as fout:
|
||||
for ex in unique_data:
|
||||
fout.write(json.dumps(ex) + "\n")
|
||||
167
research/utils/deduplication.py
Executable file
167
research/utils/deduplication.py
Executable file
@@ -0,0 +1,167 @@
|
||||
import time
|
||||
import multiprocessing
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
|
||||
|
||||
def shingle_document(text, shingle_size=13):
|
||||
"""Generate word-based shingles from a document."""
|
||||
# Split the text into words
|
||||
words = text.split()
|
||||
# Generate shingles that are sequences of 'shingle_size' consecutive words
|
||||
shingles = set(
|
||||
" ".join(words[i : i + shingle_size])
|
||||
for i in range(len(words) - shingle_size + 1)
|
||||
)
|
||||
return shingles
|
||||
|
||||
|
||||
m = MinHash(num_perm=128)
|
||||
perm = m.permutations
|
||||
|
||||
|
||||
def create_minhash(shingles, num_perm=128):
|
||||
"""Create a MinHash object from the set of shingles."""
|
||||
m = MinHash(permutations=perm)
|
||||
m.update_batch(map(lambda x: x.encode("utf-8"), shingles))
|
||||
# for shingle in shingles:
|
||||
# m.update(shingle.encode('utf-8'))
|
||||
return m
|
||||
|
||||
|
||||
def abstein_string_for_decon(string):
|
||||
# Abstein the reading comprehension subject in MMLU where a paragraph from Wikipedia is given in the question
|
||||
return "refers to the following information" in string
|
||||
|
||||
|
||||
def remove_duplicates_with_minhash(
|
||||
documents, string_for_decontamination=None, threshold=0.8, num_perm=128
|
||||
):
|
||||
# Apply 13-gram Jaccard similarity deduplication and removes ones with similarity > 80% compared to former docs.
|
||||
# Remove chunks shorter than 13 words.
|
||||
|
||||
# Create an LSH index
|
||||
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
||||
|
||||
# Dictionary to store the MinHash of each document
|
||||
minhashes = {}
|
||||
|
||||
# Hash string for decontamination first so contaminated samples will be removed
|
||||
decon_offset = 0
|
||||
if string_for_decontamination is not None and not abstein_string_for_decon(
|
||||
string_for_decontamination
|
||||
):
|
||||
shingles = shingle_document(string_for_decontamination)
|
||||
m_decon = create_minhash(shingles, num_perm)
|
||||
lsh.insert(f"doc_{decon_offset}", m_decon)
|
||||
minhashes[decon_offset] = m_decon
|
||||
decon_offset = 1
|
||||
|
||||
# Populate the LSH index
|
||||
short_chunk_indices = []
|
||||
for idx, ctx in enumerate(documents, start=decon_offset):
|
||||
doc = ctx["retrieval text"]
|
||||
shingles = shingle_document(doc)
|
||||
if not shingles:
|
||||
short_chunk_indices.append(idx - decon_offset)
|
||||
m = create_minhash(shingles, num_perm)
|
||||
lsh.insert(f"doc_{idx}", m)
|
||||
minhashes[idx] = m
|
||||
|
||||
# List to keep track of non-duplicate document indices
|
||||
non_duplicate_indices = []
|
||||
|
||||
# Check each document against the LSH index
|
||||
for idx, m in minhashes.items():
|
||||
if idx < decon_offset:
|
||||
continue
|
||||
|
||||
# Query the LSH for near-duplicate candidates
|
||||
result = lsh.query(m)
|
||||
|
||||
# print(result)
|
||||
# print([minhashes[int(doc_id.split("_")[1])].jaccard(m) for doc_id in result])
|
||||
|
||||
# If the document is the only one in its bucket or it appears first in the list
|
||||
if all(
|
||||
minhashes[int(doc_id.split("_")[1])].jaccard(m) <= threshold
|
||||
or int(doc_id.split("_")[1]) >= idx
|
||||
for doc_id in result
|
||||
):
|
||||
non_duplicate_indices.append(idx - decon_offset)
|
||||
|
||||
# Return non-duplicate documents
|
||||
deduplicated_documents = [
|
||||
documents[i] for i in non_duplicate_indices if i not in short_chunk_indices
|
||||
]
|
||||
[doc.update({"quality score": 1}) for doc in deduplicated_documents]
|
||||
removed_documents = [doc for doc in documents if doc not in deduplicated_documents]
|
||||
[doc.update({"quality score": 0}) for doc in removed_documents]
|
||||
|
||||
print(f"Non-deduplication ctxs num: {len(deduplicated_documents)}")
|
||||
# for c in deduplicated_documents:
|
||||
# try:
|
||||
# print(c['retrieval text'][:10])
|
||||
# except:
|
||||
# print(c)
|
||||
# if len(deduplicated_documents[0]['retrieval text'].split(' ')) < 13:
|
||||
# import pdb; pdb.set_trace()
|
||||
return deduplicated_documents # + removed_documents
|
||||
|
||||
|
||||
def process_item(data_item):
|
||||
time.sleep(0.0001)
|
||||
id_, ex = data_item
|
||||
ex["ctxs"] = remove_duplicates_with_minhash(
|
||||
ex["ctxs"], string_for_decontamination=ex["raw_query"]
|
||||
)
|
||||
return id_, ex
|
||||
|
||||
|
||||
def multiprocess_deduplication(data):
|
||||
items_to_process = list(enumerate(data))
|
||||
pool = multiprocessing.Pool(processes=32)
|
||||
for result in pool.imap(process_item, items_to_process):
|
||||
id_, updated_ex = result
|
||||
data[id_] = updated_ex
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
question = (
|
||||
"Answer these questions:\n\nQ: when did the eagles win last super bowl?\nA:"
|
||||
)
|
||||
docs = [
|
||||
"Eagles won the Super Bowl.",
|
||||
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
|
||||
* 20,
|
||||
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
|
||||
* 20
|
||||
+ ".",
|
||||
"An entirely different document looks nothing like the others and should not be considered a duplicate."
|
||||
* 20,
|
||||
"Short sentence." * 1,
|
||||
"As someone who lived in Philly for about five years, I agree about the city\u2019s greatness \u2014 which makes the juxtaposition between its friendly day-to-day interactions and sometimes psychotic sports fandom even more jarring. The Eagles did win three NFL championships before the Super Bowl existed, most recently in 1960. But any fan who was following the team back then is now at least into their mid-60s, if not much older. It is, to say the least, a distant memory from another era. Granted, the Sixers went on their infamous tanking expedition during this span.",
|
||||
] * 1
|
||||
import time
|
||||
|
||||
num_ex = 1
|
||||
|
||||
start = time.time()
|
||||
data1 = []
|
||||
for _ in range(num_ex):
|
||||
cleaned_ex = remove_duplicates_with_minhash(
|
||||
[{"retrieval text": doc} for doc in docs], question
|
||||
)
|
||||
data1.append(cleaned_ex)
|
||||
time1 = time.time() - start
|
||||
|
||||
# ori_data = [{'raw_query': docs[0], 'ctxs': [{'retrieval text': doc} for doc in docs]}] * num_ex
|
||||
# start = time.time()
|
||||
# data2 = multiprocess_deduplication(ori_data)
|
||||
# time2 = time.time()-start
|
||||
|
||||
# assert data2[0]['ctxs'] == data1[0]
|
||||
|
||||
# print(time1)
|
||||
# print(time2)
|
||||
387
research/utils/demo_reader.cpp
Normal file
387
research/utils/demo_reader.cpp
Normal file
@@ -0,0 +1,387 @@
|
||||
/*
|
||||
Run with
|
||||
g++ ./demo_reader.cpp -o ./demo_reader && ./demo_reader --stats \
|
||||
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin
|
||||
\
|
||||
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <limits> // Include for std::numeric_limits
|
||||
#include <string> // Include for std::string comparison
|
||||
#include <vector>
|
||||
|
||||
#define READ_U64(f, val) \
|
||||
f.read(reinterpret_cast<char *>(&val), sizeof(uint64_t))
|
||||
#define READ_U32(f, val) \
|
||||
f.read(reinterpret_cast<char *>(&val), sizeof(uint32_t))
|
||||
#define SECTOR_SIZE 4096
|
||||
|
||||
// 辅助:获取文件大小
|
||||
static size_t get_file_size(const std::string &fname) {
|
||||
std::ifstream ifs(fname, std::ios::binary | std::ios::ate);
|
||||
if (ifs.fail() || !ifs.is_open()) {
|
||||
return 0;
|
||||
}
|
||||
return static_cast<size_t>(ifs.tellg());
|
||||
}
|
||||
|
||||
// 打印 sector 的前若干 hex,用于debug
|
||||
static void print_hex(const char *buf, size_t len, size_t max_len = 64) {
|
||||
size_t show_len = (len < max_len) ? len : max_len;
|
||||
for (size_t i = 0; i < show_len; i++) {
|
||||
unsigned char c = (unsigned char)buf[i];
|
||||
std::cout << std::hex << std::setw(2) << std::setfill('0') << (unsigned)c
|
||||
<< " ";
|
||||
if ((i + 1) % 16 == 0)
|
||||
std::cout << "\n ";
|
||||
}
|
||||
std::cout << std::dec << "\n";
|
||||
}
|
||||
|
||||
/*
|
||||
修正后的 demo_reader:
|
||||
1) 从 partition.bin 读:
|
||||
- C, partition_nums, nd
|
||||
- graph_partitions[i]: 分区 i 的所有 nodeID
|
||||
- id2partition[nodeID]: nodeID => partition i
|
||||
2) 从 _disk_graph.index 读:
|
||||
a) sector0 里先有 2个 int: meta_n, meta_dim
|
||||
b) 再有 meta_n个 uint64_t
|
||||
例如: [0]=nd, [1]=dim, [2]=??, [3]=max_node_len, [4]=C, [5]..??,
|
||||
[8]=file_size... 具体位置要结合 relayout 的写法 c) graph_node_len =
|
||||
max_node_len - dim_in_meta*sizeof(float) 3) 用户给定 target_node_id =>
|
||||
partition_id= id2partition[node_id]
|
||||
在 graph_partitions[partition_id] 里找 node 的下标 j
|
||||
offset = (partition_id+1)*4096 => sector
|
||||
adjacency_offset= j*graph_node_len => neighbor_count => neighbors
|
||||
*/
|
||||
int main(int argc, char **argv) {
|
||||
bool calculate_stats = false;
|
||||
// int arg_offset = 0; // Offset for positional arguments
|
||||
std::string partition_bin;
|
||||
std::string graph_index;
|
||||
uint64_t target_node_id = 0; // Initialize
|
||||
|
||||
if (argc != 4) {
|
||||
std::cerr << "Usage:\n"
|
||||
<< " " << argv[0]
|
||||
<< " <partition.bin> <disk_graph.index> <target_node_id> (Reads "
|
||||
"adjacency for a node)\n"
|
||||
<< " " << argv[0]
|
||||
<< " --stats <partition.bin> <disk_graph.index> "
|
||||
"(Calculates degree statistics)\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check if the first argument is the stats flag
|
||||
if (std::string(argv[1]) == "--stats") {
|
||||
calculate_stats = true;
|
||||
partition_bin = argv[2];
|
||||
graph_index = argv[3];
|
||||
std::cout << "Mode: Calculating Degree Statistics\n";
|
||||
} else {
|
||||
// Assume default mode (single node lookup)
|
||||
calculate_stats = false;
|
||||
partition_bin = argv[1];
|
||||
graph_index = argv[2];
|
||||
try { // Add error handling for stoull
|
||||
target_node_id = std::stoull(argv[3]);
|
||||
} catch (const std::invalid_argument &ia) {
|
||||
std::cerr << "Error: Invalid target_node_id: " << argv[3] << std::endl;
|
||||
return 1;
|
||||
} catch (const std::out_of_range &oor) {
|
||||
std::cerr << "Error: target_node_id out of range: " << argv[3]
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
std::cout << "Mode: Single Node Lookup for Node ID " << target_node_id
|
||||
<< "\n";
|
||||
}
|
||||
|
||||
// 1) 读取 partition.bin
|
||||
std::ifstream pf(partition_bin, std::ios::binary);
|
||||
if (!pf.is_open()) {
|
||||
std::cerr << "Cannot open partition.bin: " << partition_bin << std::endl;
|
||||
return 1;
|
||||
}
|
||||
uint64_t C, partition_nums, nd;
|
||||
READ_U64(pf, C);
|
||||
READ_U64(pf, partition_nums);
|
||||
READ_U64(pf, nd);
|
||||
std::cout << "[partition.bin header] C=" << C
|
||||
<< ", partition_nums=" << partition_nums << ", nd=" << nd
|
||||
<< std::endl;
|
||||
|
||||
// 读取分区节点列表
|
||||
std::vector<std::vector<uint32_t>> graph_partitions(partition_nums);
|
||||
for (uint64_t i = 0; i < partition_nums; i++) {
|
||||
uint32_t psize;
|
||||
READ_U32(pf, psize);
|
||||
graph_partitions[i].resize(psize);
|
||||
pf.read(reinterpret_cast<char *>(graph_partitions[i].data()),
|
||||
psize * sizeof(uint32_t));
|
||||
}
|
||||
// 读取 _id2partition[node], 大小= nd
|
||||
std::vector<uint32_t> id2partition(nd);
|
||||
pf.read(reinterpret_cast<char *>(id2partition.data()), nd * sizeof(uint32_t));
|
||||
pf.close();
|
||||
std::cout << "Done loading partition info.\n";
|
||||
|
||||
if (target_node_id >= nd) {
|
||||
std::cerr << "target_node_id=" << target_node_id
|
||||
<< " out of range nd=" << nd << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// 2) 解析 _disk_graph.index
|
||||
std::ifstream gf(graph_index, std::ios::binary);
|
||||
if (!gf.is_open()) {
|
||||
std::cerr << "Cannot open disk_graph.index: " << graph_index << std::endl;
|
||||
return 1;
|
||||
}
|
||||
// (a) sector0 => 先读 2个 int
|
||||
int meta_n, meta_dim;
|
||||
gf.read((char *)&meta_n, sizeof(int));
|
||||
gf.read((char *)&meta_dim, sizeof(int));
|
||||
std::cout << "[debug] meta_n=" << meta_n << ", meta_dim=" << meta_dim << "\n";
|
||||
|
||||
// (b) 读 meta_n个 uint64_t
|
||||
std::vector<uint64_t> meta_info(meta_n);
|
||||
gf.read(reinterpret_cast<char *>(meta_info.data()),
|
||||
meta_n * sizeof(uint64_t));
|
||||
// 打印
|
||||
for (int i = 0; i < meta_n; i++) {
|
||||
std::cout << " meta_info[" << i << "]= " << meta_info[i] << "\n";
|
||||
}
|
||||
|
||||
size_t file_size = get_file_size(graph_index);
|
||||
std::cout << "[disk_graph.index size] " << file_size << " bytes\n";
|
||||
|
||||
// **根据 relayout log** 你说: meta_info[0]=nd=60450220, meta_info[1]=dim=769,
|
||||
// meta_info[2]=??(16495248?), meta_info[3]=max_node_len=3320,
|
||||
// meta_info[4]=16 (C),
|
||||
// meta_info[8]= 15475261440(文件大小)
|
||||
// 我们这里先手动解析:
|
||||
uint64_t nd_in_meta = meta_info[0];
|
||||
uint64_t dim_in_meta = meta_info[1];
|
||||
uint64_t max_node_len = meta_info[3];
|
||||
uint64_t c_in_meta = meta_info[4];
|
||||
uint64_t entire_file_sz = meta_info[8];
|
||||
|
||||
std::cout << "Based on meta_info:\n"
|
||||
<< " nd_in_meta= " << nd_in_meta
|
||||
<< ", dim_in_meta= " << dim_in_meta
|
||||
<< ", max_node_len= " << max_node_len
|
||||
<< ", c_in_meta= " << c_in_meta
|
||||
<< ", entire_file_size= " << entire_file_sz << "\n";
|
||||
|
||||
// 计算 graph_node_len
|
||||
uint64_t dim_size = dim_in_meta * sizeof(float);
|
||||
uint64_t graph_node_len = max_node_len - dim_size;
|
||||
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
|
||||
|
||||
if (calculate_stats) {
|
||||
// --- Degree Statistics Calculation Mode ---
|
||||
std::cout << " Calculated graph_node_len = " << graph_node_len << "\n\n";
|
||||
|
||||
if (nd == 0) {
|
||||
std::cerr << "Graph has 0 nodes (nd=0). Cannot calculate stats."
|
||||
<< std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint32_t min_degree = std::numeric_limits<uint32_t>::max();
|
||||
uint32_t max_degree = 0;
|
||||
uint64_t total_degree = 0;
|
||||
uint64_t nodes_processed = 0;
|
||||
std::vector<char> sectorBuf(SECTOR_SIZE);
|
||||
|
||||
std::cout << "Calculating degrees for " << nd << " nodes across "
|
||||
<< partition_nums << " partitions..." << std::endl;
|
||||
|
||||
for (uint32_t p = 0; p < partition_nums; ++p) {
|
||||
uint64_t sector_offset = uint64_t(p + 1) * SECTOR_SIZE;
|
||||
gf.seekg(sector_offset, std::ios::beg);
|
||||
if (gf.fail()) {
|
||||
std::cerr << "Error seeking to sector offset for partition " << p
|
||||
<< std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
gf.read(sectorBuf.data(), SECTOR_SIZE);
|
||||
if (gf.fail() && !gf.eof()) {
|
||||
std::cerr << "Error reading sector data for partition " << p
|
||||
<< std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
gf.clear(); // Reset fail bits
|
||||
|
||||
const auto &part_list = graph_partitions[p];
|
||||
for (size_t j = 0; j < part_list.size(); ++j) {
|
||||
uint64_t node_offset = j * graph_node_len;
|
||||
if (node_offset + sizeof(uint32_t) > SECTOR_SIZE) {
|
||||
std::cerr << "Error: Node offset out of sector bounds.\n"
|
||||
<< " Partition=" << p << ", node_subIndex=" << j
|
||||
<< ", node_offset=" << node_offset
|
||||
<< ", graph_node_len=" << graph_node_len << std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
char *adjacency_ptr = sectorBuf.data() + node_offset;
|
||||
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
|
||||
min_degree = std::min(min_degree, neighbor_count);
|
||||
max_degree = std::max(max_degree, neighbor_count);
|
||||
total_degree += neighbor_count;
|
||||
nodes_processed++;
|
||||
}
|
||||
if (p % 10 == 0 || p == partition_nums - 1) {
|
||||
std::cout << " Processed partition " << p + 1 << " / "
|
||||
<< partition_nums << "...\r" << std::flush;
|
||||
}
|
||||
}
|
||||
std::cout << "\nFinished processing partitions." << std::endl;
|
||||
|
||||
if (nodes_processed != nd) {
|
||||
std::cerr << "Warning: Processed " << nodes_processed
|
||||
<< " nodes, but expected " << nd << std::endl;
|
||||
}
|
||||
|
||||
double avg_degree = (nd > 0) ? static_cast<double>(total_degree) / nd : 0.0;
|
||||
std::cout << "\n--- Degree Statistics ---\n";
|
||||
std::cout << "Min Degree: "
|
||||
<< (min_degree == std::numeric_limits<uint32_t>::max()
|
||||
? 0
|
||||
: min_degree)
|
||||
<< std::endl; // Handle case of 0 nodes
|
||||
std::cout << "Max Degree: " << max_degree << std::endl;
|
||||
std::cout << "Avg Degree: " << std::fixed << std::setprecision(2)
|
||||
<< avg_degree << std::endl;
|
||||
std::cout << "Total Degree (Sum): " << total_degree << std::endl;
|
||||
std::cout << "Nodes Processed: " << nodes_processed << std::endl;
|
||||
|
||||
} else {
|
||||
uint64_t nd_in_meta = meta_info[0];
|
||||
uint64_t c_in_meta = meta_info[4];
|
||||
uint64_t entire_file_sz = meta_info[8];
|
||||
std::cout << "Based on meta_info:\n"
|
||||
<< " nd_in_meta= " << nd_in_meta
|
||||
<< ", dim_in_meta= " << dim_in_meta
|
||||
<< ", max_node_len= " << max_node_len
|
||||
<< ", c_in_meta= " << c_in_meta
|
||||
<< ", entire_file_size= " << entire_file_sz << "\n";
|
||||
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
|
||||
|
||||
if (target_node_id >= nd) {
|
||||
std::cerr << "target_node_id=" << target_node_id
|
||||
<< " out of range nd=" << nd << std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// We need id2partition only for single-node lookup
|
||||
std::vector<uint32_t> id2partition(nd);
|
||||
{ // Read id2partition again as it was skipped before
|
||||
std::ifstream pf_again(partition_bin, std::ios::binary);
|
||||
uint64_t header_offset =
|
||||
3 * sizeof(uint64_t); // Skip C, partition_nums, nd
|
||||
uint64_t partition_list_offset = 0;
|
||||
for (uint64_t i = 0; i < partition_nums; i++) {
|
||||
partition_list_offset += sizeof(uint32_t); // Size field
|
||||
partition_list_offset +=
|
||||
graph_partitions[i].size() * sizeof(uint32_t); // Data
|
||||
}
|
||||
pf_again.seekg(header_offset + partition_list_offset, std::ios::beg);
|
||||
pf_again.read(reinterpret_cast<char *>(id2partition.data()),
|
||||
nd * sizeof(uint32_t));
|
||||
// Error check pf_again if needed
|
||||
}
|
||||
|
||||
// 3) 找 target_node_id => partition_id => subIndex
|
||||
uint32_t partition_id = id2partition[target_node_id];
|
||||
if (partition_id >= partition_nums) {
|
||||
std::cerr << "Partition ID out-of-range for target node.\n";
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
const auto &part_list = graph_partitions[partition_id]; // Use const ref
|
||||
auto it =
|
||||
std::find(part_list.begin(), part_list.end(), (uint32_t)target_node_id);
|
||||
if (it == part_list.end()) {
|
||||
std::cerr << "Cannot find node " << target_node_id << " in partition "
|
||||
<< partition_id << std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
size_t j = std::distance(part_list.begin(), it);
|
||||
|
||||
// 4) sector => (partition_id+1)* 4096
|
||||
uint64_t sector_offset = uint64_t(partition_id + 1) * SECTOR_SIZE;
|
||||
gf.seekg(sector_offset, std::ios::beg);
|
||||
std::vector<char> sectorBuf(SECTOR_SIZE);
|
||||
gf.read(sectorBuf.data(), SECTOR_SIZE);
|
||||
if (gf.fail() && !gf.eof()) {
|
||||
std::cerr << "Error reading sector data for partition " << partition_id
|
||||
<< std::endl;
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
gf.clear(); // Reset fail bits
|
||||
|
||||
std::cout << "Partition #" << partition_id
|
||||
<< ", nodeCount= " << part_list.size()
|
||||
<< ", offset= " << sector_offset << "\n"
|
||||
<< " first64 hex:\n ";
|
||||
print_hex(sectorBuf.data(), SECTOR_SIZE, 64);
|
||||
|
||||
// adjacency_offset= j* graph_node_len
|
||||
uint64_t node_offset = j * graph_node_len;
|
||||
if (node_offset + sizeof(uint32_t) >
|
||||
SECTOR_SIZE) { // Check only for neighbor_count read first
|
||||
std::cerr << "Out-of-range. j=" << j << ", node_offset=" << node_offset
|
||||
<< ", node_offset+4=" << (node_offset + sizeof(uint32_t))
|
||||
<< "> 4096\n";
|
||||
gf.close();
|
||||
return 1;
|
||||
}
|
||||
|
||||
char *adjacency_ptr = sectorBuf.data() + node_offset;
|
||||
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
|
||||
std::cout << "[Node " << target_node_id << "] partition=" << partition_id
|
||||
<< ", subIndex=" << j << ", adjacency_offset=" << node_offset
|
||||
<< ", neighbor_count=" << neighbor_count << "\n";
|
||||
|
||||
size_t needed = neighbor_count * sizeof(uint32_t);
|
||||
if (node_offset + sizeof(uint32_t) + needed > SECTOR_SIZE) {
|
||||
std::cerr << "Neighbors partly out-of-range => neighbor_count="
|
||||
<< neighbor_count << "\n";
|
||||
// Option: Can still print partial list if needed, but indicating it's
|
||||
// truncated
|
||||
gf.close();
|
||||
return 1; // Or handle differently
|
||||
}
|
||||
std::vector<uint32_t> neighbors(neighbor_count);
|
||||
memcpy(neighbors.data(), adjacency_ptr + sizeof(uint32_t), needed);
|
||||
|
||||
std::cout << " neighbors=[";
|
||||
for (size_t kk = 0; kk < std::min<size_t>(10, neighbor_count); kk++) {
|
||||
std::cout << neighbors[kk];
|
||||
if (kk + 1 < std::min<size_t>(10, neighbor_count))
|
||||
std::cout << ", ";
|
||||
}
|
||||
if (neighbor_count > 10)
|
||||
std::cout << " ... (total " << neighbor_count << ")";
|
||||
std::cout << "]\n";
|
||||
} // End of else (single node lookup mode)
|
||||
|
||||
gf.close();
|
||||
return 0;
|
||||
}
|
||||
15
research/utils/diskann_degree_distribution.fish
Executable file
15
research/utils/diskann_degree_distribution.fish
Executable file
@@ -0,0 +1,15 @@
|
||||
#! /bin/fish
|
||||
|
||||
# get the dir of this script
|
||||
set -x SCRIPT_DIR (dirname (realpath $0))
|
||||
|
||||
g++ $SCRIPT_DIR/analyze_diskann_graph.cpp -o $SCRIPT_DIR/analyze_diskann_graph
|
||||
|
||||
# get args
|
||||
set -x INDEX_PATH $argv[1]
|
||||
|
||||
./analyze_diskann_graph $INDEX_PATH $INDEX_PATH.degree_distribution.txt
|
||||
|
||||
python plot_degree_distribution.py $INDEX_PATH.degree_distribution.txt
|
||||
|
||||
rm $INDEX_PATH.degree_distribution.txt
|
||||
30
research/utils/download_mac.fish
Normal file
30
research/utils/download_mac.fish
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env fish
|
||||
|
||||
set scaling_out_dir "/Users/ec2-user/scaling_out"
|
||||
|
||||
# Define an array of paths to download
|
||||
set paths \
|
||||
"examples/" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_medoids.bin" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_centroids.bin" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_max_base_norm.bin" \
|
||||
"embeddings/facebook/contriever-msmarco/rpj_wiki/compressed_10/" \
|
||||
"passages/rpj_wiki/8-shards/" \
|
||||
"indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json"
|
||||
|
||||
# Download each path using a for loop
|
||||
for path in $paths
|
||||
echo "Downloading $path..."
|
||||
# if ends with /, then create the directory
|
||||
if string match -q "*/" $path
|
||||
echo "Creating directory $scaling_out_dir/$path"
|
||||
mkdir -p "$scaling_out_dir/$path"
|
||||
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path" --recursive
|
||||
else
|
||||
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path"
|
||||
end
|
||||
end
|
||||
|
||||
echo "Download completed."
|
||||
422
research/utils/embedding_comp.py
Normal file
422
research/utils/embedding_comp.py
Normal file
@@ -0,0 +1,422 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import kendalltau, spearmanr
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else
|
||||
("mps" if torch.backends.mps.is_available() else "cpu"))
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 定义自定义比较函数(基于内积)
|
||||
def compare(a, b):
|
||||
"""
|
||||
计算两个向量的内积,并返回其负值作为距离度量
|
||||
数值越小表示越相似(与提供的代码一致)
|
||||
"""
|
||||
result = np.dot(a, b)
|
||||
return -result # 返回负值,与原代码一致
|
||||
|
||||
# 批量计算相似度
|
||||
def compute_similarities(queries, corpus):
|
||||
"""计算查询向量与语料库向量之间的相似度矩阵"""
|
||||
similarities = np.zeros((len(queries), len(corpus)))
|
||||
for i, query in enumerate(queries):
|
||||
for j, doc in enumerate(corpus):
|
||||
similarities[i, j] = compare(query, doc)
|
||||
return similarities
|
||||
|
||||
# 加载两个模型
|
||||
model_names = [
|
||||
"facebook/contriever-msmarco", # Contriever模型
|
||||
"facebook/contriever-msmarco-int4" # Contriever模型 (int4)
|
||||
]
|
||||
|
||||
# 扩展的样本文本 - 分为多个主题组
|
||||
texts = [
|
||||
# 组1: 关于狐狸和动物 (0-9)
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"A rapid auburn fox leaps above the inactive canine.",
|
||||
"The sly fox outsmarts the hunting hounds in the forest.",
|
||||
"Foxes are known for their cunning behavior and bushy tails.",
|
||||
"The red fox is the largest of the true foxes and the most common fox species.",
|
||||
"Dogs have been companions to humans for thousands of years.",
|
||||
"The lazy dog slept through the commotion of the playful fox.",
|
||||
"Wolves and foxes belong to the same family, Canidae.",
|
||||
"The arctic fox changes its coat color with the seasons.",
|
||||
"Domestic dogs come in hundreds of breeds of various sizes and appearances.",
|
||||
|
||||
# 组2: 人工智能和机器学习 (10-19)
|
||||
"Machine learning is a branch of artificial intelligence.",
|
||||
"Deep learning is a subset of machine learning.",
|
||||
"Neural networks are computing systems inspired by biological neural networks.",
|
||||
"AI systems can now beat human champions at complex games like chess and Go.",
|
||||
"Natural language processing allows computers to understand human language.",
|
||||
"Reinforcement learning involves training agents to make sequences of decisions.",
|
||||
"Computer vision enables machines to derive information from images and videos.",
|
||||
"The Turing test measures a machine's ability to exhibit intelligent behavior.",
|
||||
"Supervised learning uses labeled training data to learn the mapping function.",
|
||||
"Unsupervised learning finds patterns in data without pre-existing labels.",
|
||||
|
||||
# 组3: 巴黎和法国地标 (20-29)
|
||||
"The Eiffel Tower is located in Paris, France.",
|
||||
"The Louvre Museum is in the city of Paris.",
|
||||
"Notre-Dame Cathedral is a medieval Catholic cathedral on the Île de la Cité in Paris.",
|
||||
"The Arc de Triomphe stands at the center of the Place Charles de Gaulle in Paris.",
|
||||
"The Seine River flows through the heart of Paris.",
|
||||
"Montmartre is a large hill in Paris's 18th arrondissement known for its artistic history.",
|
||||
"The Palace of Versailles is located in the Île-de-France region of France.",
|
||||
"The Champs-Élysées is an avenue in Paris famous for its theatres, cafés, and luxury shops.",
|
||||
"The Sacré-Cœur Basilica offers one of the most beautiful panoramic views of Paris.",
|
||||
"The Musée d'Orsay houses the largest collection of impressionist masterpieces in the world.",
|
||||
|
||||
# 组4: 可再生能源 (30-39)
|
||||
"Solar panels convert sunlight into electricity.",
|
||||
"Wind turbines generate power from moving air.",
|
||||
"Hydroelectric power is generated from flowing water.",
|
||||
"Geothermal energy harnesses heat from within the Earth.",
|
||||
"Biomass energy comes from organic materials like plants and waste.",
|
||||
"Tidal energy uses the natural rise and fall of coastal tidal waters.",
|
||||
"Renewable energy sources can help reduce greenhouse gas emissions.",
|
||||
"Solar farms can span hundreds of acres with thousands of panels.",
|
||||
"Offshore wind farms are built in bodies of water to harvest wind energy.",
|
||||
"Energy storage systems are crucial for balancing renewable energy supply and demand.",
|
||||
|
||||
# 组5: 编程语言 (40-49)
|
||||
"Python is a popular programming language for data science.",
|
||||
"JavaScript is commonly used for web development.",
|
||||
"Java is known for its 'write once, run anywhere' capability.",
|
||||
"C++ provides high-performance and close hardware control.",
|
||||
"Ruby is praised for its simplicity and productivity.",
|
||||
"PHP is a server-side scripting language designed for web development.",
|
||||
"Swift is used to develop applications for Apple platforms.",
|
||||
"Rust offers memory safety without using garbage collection.",
|
||||
"Go was designed at Google to improve programming productivity.",
|
||||
"Kotlin is fully interoperable with Java and provides more concise syntax.",
|
||||
]
|
||||
|
||||
# 扩展的查询句子
|
||||
query_texts = [
|
||||
# 动物相关查询
|
||||
"A fox jumped over a dog.",
|
||||
"Wild animals and their behaviors in forests.",
|
||||
"Different species of foxes around the world.",
|
||||
|
||||
# AI相关查询
|
||||
"Artificial intelligence and neural networks.",
|
||||
"Machine learning algorithms and applications.",
|
||||
"The future of deep learning technology.",
|
||||
|
||||
# 巴黎相关查询
|
||||
"Famous landmarks in Paris, France.",
|
||||
"Tourist attractions along the Seine River.",
|
||||
"Historical buildings and museums in Paris.",
|
||||
|
||||
# 能源相关查询
|
||||
"Renewable energy sources and sustainability.",
|
||||
"Solar and wind power generation technologies.",
|
||||
"Alternative clean energy solutions for the future.",
|
||||
|
||||
# 编程相关查询
|
||||
"Computer programming languages comparison.",
|
||||
"Best languages for web development.",
|
||||
"Programming tools for data science applications."
|
||||
]
|
||||
|
||||
# 函数:获取BGE模型的嵌入
|
||||
def get_bge_embeddings(model, tokenizer, texts, device):
|
||||
# 处理大量文本时分批进行
|
||||
batch_size = 16
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
|
||||
max_length=512, return_tensors='pt').to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = model(**encoded_input)
|
||||
|
||||
# BGE使用[CLS]标记
|
||||
embeddings = model_output.last_hidden_state[:, 0]
|
||||
# 归一化嵌入
|
||||
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
all_embeddings.append(normalized_embeddings.cpu().numpy())
|
||||
|
||||
return np.vstack(all_embeddings)
|
||||
|
||||
# 函数:获取Contriever模型的嵌入
|
||||
def get_contriever_embeddings(model, tokenizer, texts, device, use_int4=False):
|
||||
# 处理大量文本时分批进行
|
||||
batch_size = 16
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
|
||||
max_length=512, return_tensors='pt').to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = model(**encoded_input)
|
||||
|
||||
# Contriever使用平均池化
|
||||
attention_mask = encoded_input['attention_mask'].unsqueeze(-1)
|
||||
embeddings = (model_output.last_hidden_state * attention_mask).sum(1) / attention_mask.sum(1)
|
||||
# 归一化嵌入
|
||||
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
all_embeddings.append(normalized_embeddings.cpu().numpy())
|
||||
|
||||
return np.vstack(all_embeddings)
|
||||
|
||||
# 主函数
|
||||
def compare_embeddings():
|
||||
results = {}
|
||||
|
||||
for i, model_name in enumerate(model_names):
|
||||
model_display_name = model_name
|
||||
# 给第二个模型一个不同的显示名称,以便区分
|
||||
if i == 1:
|
||||
model_display_name = "facebook/contriever-msmarco-int4"
|
||||
|
||||
print(f"\n======= 加载模型 {i+1}: {model_display_name} =======")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_names[0]) # 两个模型使用相同的tokenizer
|
||||
|
||||
# 如果是第二个模型(int4版本),进行量化
|
||||
if i == 1:
|
||||
print("应用int4量化...")
|
||||
try:
|
||||
from transformers import BitsAndBytesConfig
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_names[0], # 使用相同的基础模型
|
||||
quantization_config=quantization_config,
|
||||
device_map="auto"
|
||||
)
|
||||
print("成功加载int4模型")
|
||||
except Exception as e:
|
||||
print(f"int4加载失败: {e}")
|
||||
print("回退到标准模型...")
|
||||
model = AutoModel.from_pretrained(model_names[0]).to(device)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_names[0]).to(device)
|
||||
|
||||
model.eval()
|
||||
|
||||
print(f"使用 {model_display_name} 生成嵌入...")
|
||||
# 所有模型都使用contriever
|
||||
use_int4 = i == 1
|
||||
corpus_embeddings = get_contriever_embeddings(model, tokenizer, texts, device, use_int4)
|
||||
query_embeddings = get_contriever_embeddings(model, tokenizer, query_texts, device, use_int4)
|
||||
|
||||
print(f"语料库嵌入形状: {corpus_embeddings.shape}")
|
||||
print(f"查询嵌入形状: {query_embeddings.shape}")
|
||||
|
||||
# 使用自定义函数计算相似度
|
||||
similarity_scores = compute_similarities(query_embeddings, corpus_embeddings)
|
||||
|
||||
# 对每个查询,按相似度排序文本索引(较小的值表示更相似)
|
||||
ranked_indices = {}
|
||||
for j, scores in enumerate(similarity_scores):
|
||||
# 按相似度从低到高排序(因为我们返回的是负内积值)
|
||||
sorted_indices = np.argsort(scores)
|
||||
ranked_indices[f"query_{j+1}"] = sorted_indices
|
||||
|
||||
results[model_display_name] = {
|
||||
'corpus_embeddings': corpus_embeddings,
|
||||
'query_embeddings': query_embeddings,
|
||||
'similarity_scores': similarity_scores,
|
||||
'ranked_indices': ranked_indices
|
||||
}
|
||||
|
||||
# 立即打印这个模型的一些结果作为验证
|
||||
print(f"\n=== {model_display_name} 初步结果 ===")
|
||||
# 显示第一个查询的前3个结果
|
||||
query_idx = 0
|
||||
ranked_idx = ranked_indices[f"query_{query_idx+1}"]
|
||||
top_texts = [texts[idx] for idx in ranked_idx[:3]]
|
||||
print(f"查询: '{query_texts[query_idx]}'")
|
||||
print(f"排名前3位的文本:")
|
||||
for j, text in enumerate(top_texts):
|
||||
idx = ranked_idx[j]
|
||||
score = similarity_scores[query_idx][idx]
|
||||
print(f" {j+1}. [ID:{idx}] {text} (分数: {score:.4f})")
|
||||
|
||||
return results
|
||||
|
||||
# 分析结果
|
||||
def analyze_results(results):
|
||||
models = list(results.keys())
|
||||
|
||||
# 1. 比较相似度分数
|
||||
print("\n=== 相似度分数比较 ===")
|
||||
for model_name, result in results.items():
|
||||
similarities = result['similarity_scores'].flatten()
|
||||
print(f"{model_name} 相似度统计:")
|
||||
print(f" 平均值: {similarities.mean():.4f}")
|
||||
print(f" 最小值: {similarities.min():.4f}")
|
||||
print(f" 最大值: {similarities.max():.4f}")
|
||||
print(f" 标准差: {similarities.std():.4f}")
|
||||
|
||||
# 2. 比较排序结果(针对每个查询显示前5个结果)
|
||||
print("\n=== 排序结果比较 ===")
|
||||
for query_idx in range(len(query_texts)):
|
||||
query_key = f"query_{query_idx+1}"
|
||||
print(f"\n查询 {query_idx+1}: '{query_texts[query_idx]}'")
|
||||
|
||||
for model_name in models:
|
||||
ranked_idx = results[model_name]['ranked_indices'][query_key]
|
||||
top_texts = [texts[idx] for idx in ranked_idx[:5]]
|
||||
print(f"{model_name} 排名前5位的文本:")
|
||||
for i, text in enumerate(top_texts):
|
||||
idx = ranked_idx[i]
|
||||
score = results[model_name]['similarity_scores'][query_idx][idx]
|
||||
print(f" {i+1}. [ID:{idx}] {text} (分数: {score:.4f})")
|
||||
|
||||
# 3. 排序一致性分析
|
||||
print("\n=== 模型间排序一致性分析 ===")
|
||||
kendall_tau_scores = []
|
||||
spearman_scores = []
|
||||
|
||||
for query_idx in range(len(query_texts)):
|
||||
query_key = f"query_{query_idx+1}"
|
||||
|
||||
# 获取各模型的排序结果(只比较前10个结果)
|
||||
model1_top10 = results[models[0]]['ranked_indices'][query_key][:10]
|
||||
model2_top10 = results[models[1]]['ranked_indices'][query_key][:10]
|
||||
|
||||
# 计算排序一致性
|
||||
kt, _ = kendalltau(model1_top10, model2_top10)
|
||||
sr, _ = spearmanr(model1_top10, model2_top10)
|
||||
|
||||
kendall_tau_scores.append(kt)
|
||||
spearman_scores.append(sr)
|
||||
|
||||
# 计算前10个结果的重叠率
|
||||
overlap = len(set(model1_top10) & set(model2_top10))
|
||||
overlap_rate = overlap / 10.0
|
||||
|
||||
print(f"查询 {query_idx+1} '{query_texts[query_idx]}':")
|
||||
print(f" Kendall's Tau = {kt:.4f}, Spearman's rank correlation = {sr:.4f}")
|
||||
print(f" 前10结果重叠率: {overlap_rate:.2f} ({overlap}/10)")
|
||||
|
||||
print(f"\n平均 Kendall's Tau: {np.mean(kendall_tau_scores):.4f}")
|
||||
print(f"平均 Spearman's rank correlation: {np.mean(spearman_scores):.4f}")
|
||||
|
||||
# 4. 可视化相似度分布差异
|
||||
plt.figure(figsize=(12, 6))
|
||||
for i, model_name in enumerate(models):
|
||||
sns.histplot(results[model_name]['similarity_scores'].flatten(),
|
||||
kde=True, label=model_name, alpha=0.6)
|
||||
|
||||
plt.title('不同模型的相似度分布')
|
||||
plt.xlabel('相似度得分(越小越相似)')
|
||||
plt.ylabel('频率')
|
||||
plt.legend()
|
||||
plt.savefig('similarity_distribution.png')
|
||||
print("已保存相似度分布图表到 'similarity_distribution.png'")
|
||||
|
||||
# 5. 可视化主题相关性
|
||||
plt.figure(figsize=(15, 10))
|
||||
|
||||
# 为每个主题组定义颜色
|
||||
topic_colors = {
|
||||
'动物': 'blue',
|
||||
'AI': 'red',
|
||||
'巴黎': 'green',
|
||||
'能源': 'purple',
|
||||
'编程': 'orange'
|
||||
}
|
||||
|
||||
# 定义主题组范围
|
||||
topic_ranges = {
|
||||
'动物': (0, 10),
|
||||
'AI': (10, 20),
|
||||
'巴黎': (20, 30),
|
||||
'能源': (30, 40),
|
||||
'编程': (40, 50)
|
||||
}
|
||||
|
||||
# 对每个查询显示前10个结果的主题分布
|
||||
query_groups = [
|
||||
[0, 1, 2], # 动物查询组
|
||||
[3, 4, 5], # AI查询组
|
||||
[6, 7, 8], # 巴黎查询组
|
||||
[9, 10, 11], # 能源查询组
|
||||
[12, 13, 14] # 编程查询组
|
||||
]
|
||||
|
||||
for group_idx, group in enumerate(query_groups):
|
||||
plt.subplot(len(query_groups), 1, group_idx+1)
|
||||
|
||||
# 为每个模型计算主题分布
|
||||
bar_width = 0.35
|
||||
bar_positions = np.arange(len(topic_ranges))
|
||||
|
||||
for model_idx, model_name in enumerate(models):
|
||||
# 统计每个主题在前10个结果中的出现次数
|
||||
topic_counts = {topic: 0 for topic in topic_ranges.keys()}
|
||||
|
||||
for query_idx in group:
|
||||
query_key = f"query_{query_idx+1}"
|
||||
top10 = results[model_name]['ranked_indices'][query_key][:10]
|
||||
|
||||
for idx in top10:
|
||||
for topic, (start, end) in topic_ranges.items():
|
||||
if start <= idx < end:
|
||||
topic_counts[topic] += 1
|
||||
|
||||
# 绘制主题分布柱状图
|
||||
plt.bar(bar_positions + (model_idx * bar_width),
|
||||
list(topic_counts.values()),
|
||||
bar_width,
|
||||
label=model_name)
|
||||
|
||||
plt.title(f"查询组 {group_idx+1}: {', '.join([query_texts[i] for i in group[:1]])}")
|
||||
plt.xticks(bar_positions + bar_width/2, list(topic_ranges.keys()))
|
||||
plt.ylabel('前10结果中的出现次数')
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('topic_distribution.png')
|
||||
print("已保存主题分布图表到 'topic_distribution.png'")
|
||||
|
||||
# 6. 可视化查询与相关文档的相似度热图
|
||||
plt.figure(figsize=(15, 12))
|
||||
|
||||
for i, model_name in enumerate(models):
|
||||
plt.subplot(2, 1, i+1)
|
||||
|
||||
# 获取相似度矩阵(负数越小表示越相似)
|
||||
sim_matrix = results[model_name]['similarity_scores']
|
||||
|
||||
# 将负值转换为正值以便可视化(越大表示越相似)
|
||||
sim_matrix_viz = -sim_matrix
|
||||
|
||||
# 绘制热图
|
||||
sns.heatmap(sim_matrix_viz, cmap='YlGnBu',
|
||||
xticklabels=[f"Doc{i}" for i in range(len(texts))],
|
||||
yticklabels=[f"Q{i+1}" for i in range(len(query_texts))],
|
||||
cbar_kws={'label': '相似度(越高越相似)'})
|
||||
|
||||
plt.title(f"{model_name} 相似度热图")
|
||||
plt.xlabel('文档ID')
|
||||
plt.ylabel('查询ID')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('similarity_heatmap.png')
|
||||
print("已保存相似度热图到 'similarity_heatmap.png'")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始比较嵌入模型...")
|
||||
results = compare_embeddings()
|
||||
analyze_results(results)
|
||||
print("\n比较完成!")
|
||||
444
research/utils/evaluate_results.py
Normal file
444
research/utils/evaluate_results.py
Normal file
@@ -0,0 +1,444 @@
|
||||
# Filename: evaluate_results_xai_line_sync.py
|
||||
import openai
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dotenv import load_dotenv
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
import concurrent.futures
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
# --- Configuration ---
|
||||
load_dotenv()
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
if not OPENAI_API_KEY:
|
||||
raise ValueError("Please set the OPENAI_API_KEY in your .env file")
|
||||
|
||||
try:
|
||||
client = openai.OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
)
|
||||
except ImportError:
|
||||
print("Please install the latest OpenAI library: pip install --upgrade openai")
|
||||
exit()
|
||||
except openai.AuthenticationError:
|
||||
print("OpenAI library reported an AuthenticationError. Ensure OPENAI_API_KEY is correct.")
|
||||
exit()
|
||||
|
||||
LLM_MODEL = "gpt-3.5-turbo" # Using OpenAI's standard model
|
||||
MAX_RETRIES = 5
|
||||
INITIAL_RETRY_DELAY_SECONDS = 5
|
||||
REQUEST_TIMEOUT_SECONDS = 90
|
||||
MAX_WORKERS = 10 # Number of parallel workers
|
||||
|
||||
# --- File Paths (Adjust as needed) ---
|
||||
# User provided paths
|
||||
QUERIES_FILE_PATH = "/opt/dlami/nvme/scaling_out/examples/enron_eval_retrieval.jsonl"
|
||||
RAW_PASSAGES_FILE_PATH = "/opt/dlami/nvme/scaling_out/passages/enron_emails/1-shards/raw_passages-0-of-1.jsonl"
|
||||
RESULTS_FILE_PATH = "search_results_top10_bm25.jsonl" # This file's Nth line corresponds to QUERIES_FILE_PATH's Nth line
|
||||
OUTPUT_EVALUATION_FILE = "llm_containment_evaluations_xai_line_sync.jsonl"
|
||||
|
||||
# --- LLM Prompt Definitions for Containment (Same as before) ---
|
||||
CONTAINMENT_SYSTEM_PROMPT = """You are an AI evaluator. Your task is to determine if the core information presented in the 'Retrieved Passage' is directly contained within *any* of the text snippets provided in the 'Ground Truth Email Snippets' list."""
|
||||
CONTAINMENT_USER_TEMPLATE = """Retrieved Passage:
|
||||
"{retrieved_passage_text}"
|
||||
|
||||
---
|
||||
Ground Truth Email Snippets (Parts of the correct source email):
|
||||
{ground_truth_snippets_formatted_list}
|
||||
---
|
||||
|
||||
Is the core information of the 'Retrieved Passage' directly present or fully contained within *any* of the 'Ground Truth Email Snippets' listed above?
|
||||
- Focus on whether the specific facts or statements in the 'Retrieved Passage' can be found within the ground truth snippets.
|
||||
- Ignore minor formatting differences. If the retrieved passage is a direct quote or a very close paraphrase of content within the ground truth snippets, answer YES.
|
||||
- Respond YES if the Retrieved Passage's content is clearly represented in one or more of the ground truth snippets.
|
||||
- Respond NO if the Retrieved Passage's content is not found, is contradictory, or introduces significant information not present in the ground truth snippets.
|
||||
|
||||
Your response must be a single word: YES or NO.
|
||||
"""
|
||||
|
||||
# --- Data Loading Functions ---
|
||||
|
||||
def load_queries_as_list(file_path):
|
||||
"""
|
||||
Loads queries from a jsonl file into a list, preserving order.
|
||||
Each item in the list is a dict containing original_id, query_text, and ground_truth_message_ids.
|
||||
"""
|
||||
queries_list = []
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
required_keys = ["id", "query", "ground_truth_message_ids"]
|
||||
if not all(key in data for key in required_keys):
|
||||
print(f"Warning: Skipping line {line_num + 1} in query file due to missing keys: {line.strip()}")
|
||||
continue
|
||||
if not isinstance(data["ground_truth_message_ids"], list):
|
||||
print(f"Warning: 'ground_truth_message_ids' is not a list in line {line_num + 1}. Skipping: {line.strip()}")
|
||||
continue
|
||||
queries_list.append({
|
||||
"original_id": data["id"], # Store the original ID from the file
|
||||
"query_text": data["query"],
|
||||
"ground_truth_message_ids": data["ground_truth_message_ids"]
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in query file: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Queries file not found at {file_path}")
|
||||
exit()
|
||||
print(f"Loaded {len(queries_list)} queries (as a list) from {file_path}")
|
||||
return queries_list
|
||||
|
||||
def load_all_passages_by_message_id(raw_passages_file_path):
|
||||
"""Loads all raw passages into memory, grouped by message_id. (Same as before)"""
|
||||
passages_dict = defaultdict(list)
|
||||
# ... (implementation from previous script, no changes needed here) ...
|
||||
print(f"Loading all raw passages from {raw_passages_file_path} into memory...")
|
||||
try:
|
||||
with open(raw_passages_file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "message_id" in data and "text" in data:
|
||||
passages_dict[data["message_id"]].append(data["text"])
|
||||
else:
|
||||
print(f"Warning: Skipping line {line_num+1} in raw passages file due to missing 'message_id' or 'text'.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in raw passages file: {line.strip()}")
|
||||
print(f"Finished loading raw passages. Found {len(passages_dict)} unique message IDs.")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Raw passages file not found at {raw_passages_file_path}")
|
||||
exit()
|
||||
except MemoryError:
|
||||
print("Error: Ran out of memory loading all raw passages. Consider an indexed approach.")
|
||||
exit()
|
||||
return dict(passages_dict)
|
||||
|
||||
def load_search_results_as_list(file_path):
|
||||
"""Loads search results from a jsonl file into a list, preserving order."""
|
||||
results_list = []
|
||||
# ... (implementation similar to load_queries_as_list, parsing each line as a dict) ...
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
# We expect "query_id" (though not used for matching) and "passages"
|
||||
if "passages" not in data: # query_id might be implicitly by order
|
||||
print(f"Warning: Skipping line {line_num + 1} in search results file due to missing 'passages' key: {line.strip()}")
|
||||
continue
|
||||
results_list.append(data)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping malformed JSON line {line_num + 1} in search results file: {line.strip()}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Search results file not found at {file_path}")
|
||||
exit()
|
||||
print(f"Loaded {len(results_list)} search result sets (as a list) from {file_path}")
|
||||
return results_list
|
||||
|
||||
|
||||
def format_ground_truth_snippets(snippet_list):
|
||||
"""Formats the list of ground truth snippets for the prompt. (Same as before)"""
|
||||
# ... (implementation from previous script) ...
|
||||
if not snippet_list:
|
||||
return " [No ground truth snippets found for the target message ID(s)]"
|
||||
formatted = []
|
||||
for i, snippet in enumerate(snippet_list):
|
||||
display_snippet = (snippet[:500] + '...') if len(snippet) > 500 else snippet
|
||||
formatted.append(f" {i+1}. {display_snippet}")
|
||||
return "\n".join(formatted)
|
||||
|
||||
# --- LLM API Call Function ---
|
||||
def get_llm_containment_evaluation(retrieved_passage_text: str, ground_truth_snippets_list: List[str], query_id_for_log: str, passage_identifier_info: str, query_text_for_context: str = "") -> str:
|
||||
"""Calls the OpenAI API with retry logic."""
|
||||
formatted_gt_snippets = format_ground_truth_snippets(ground_truth_snippets_list)
|
||||
# max_gt_chars_in_prompt = 5000 # Arbitrary limit, adjust as needed
|
||||
# if len(formatted_gt_snippets) > max_gt_chars_in_prompt:
|
||||
# print(f"Warning: Ground truth snippets for Q_log_id:{query_id_for_log} are too long ({len(formatted_gt_snippets)} chars), truncating for LLM prompt.")
|
||||
# formatted_gt_snippets = formatted_gt_snippets[:max_gt_chars_in_prompt] + "\n [... Snippets Truncated ...]"
|
||||
|
||||
user_prompt = CONTAINMENT_USER_TEMPLATE.format(
|
||||
retrieved_passage_text=retrieved_passage_text,
|
||||
ground_truth_snippets_formatted_list=formatted_gt_snippets
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": CONTAINMENT_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
current_retry_delay = INITIAL_RETRY_DELAY_SECONDS
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=LLM_MODEL,
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
max_tokens=10,
|
||||
timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
answer = response.choices[0].message.content.strip().upper()
|
||||
if answer in ["YES", "NO"]:
|
||||
return answer
|
||||
else:
|
||||
print(f"Warning: Unexpected LLM response content '{answer[:100]}' for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Defaulting to NO.")
|
||||
return "NO"
|
||||
except openai.APIConnectionError as e:
|
||||
error_message = f"API Connection Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
|
||||
except openai.RateLimitError as e:
|
||||
error_message = f"API Rate Limit Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
|
||||
except openai.APIStatusError as e:
|
||||
error_message = f"API Status Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e.status_code} - {e.response}"
|
||||
if e.status_code == 401:
|
||||
return "ERROR_AUTH"
|
||||
if e.status_code == 500:
|
||||
pass
|
||||
else:
|
||||
return "ERROR_API_CLIENT"
|
||||
except Exception as e:
|
||||
error_message = f"Unexpected error with OpenAI lib (Attempt {attempt + 1}/{MAX_RETRIES}): {type(e).__name__} - {e}"
|
||||
|
||||
print(f"{error_message}. Query Log ID: {query_id_for_log}, Passage: {passage_identifier_info}")
|
||||
if "ERROR_AUTH" in error_message or "ERROR_API_CLIENT" in error_message:
|
||||
break
|
||||
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
print(f"Retrying in {current_retry_delay} seconds...")
|
||||
time.sleep(current_retry_delay)
|
||||
current_retry_delay = min(current_retry_delay * 2, 60)
|
||||
else:
|
||||
print(f"Max retries ({MAX_RETRIES}) reached for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Skipping.")
|
||||
return "ERROR_MAX_RETRIES"
|
||||
return "ERROR_MAX_RETRIES"
|
||||
|
||||
def process_query_passage_pair(args: Tuple[Dict[str, Any], Dict[str, Any], Dict[str, List[str]], set]) -> List[Dict[str, Any]]:
|
||||
"""Process a single query-passage pair for parallel execution."""
|
||||
query_info, result_item, passages_lookup, already_evaluated = args
|
||||
evaluations = []
|
||||
|
||||
query_original_id = query_info["original_id"]
|
||||
query_text = query_info["query_text"]
|
||||
target_message_ids = query_info.get("ground_truth_message_ids", [])
|
||||
|
||||
if not target_message_ids:
|
||||
return evaluations
|
||||
|
||||
ground_truth_snippets = []
|
||||
for msg_id_in_query_file in target_message_ids:
|
||||
msg_id_to_lookup = msg_id_in_query_file
|
||||
if msg_id_in_query_file.startswith("<") and msg_id_in_query_file.endswith(">"):
|
||||
msg_id_to_lookup = msg_id_in_query_file[1:-1]
|
||||
|
||||
snippets = passages_lookup.get(msg_id_to_lookup)
|
||||
if snippets:
|
||||
ground_truth_snippets.extend(snippets)
|
||||
|
||||
if not ground_truth_snippets:
|
||||
return evaluations
|
||||
|
||||
retrieved_passages = result_item.get("passages", [])
|
||||
if not retrieved_passages:
|
||||
return evaluations
|
||||
|
||||
for passage_idx, passage_obj in enumerate(retrieved_passages):
|
||||
if not isinstance(passage_obj, dict):
|
||||
print(f"Warning: Invalid passage format for Q_original_id:{query_original_id}, passage index {passage_idx}. Skipping passage.")
|
||||
continue
|
||||
|
||||
retrieved_passage_text = passage_obj.get("text", "").strip()
|
||||
passage_identifier = passage_obj.get("passage_id", passage_obj.get("id", f"retrieved_idx_{passage_idx}"))
|
||||
|
||||
evaluation_key = (query_original_id, passage_identifier)
|
||||
if evaluation_key in already_evaluated:
|
||||
continue
|
||||
|
||||
passage_text_preview = (retrieved_passage_text[:75] + '...') if len(retrieved_passage_text) > 75 else retrieved_passage_text
|
||||
|
||||
if not retrieved_passage_text:
|
||||
evaluation = "NO"
|
||||
else:
|
||||
evaluation = get_llm_containment_evaluation(
|
||||
retrieved_passage_text,
|
||||
ground_truth_snippets,
|
||||
query_original_id,
|
||||
passage_identifier,
|
||||
query_text
|
||||
)
|
||||
if evaluation == "ERROR_AUTH":
|
||||
print("Authentication error with OpenAI API. Stopping script.")
|
||||
return evaluations
|
||||
|
||||
evaluation_record = {
|
||||
"query_original_id": query_original_id,
|
||||
"passage_identifier": passage_identifier,
|
||||
"passage_text_preview": passage_text_preview,
|
||||
"evaluation": evaluation,
|
||||
"model_used": LLM_MODEL,
|
||||
"ground_truth_message_ids_checked": target_message_ids
|
||||
}
|
||||
evaluations.append(evaluation_record)
|
||||
|
||||
return evaluations
|
||||
|
||||
# --- Resume Logic ---
|
||||
def load_existing_evaluations(output_file):
|
||||
"""Loads already evaluated query-passage pairs using 'passage_identifier' and 'query_original_id'. (Same as before, but keying with original_id)"""
|
||||
# ... (implementation from previous script, ensure it uses the correct ID for keys) ...
|
||||
evaluated_pairs = set()
|
||||
if os.path.exists(output_file):
|
||||
print(f"Loading existing containment evaluations from {output_file}...")
|
||||
with open(output_file, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
# Key for resuming should be based on the logged original query ID
|
||||
query_original_id = data.get('query_original_id')
|
||||
passage_identifier = data.get('passage_identifier')
|
||||
if query_original_id is not None and passage_identifier is not None:
|
||||
evaluated_pairs.add((query_original_id, passage_identifier))
|
||||
else:
|
||||
print(f"Warning: Could not identify query_original_id/passage_identifier in existing file line {line_num + 1}.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping malformed line {line_num + 1} in existing file: {line.strip()}")
|
||||
except KeyError as e:
|
||||
print(f"Warning: Skipping line {line_num + 1} with missing key '{e}' in existing file: {line.strip()}")
|
||||
print(f"Loaded {len(evaluated_pairs)} existing evaluation records.")
|
||||
else:
|
||||
print(f"No existing evaluation file found at {output_file}. Starting fresh.")
|
||||
return evaluated_pairs
|
||||
|
||||
# --- Main Execution Logic ---
|
||||
|
||||
def main():
|
||||
"""Main function to run the containment evaluation process using parallel processing."""
|
||||
print(f"Starting containment evaluation using OpenAI model: {LLM_MODEL} via OpenAI library interface.")
|
||||
|
||||
# Load data as lists
|
||||
queries_list = load_queries_as_list(QUERIES_FILE_PATH)
|
||||
passages_lookup = load_all_passages_by_message_id(RAW_PASSAGES_FILE_PATH)
|
||||
search_results_list = load_search_results_as_list(RESULTS_FILE_PATH)
|
||||
|
||||
if not queries_list or not search_results_list or not passages_lookup:
|
||||
print("Error loading one or more input files or raw passages. Exiting.")
|
||||
return
|
||||
|
||||
# Determine the number of items to process
|
||||
num_items_to_process = min(len(queries_list), len(search_results_list))
|
||||
print(f"Will process {num_items_to_process} query-result pairs.")
|
||||
|
||||
already_evaluated = load_existing_evaluations(OUTPUT_EVALUATION_FILE)
|
||||
|
||||
try:
|
||||
with open(OUTPUT_EVALUATION_FILE, 'a', encoding='utf-8') as outfile:
|
||||
# Prepare arguments for parallel processing
|
||||
process_args = [
|
||||
(queries_list[i], search_results_list[i], passages_lookup, already_evaluated)
|
||||
for i in range(num_items_to_process)
|
||||
]
|
||||
|
||||
# Use ThreadPoolExecutor for parallel processing
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||||
# Submit all tasks and get futures
|
||||
futures = [executor.submit(process_query_passage_pair, args) for args in process_args]
|
||||
|
||||
# Process results as they complete
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing query-result pairs"):
|
||||
try:
|
||||
evaluations = future.result()
|
||||
for evaluation in evaluations:
|
||||
outfile.write(json.dumps(evaluation) + "\n")
|
||||
outfile.flush()
|
||||
# Update already_evaluated set
|
||||
already_evaluated.add((evaluation["query_original_id"], evaluation["passage_identifier"]))
|
||||
except Exception as e:
|
||||
print(f"Error processing query-result pair: {e}")
|
||||
|
||||
except IOError as e:
|
||||
print(f"Error writing to output file {OUTPUT_EVALUATION_FILE}: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during the main processing loop: {e}")
|
||||
return
|
||||
|
||||
print("\n--- Containment Evaluation Script Finished ---")
|
||||
|
||||
# --- Final Summary Calculation ---
|
||||
print(f"Calculating final summary statistics from: {OUTPUT_EVALUATION_FILE}")
|
||||
final_query_containment_found = {}
|
||||
total_evaluated_pairs = 0
|
||||
error_count = 0
|
||||
evaluated_query_original_ids = set()
|
||||
|
||||
try:
|
||||
with open(OUTPUT_EVALUATION_FILE, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f):
|
||||
total_evaluated_pairs += 1
|
||||
try:
|
||||
data = json.loads(line)
|
||||
q_original_id = data['query_original_id']
|
||||
eval_result = data['evaluation']
|
||||
evaluated_query_original_ids.add(q_original_id)
|
||||
|
||||
if eval_result == "YES":
|
||||
final_query_containment_found[q_original_id] = True
|
||||
elif q_original_id not in final_query_containment_found:
|
||||
final_query_containment_found[q_original_id] = False
|
||||
if eval_result not in ["YES", "NO"]:
|
||||
error_count += 1
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
print(f"Error reading line {line_num + 1} during summary: {e} - Line: {line.strip()}")
|
||||
error_count += 1
|
||||
|
||||
num_queries_with_any_contained = sum(1 for contained in final_query_containment_found.values() if contained)
|
||||
total_unique_queries_evaluated = len(evaluated_query_original_ids)
|
||||
|
||||
if total_unique_queries_evaluated > 0:
|
||||
containment_rate_at_10 = num_queries_with_any_contained / total_unique_queries_evaluated
|
||||
print(f"\n--- Final Statistics (Containment Check) ---")
|
||||
print(f"Total unique queries processed (based on output file entries): {total_unique_queries_evaluated}")
|
||||
print(f"Number of queries with at least one contained passage (YES): {num_queries_with_any_contained}")
|
||||
print(f"Containment Match Rate @ Top 10 (Any YES): {containment_rate_at_10:.4f}")
|
||||
print(f"Total query-passage pairs processed (lines in output file): {total_evaluated_pairs}")
|
||||
if error_count > 0:
|
||||
print(f"Number of evaluation errors or non-YES/NO results: {error_count}")
|
||||
else:
|
||||
print("No evaluation results found to summarize.")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Output file {OUTPUT_EVALUATION_FILE} not found for summary.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during summary calculation: {e}")
|
||||
|
||||
print(f"\nDetailed containment evaluations saved to: {OUTPUT_EVALUATION_FILE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dummy files for testing the line sync logic
|
||||
if not os.path.exists(QUERIES_FILE_PATH):
|
||||
print(f"Warning: {QUERIES_FILE_PATH} not found. Creating dummy file.")
|
||||
with open(QUERIES_FILE_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump({"id": "q_alpha", "query": "Query Alpha Text", "ground_truth_message_ids": ["<msg_A>"]}, f); f.write("\n") # Line 0
|
||||
json.dump({"id": "q_beta", "query": "Query Beta Text", "ground_truth_message_ids": ["<msg_B>"]}, f); f.write("\n") # Line 1
|
||||
json.dump({"id": "q_gamma", "query": "Query Gamma Text", "ground_truth_message_ids": ["<msg_C>"]}, f); f.write("\n")# Line 2
|
||||
|
||||
if not os.path.exists(RAW_PASSAGES_FILE_PATH):
|
||||
print(f"Warning: {RAW_PASSAGES_FILE_PATH} not found. Creating dummy file.")
|
||||
with open(RAW_PASSAGES_FILE_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump({"text": "Content from message A snippet 1.", "id": 100, "message_id": "<msg_A>"}, f); f.write("\n")
|
||||
json.dump({"text": "Content from message A snippet 2.", "id": 101, "message_id": "<msg_A>"}, f); f.write("\n")
|
||||
json.dump({"text": "Content from message B.", "id": 200, "message_id": "<msg_B>"}, f); f.write("\n")
|
||||
json.dump({"text": "Content from message D (unrelated).", "id": 300, "message_id": "<msg_D>"}, f); f.write("\n")
|
||||
|
||||
# RESULTS_FILE_PATH should have results corresponding line-by-line to QUERIES_FILE_PATH
|
||||
if not os.path.exists(RESULTS_FILE_PATH):
|
||||
print(f"Warning: {RESULTS_FILE_PATH} not found. Creating dummy file (2 entries).")
|
||||
with open(RESULTS_FILE_PATH, 'w', encoding='utf-8') as f:
|
||||
# Result for query "q_alpha" (line 0 in queries file)
|
||||
json.dump({"query_id": "this_can_be_ignored_if_line_sync", "passages": [{"id": 101, "text": "Content from message A snippet 2."}, {"id": 300, "text": "Content from message D (unrelated)."}]}, f); f.write("\n")
|
||||
# Result for query "q_beta" (line 1 in queries file)
|
||||
json.dump({"query_id": "this_too", "passages": [{"id": 999, "text": "Some other text."}, {"id": 200, "text": "Content from message B."}]}, f); f.write("\n")
|
||||
# Note: Only 2 result sets, but 3 queries in dummy QUERIES_FILE_PATH.
|
||||
# The script will process min(len(queries_list), len(search_results_list)) if you uncomment that logic,
|
||||
# or just len(search_results_list) as it's currently written for tqdm.
|
||||
|
||||
main()
|
||||
44
research/utils/experiment.md
Normal file
44
research/utils/experiment.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Recompute Embeddings Saved
|
||||
|
||||
```console
|
||||
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute --dedup --use-partition
|
||||
python ./demo/embedding_server.py --domain rpj_wiki
|
||||
python ./demo/test_serve.py --port 8082 --nprobe 80 --re --dedup
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
Evaluation Results for nprobe = 80:
|
||||
Final Recall Rate: 0.9333
|
||||
Average total latency: 2.427s
|
||||
Average search time: 2.414s
|
||||
```
|
||||
|
||||
其中,use-partition也可以不加,也可以跑。不加的效果如下:
|
||||
```
|
||||
Results for nprobe = 80:
|
||||
Final Recall Rate: 0.9333
|
||||
Average total latency: 2.434s
|
||||
Average search time: 2.421s
|
||||
```
|
||||
|
||||
# Recompute Embeddings + Loading from disk
|
||||
|
||||
Remove `--dedup --use-partition`
|
||||
|
||||
```console
|
||||
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute
|
||||
python ./demo/embedding_server.py --domain rpj_wiki
|
||||
python ./demo/test_serve.py --port 8082 --nprobe 80 --re
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
Evaluation Results for nprobe = 80:
|
||||
Evaluation Results for nprobe = 80:
|
||||
Average F1 Score: 0.5708
|
||||
Average Exact Match Score: 0.4500
|
||||
Average Recall Rate: 0.9333
|
||||
Average total latency: 3.709s
|
||||
Average search time: 3.696s
|
||||
```
|
||||
599
research/utils/extract_results.py
Executable file
599
research/utils/extract_results.py
Executable file
@@ -0,0 +1,599 @@
|
||||
import os
|
||||
import math
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import json
|
||||
import pickle
|
||||
import pdb
|
||||
|
||||
|
||||
"""
|
||||
Automatic result extraction for BM25.
|
||||
"""
|
||||
|
||||
|
||||
def extract_data_to_table(directory_path):
|
||||
# Regular expression pattern to match the data format in file content
|
||||
content_pattern = (
|
||||
r"# tokens: (\d+(\.\d+)?)\tLM PPL: (\d+(\.\d+)?)\tPPL: (\d+(\.\d+)?)"
|
||||
)
|
||||
# Regular expression pattern to extract info from file names
|
||||
file_name_pattern_M = r"(.+)-(\d+)M-seed_(\d+).txt"
|
||||
file_name_pattern = r"(.+)-(\d+)-seed_(\d+).txt"
|
||||
|
||||
# Data storage
|
||||
data = []
|
||||
|
||||
# Iterating through each file in the directory
|
||||
for file_name in os.listdir(directory_path):
|
||||
# Checking if the file name matches the pattern
|
||||
file_match_M = re.match(file_name_pattern_M, file_name)
|
||||
file_match = re.match(file_name_pattern, file_name)
|
||||
if file_match_M:
|
||||
domain, num_samples, seed = file_match_M.groups()
|
||||
|
||||
# Reading the file and extracting data
|
||||
file_path = os.path.join(directory_path, file_name)
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
# Searching for the pattern in each line
|
||||
content_match = re.search(content_pattern, line)
|
||||
if content_match:
|
||||
# Extracting values
|
||||
tokens, lm_ppl, ppl = (
|
||||
content_match.groups()[0],
|
||||
content_match.groups()[2],
|
||||
content_match.groups()[4],
|
||||
)
|
||||
# Adding the extracted data and extra info to the list
|
||||
data.append(
|
||||
{
|
||||
"Domain": domain,
|
||||
"Samples": int(num_samples) * 1e6,
|
||||
"Seed": int(seed),
|
||||
"#eval_tokens": float(tokens),
|
||||
"LM_PPL": float(lm_ppl),
|
||||
"PPL": float(ppl),
|
||||
}
|
||||
)
|
||||
elif file_match:
|
||||
domain, num_samples, seed = file_match.groups()
|
||||
|
||||
# Reading the file and extracting data
|
||||
file_path = os.path.join(directory_path, file_name)
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
# Searching for the pattern in each line
|
||||
content_match = re.search(content_pattern, line)
|
||||
if content_match:
|
||||
# Extracting values
|
||||
tokens, lm_ppl, ppl = (
|
||||
content_match.groups()[0],
|
||||
content_match.groups()[2],
|
||||
content_match.groups()[4],
|
||||
)
|
||||
# Adding the extracted data and extra info to the list
|
||||
data.append(
|
||||
{
|
||||
"Domain": domain,
|
||||
"Samples": int(num_samples),
|
||||
"Seed": int(seed),
|
||||
"#eval_tokens": float(tokens),
|
||||
"LM_PPL": float(lm_ppl),
|
||||
"PPL": float(ppl),
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
grouped_df = df.groupby(["Domain", "Samples", "#eval_tokens"]).mean()
|
||||
|
||||
return df, grouped_df
|
||||
|
||||
|
||||
"""
|
||||
Automatic resutls extraction for dense retrieval. (new)
|
||||
"""
|
||||
|
||||
|
||||
def extract_dense_scaling_results(log_files, domain=None, plot=None):
|
||||
# Regular expression pattern to match the key-value pairs in the input string
|
||||
pattern = r"(\w[\w #]+) = ([\w.]+)"
|
||||
|
||||
data_list = []
|
||||
for file in log_files:
|
||||
with open(file, "r") as file:
|
||||
for line in file:
|
||||
# Use re.findall to extract all matches of the pattern
|
||||
matches = re.findall(pattern, line)
|
||||
|
||||
if matches:
|
||||
data_dict = {
|
||||
key.replace(" ", "_").lower(): (
|
||||
None
|
||||
if value == "None"
|
||||
else float(value)
|
||||
if value.replace(".", "", 1).isdigit()
|
||||
else value
|
||||
)
|
||||
for key, value in matches
|
||||
}
|
||||
data_list.append(data_dict)
|
||||
|
||||
df = pd.DataFrame(data_list)
|
||||
if "total_shards" in df.columns:
|
||||
df["subsample_ratio"] = df["sampled_shards"] / df["total_shards"]
|
||||
else:
|
||||
df["subsample_ratio"] = 1 / df["total_shards"]
|
||||
df = df.sort_values(by="subsample_ratio")
|
||||
print(df.head)
|
||||
|
||||
if plot:
|
||||
# Setting the plot size for better visibility
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Plotting
|
||||
for concate_k in df["concate_k"].unique():
|
||||
subset = df[df["concate_k"] == concate_k]
|
||||
if concate_k == 0:
|
||||
perplexity_when_concate_k_0 = subset["perplexity"].mean()
|
||||
plt.axhline(
|
||||
y=perplexity_when_concate_k_0,
|
||||
color="r",
|
||||
linestyle="-",
|
||||
label="Closed-book",
|
||||
)
|
||||
else:
|
||||
plt.plot(
|
||||
subset["subsample_ratio"],
|
||||
subset["perplexity"],
|
||||
label=f"Concate_k = {concate_k}",
|
||||
)
|
||||
|
||||
plt.title(f"Perplexity Change with Total Shards -- {domain}")
|
||||
plt.xlabel("Subsample Ratio")
|
||||
plt.ylabel("Perplexity")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(plot)
|
||||
return df
|
||||
|
||||
|
||||
def plot_mmlu():
|
||||
# C4 results
|
||||
labels = [
|
||||
"LM-only",
|
||||
"top-1 w/ 1/32 C4 datastore",
|
||||
"top-1 w/ 2/32 C4 datastore",
|
||||
"top-1 w/ 3/32 C4 datastore",
|
||||
"top-1 w/ 4/32 C4 datastore",
|
||||
"top-1 w/ 5/32 C4 datastore",
|
||||
"top-1 w/ 6/32 C4 datastore",
|
||||
]
|
||||
x = [0, 1, 2, 3, 4, 5, 6]
|
||||
few_shot_0_concat_1 = [30.69, 32.81, 32.05, 32.55, 32.57, 33.03, 32.88]
|
||||
few_shot_1_concat_1 = [39.67, 41.03, 41.74, 42.1, 42.62, 41.55, 42.09]
|
||||
few_shot_5_concat_1 = [42.47, 43.75, 44.37, 44.1, 44.84, 43.95, 44.49]
|
||||
|
||||
# Plotting the data
|
||||
plt.figure(figsize=(14, 8))
|
||||
|
||||
# Plot for few_shot_0_concat_1
|
||||
plt.plot(
|
||||
x,
|
||||
few_shot_0_concat_1,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
color="blue",
|
||||
label="Few-shot k=0, Concat k=1",
|
||||
)
|
||||
|
||||
# Plot for few_shot_1_concat_1
|
||||
plt.plot(
|
||||
x,
|
||||
few_shot_1_concat_1,
|
||||
marker="s",
|
||||
linestyle="-",
|
||||
color="red",
|
||||
label="Few-shot k=1, Concat k=1",
|
||||
)
|
||||
|
||||
# Plot for few_shot_5_concat_1
|
||||
plt.plot(
|
||||
x,
|
||||
few_shot_5_concat_1,
|
||||
marker="^",
|
||||
linestyle="-",
|
||||
color="green",
|
||||
label="Few-shot k=5, Concat k=1",
|
||||
)
|
||||
|
||||
# Adding details
|
||||
plt.title("MMLU Performance")
|
||||
plt.xlabel("Retrieval-based LM Datastore Composition")
|
||||
plt.ylabel("Accuracy")
|
||||
plt.xticks(ticks=x, labels=labels, rotation=45, ha="right")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.grid(True)
|
||||
plt.savefig("mmlu_c4_scaling.png")
|
||||
|
||||
|
||||
def extract_lm_eval_results(
|
||||
result_dir, task_name, model_name, n_shot_list, n_doc_list, datastore_name_filter=""
|
||||
):
|
||||
markers = ["o", "s", "^", "D", "*", "p", "H", "x"]
|
||||
colors = plt.cm.tab20.colors
|
||||
|
||||
all_data = []
|
||||
for subdir, dirs, files in os.walk(result_dir):
|
||||
num_ints = len(os.path.basename(subdir).split("-"))
|
||||
for file in files:
|
||||
if file.endswith(".jsonl"):
|
||||
file_path = os.path.join(subdir, file)
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
data["SubdirLevel"] = num_ints
|
||||
data["n-shot"], data["n-doc"] = (
|
||||
int(data["n-shot"]),
|
||||
int(data["n-doc"]),
|
||||
)
|
||||
data["Value"] = float(data["Value"])
|
||||
all_data.append(data)
|
||||
|
||||
filtered_data = [
|
||||
d
|
||||
for d in all_data
|
||||
if datastore_name_filter in result_dir
|
||||
and d["n-shot"] in n_shot_list
|
||||
and d["n-doc"] in n_doc_list
|
||||
and d["SubdirLevel"] > 0
|
||||
]
|
||||
|
||||
plot_data = {}
|
||||
for d in filtered_data:
|
||||
key = (d["n-shot"], d["n-doc"])
|
||||
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
|
||||
|
||||
sorted_keys = sorted(plot_data.keys(), key=lambda x: (x[0], x[1]))
|
||||
|
||||
closed_book_values = {}
|
||||
for i, key in enumerate(sorted_keys):
|
||||
n_shot, n_doc = key
|
||||
if n_doc == 0:
|
||||
value = plot_data[key][-1][-1]
|
||||
closed_book_values.update({n_shot: value})
|
||||
|
||||
plt.figure(figsize=(15, 10))
|
||||
for i, key in enumerate(sorted_keys):
|
||||
n_shot, n_doc = key
|
||||
if n_doc == 0:
|
||||
continue
|
||||
values = plot_data[key]
|
||||
values.append(
|
||||
(0, closed_book_values[n_shot])
|
||||
if n_shot in closed_book_values.keys()
|
||||
else (0, None)
|
||||
)
|
||||
values.sort() # Ensure the values are sorted by SubdirLevel
|
||||
x_values, y_values = zip(*values) # Unzip the tuple pairs to separate lists
|
||||
marker = markers[n_shot] if n_doc else ""
|
||||
color = colors[i % len(colors)] # Choose a color from the colormap
|
||||
label = f"n-shot={n_shot}, n-doc={n_doc}"
|
||||
plt.plot(
|
||||
x_values, y_values, marker=marker, color=color, linestyle="-", label=label
|
||||
)
|
||||
|
||||
# plt.gca().yaxis.set_major_locator(ticker.MaxNLocator(nbins='auto', steps=[1, 2, 5, 10]))
|
||||
|
||||
if subject_name == "mmlu":
|
||||
plot_dir = os.path.join("plots", "mmlu")
|
||||
else:
|
||||
plot_dir = "plots"
|
||||
os.makedirs(plot_dir, exist_ok=True)
|
||||
|
||||
plt.xlabel("Number of Index Shards")
|
||||
plt.ylabel("Accuracy")
|
||||
plt.title(f"{task_name} scaling performance with {model_name}")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(f"{plot_dir}/{task_name}_{model_name}.png")
|
||||
|
||||
return all_data
|
||||
|
||||
|
||||
def plot_mmlu_persub_figures(directory="plots"):
|
||||
files = [
|
||||
file
|
||||
for file in os.listdir(directory)
|
||||
if file.startswith("mmlu_") and file.endswith(".png")
|
||||
]
|
||||
plots_per_figure = 16
|
||||
for i in range(0, len(files), plots_per_figure):
|
||||
# Create a new figure
|
||||
fig, axs = plt.subplots(4, 4, figsize=(20, 20))
|
||||
|
||||
# Flatten the axis array for easy indexing
|
||||
axs = axs.flatten()
|
||||
|
||||
# Iterate over each subplot in the current figure
|
||||
for ax, file in zip(axs, files[i : i + plots_per_figure]):
|
||||
# Read the image file
|
||||
img = plt.imread(os.path.join(directory, file))
|
||||
|
||||
# Display the image in the subplot
|
||||
ax.imshow(img)
|
||||
ax.set_title(file)
|
||||
ax.axis("off") # Hide axes
|
||||
|
||||
# Adjust layout and display the figure
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"mmlu_persub_{i}.png")
|
||||
|
||||
|
||||
def plot_calibration_figures(domain, shard_id=8, show_ci=True, show_all_points=False):
|
||||
if show_all_points:
|
||||
show_ci = False
|
||||
|
||||
data_path = f"out_calibration/{shard_id}_shard_{domain}/calibration_results_decon_rpj_{domain}_None_samples.pkl"
|
||||
|
||||
with open(data_path, "rb") as file:
|
||||
all_results = pickle.load(file)
|
||||
|
||||
all_lm_losses = [item[0] for item in all_results]
|
||||
all_retrieval_scores = [item[1] for item in all_results]
|
||||
print(f"Total {len(all_lm_losses)} examples.")
|
||||
|
||||
# Compute PPL of top-1 doc v.s. golden doc from top-100
|
||||
losses_top1 = [losses[0] for losses in all_lm_losses]
|
||||
avg_losses_top1 = sum(losses_top1) / len(losses_top1)
|
||||
ppl_losses_top1 = math.exp(avg_losses_top1)
|
||||
|
||||
lossed_top100_gold = [min(losses) for losses in all_lm_losses]
|
||||
avg_losses_top100_gold = sum(lossed_top100_gold) / len(lossed_top100_gold)
|
||||
ppl_lossed_top100_gold = math.exp(avg_losses_top100_gold)
|
||||
|
||||
print(
|
||||
f"Top-1 doc PPL: {ppl_losses_top1:.4f}\nGold doc from top-100 PPL: {ppl_lossed_top100_gold:.4f}"
|
||||
)
|
||||
|
||||
# Calibration plot
|
||||
lm_losses = np.array(all_lm_losses)
|
||||
retrieval_scores = np.array(all_retrieval_scores)
|
||||
|
||||
from scipy.special import softmax
|
||||
import scipy.stats as stats
|
||||
|
||||
softmax_lm_losses = softmax(lm_losses, axis=1)
|
||||
softmax_retrieval_scores = softmax(retrieval_scores, axis=1)
|
||||
|
||||
if show_all_points:
|
||||
lm_losses = lm_losses.flatten()
|
||||
retrieval_scores = retrieval_scores.flatten()
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
|
||||
plt.title(f"Calibration Curve with {shard_id} Shards")
|
||||
plt.xlabel("LM Losses")
|
||||
plt.ylabel("Retrieval Scores")
|
||||
plt.grid(True)
|
||||
plt.savefig(f"out_calibration/calibration_all_{shard_id}_shard_{domain}.png")
|
||||
|
||||
elif show_ci:
|
||||
lm_losses_mean = np.mean(lm_losses, axis=0)
|
||||
retrieval_scores_mean = np.mean(retrieval_scores, axis=0)
|
||||
|
||||
lm_losses_sem = stats.sem(lm_losses, axis=0)
|
||||
retrieval_scores_sem = stats.sem(retrieval_scores, axis=0)
|
||||
|
||||
# Assuming a 95% confidence interval, z-score is approximately 1.96 for a normal distribution
|
||||
z_score = 1.96
|
||||
losses_ci = lm_losses_sem * z_score
|
||||
retrieval_ci = retrieval_scores_sem * z_score
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.errorbar(
|
||||
lm_losses_mean,
|
||||
retrieval_scores_mean,
|
||||
xerr=losses_ci,
|
||||
yerr=retrieval_ci,
|
||||
fmt="o",
|
||||
ecolor="lightgray",
|
||||
alpha=0.5,
|
||||
capsize=5,
|
||||
)
|
||||
plt.xlabel("LM Losses")
|
||||
plt.ylabel("Retrieval Scores")
|
||||
plt.title(
|
||||
f"Calibration plot for {shard_id}-shard {domain} with Confidence Intervals"
|
||||
)
|
||||
plt.grid(True)
|
||||
plt.savefig(f"out_calibration/calibration_ci_{shard_id}_shard_{domain}.png")
|
||||
|
||||
else:
|
||||
lm_losses = np.mean(lm_losses, axis=0)
|
||||
retrieval_scores = np.mean(retrieval_scores, axis=0)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
|
||||
plt.title(f"Calibration Curve with {shard_id} Shards")
|
||||
plt.xlabel("LM Losses")
|
||||
plt.ylabel("Retrieval Scores")
|
||||
plt.grid(True)
|
||||
plt.savefig(f"out_calibration/calibration_{shard_id}_shard_{domain}.png")
|
||||
|
||||
return ppl_losses_top1, ppl_lossed_top100_gold, all_lm_losses, all_retrieval_scores
|
||||
|
||||
|
||||
def plot_top1_vs_best_doc(domain, total_shards=8):
|
||||
lm_only_ppl = {
|
||||
"books": 21.5250,
|
||||
"stackexchange": 11.5948,
|
||||
"wiki": 14.0729,
|
||||
}
|
||||
|
||||
top1_losses, best_losses = [], []
|
||||
for shard_id in range(1, total_shards + 1):
|
||||
top1_loss, best_loss, _, _ = plot_calibration_figures(domain, shard_id)
|
||||
top1_losses.append(top1_loss)
|
||||
best_losses.append(best_loss)
|
||||
|
||||
x = [i for i in range(1, total_shards + 1)]
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Plotting
|
||||
if lm_only_ppl[domain]:
|
||||
plt.axhline(
|
||||
y=lm_only_ppl[domain], color="r", linestyle="-", label="Closed-book"
|
||||
)
|
||||
|
||||
plt.plot(x, top1_losses, label=f"Top-1 Doc")
|
||||
plt.plot(x, best_losses, label=f"Gold Doc")
|
||||
|
||||
plt.title(f"Perplexity Change with Total Shards")
|
||||
plt.xlabel("Number of Shards")
|
||||
plt.ylabel("Perplexity")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(f"best_plot_{domain}.png")
|
||||
|
||||
|
||||
def plot_top1_vs_best_doc_per_sample(domain, shard_id, show_top_k=10, special_mark_k=0):
|
||||
_, _, all_lm_losses, all_retrieval_scores = plot_calibration_figures(
|
||||
domain, shard_id
|
||||
)
|
||||
all_sorted_lm_losses, all_sorted_retrieval_scores = [], []
|
||||
for lm_losses, retrieval_scores in zip(all_retrieval_scores, all_lm_losses):
|
||||
sorted_scores, sorted_losses = zip(
|
||||
*sorted(zip(retrieval_scores, lm_losses), reverse=True)
|
||||
)
|
||||
all_sorted_lm_losses.append(sorted_losses)
|
||||
all_sorted_retrieval_scores.append(sorted_scores)
|
||||
|
||||
num_samples = len(all_lm_losses)
|
||||
x = [i for i in range(num_samples)]
|
||||
plt.figure(figsize=(25, 6))
|
||||
|
||||
# Plotting
|
||||
for i in range(show_top_k - 1, -1, -1):
|
||||
plt.plot(
|
||||
x,
|
||||
[losses[i] for losses in all_sorted_lm_losses],
|
||||
label=f"Top-{i + 1}th Doc",
|
||||
marker="x" if i == special_mark_k else "o",
|
||||
linestyle="",
|
||||
)
|
||||
|
||||
plt.title(f"Per-sample Loss of {domain} with 1 retrieved doc")
|
||||
plt.xlabel("Index of the Evaluation Sample")
|
||||
plt.ylabel("Loss")
|
||||
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
|
||||
plt.grid(True)
|
||||
plt.savefig(f"per_sample_{domain}.png")
|
||||
|
||||
|
||||
def compute_variance_across_hards(path, n_shot=5, n_doc=3):
|
||||
all_data = []
|
||||
for subdir, dirs, files in os.walk(path):
|
||||
num_ints = len(os.path.basename(subdir).split("-"))
|
||||
for file in files:
|
||||
if file.endswith(".jsonl"):
|
||||
file_path = os.path.join(subdir, file)
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
data["SubdirLevel"] = num_ints
|
||||
data["n-shot"], data["n-doc"] = (
|
||||
int(data["n-shot"]),
|
||||
int(data["n-doc"]),
|
||||
)
|
||||
data["Value"] = float(data["Value"])
|
||||
all_data.append(data)
|
||||
|
||||
plot_data = {}
|
||||
for d in all_data:
|
||||
key = (d["n-shot"], d["n-doc"])
|
||||
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
|
||||
|
||||
files_end = [d.split("/")[-1] for d, _, _ in os.walk(path)]
|
||||
shard_ids = [int(i) for i in files_end[1:]]
|
||||
key = n_shot, n_doc
|
||||
values = plot_data[key]
|
||||
_, y_values = zip(*values)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
try:
|
||||
plt.plot(shard_ids, y_values, marker="o", linestyle="")
|
||||
except:
|
||||
print(f"mismatched size for {key}: {len(shard_ids)}, {len(y_values)}")
|
||||
|
||||
print(y_values)
|
||||
print(f"Saving to {f'per_sample_{files_end[0]}.png'}")
|
||||
|
||||
plt.xlabel("Single-shard Index ID")
|
||||
plt.ylabel("PPL")
|
||||
plt.grid(True)
|
||||
plt.savefig(f"per_sample_{files_end[0]}.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # Replace with your directory path
|
||||
# directory_path = "out/2023_dec_25_single_domain"
|
||||
|
||||
# # Extracting data to a table with additional information
|
||||
# df, grouped_df = extract_data_to_table(directory_path)
|
||||
# print(grouped_df)
|
||||
# print(grouped_df.index.get_level_values("Samples (M)").to_numpy())
|
||||
|
||||
plot_info_list = [
|
||||
# {'logfile': 'rpj_c4.log', 'domain': 'rpj-c4', 'plot': 'scaling_c4_single_index_plot.png'},
|
||||
# {'logfile': 'rpj_arxiv.log', 'domain': 'rpj-arxiv', 'plot': 'scaling_arxiv_plot.png'},
|
||||
# {'logfile': 'rpj_book_scaling.log', 'domain': 'rpj-book', 'plot': 'scaling_book_plot.png'},
|
||||
# {'logfile': 'rpj_github_scaling.log', 'domain': 'rpj-github', 'plot': 'scaling_github_plot.png'},
|
||||
# {'logfile': 'rpj_stackexchange_scaling.log', 'domain': 'rpj-stackexchange', 'plot': 'scaling_stackexchange_plot.png'},
|
||||
# {'logfile': 'rpj_wiki.log', 'domain': 'rpj-wiki', 'plot': 'scaling_wiki_plot.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_contriever_ppl.log', 'domain': 'rpj-wiki-decon-contriever', 'plot': 'scaling_wiki_decon_plot_contriever.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_contriever_ppl.log', 'domain': 'rpj-book-decon-contriever', 'plot': 'scaling_book_decon_plot_contriever.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_contriever_ppl.log', 'domain': 'rpj-arxiv-decon-contriever', 'plot': 'scaling_arxiv_decon_plot_contriever.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_contriever_ppl.log', 'domain': 'rpj-stackexchange-decon-contriever', 'plot': 'scaling_stackexchange_decon_plot_contriever.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_dragon_ppl.log', 'domain': 'rpj-stackexchange-decon-dragon', 'plot': 'scaling_stackexchange_decon_plot_dragon.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_dragon_ppl.log', 'domain': 'rpj-wiki-decon-dragon', 'plot': 'scaling_wiki_decon_plot_dragon.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_dragon_ppl.log', 'domain': 'rpj-arxiv-decon-dragon', 'plot': 'scaling_arxiv_decon_plot_dragon.png'},
|
||||
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_dragon_ppl.log', 'domain': 'rpj-book-decon-dragon', 'plot': 'scaling_book_decon_plot_dragon.png'},
|
||||
]
|
||||
|
||||
# for plot_info in plot_info_list:
|
||||
# extract_dense_scaling_results([plot_info['logfile']], plot_info['domain'], plot_info['plot'])
|
||||
|
||||
model_name = "lclm"
|
||||
subject_name = "gsm8k"
|
||||
datastore_name = "c4"
|
||||
result_dir = f"/gscratch/zlab/rulins/Scaling/lm_eval_results/{model_name}"
|
||||
|
||||
all_subjects = [
|
||||
file
|
||||
for file in os.listdir(result_dir)
|
||||
if subject_name in file and datastore_name in file
|
||||
]
|
||||
for subject in all_subjects:
|
||||
file_name = subject
|
||||
print(file_name)
|
||||
extract_lm_eval_results(
|
||||
os.path.join(result_dir, file_name),
|
||||
subject,
|
||||
model_name,
|
||||
[0, 5], # few-shot
|
||||
[0, 3], # n-doc
|
||||
file_name,
|
||||
)
|
||||
|
||||
# plot_mmlu_persub_figures("plots/mmlu")
|
||||
|
||||
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/nq_open-rpj_c4-32_shards')
|
||||
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/medqa_4options-rpj_c4-32_shards')
|
||||
|
||||
# plot_calibration_figures(domain='wiki', shard_id=1, show_all_points=True)
|
||||
# plot_top1_vs_best_doc_per_sample(domain='stackexchange', shard_id=1, show_top_k=10, special_mark_k=0)
|
||||
20
research/utils/faiss.md
Normal file
20
research/utils/faiss.md
Normal file
@@ -0,0 +1,20 @@
|
||||
sudo apt-get install libgflags-dev
|
||||
|
||||
uv pip install swig
|
||||
|
||||
sudo apt-get update && sudo apt-get install libzmq3-dev libmsgpack-dev pkg-config
|
||||
|
||||
/home/ubuntu/Power-RAG/.venv/bin/cmake -B build -DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE -DFAISS_ENABLE_PYTHON=ON -DFAISS_ENABLE_GPU=OFF -DCMAKE_BUILD_TYPE=Debug .
|
||||
|
||||
make -C build -j faiss && make -C build -j swigfaiss && uv pip install -e build/faiss/python
|
||||
|
||||
|
||||
|
||||
## Some outdated info (may not needeed)
|
||||
|
||||
&& cp ./build/faiss/python/_swigfaiss.so /home/andy/Power-RAG/.venv/lib/python3.10/site-packages/faiss-1.10.0-py3.10.egg/faiss/
|
||||
|
||||
|
||||
export LD_PRELOAD="/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so"
|
||||
|
||||
set -x LD_PRELOAD "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so"
|
||||
739
research/utils/find_probe.py
Normal file
739
research/utils/find_probe.py
Normal file
@@ -0,0 +1,739 @@
|
||||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
import json
|
||||
import re
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--path-suffix", type=str, default="", help="Path suffix for the index")
|
||||
parser.add_argument("--pq-compressed", type=int, default=None)
|
||||
parser.add_argument("--beam-width", type=int, default=2, help="DiskANN beam width for search (controls number of IO requests per iteration)")
|
||||
parser.add_argument("--index-type", type=str, default="diskann", help="Index type to test (default: diskann)")
|
||||
parser.add_argument("--task", type=str, default="nq", help="Task to run (default: nq)")
|
||||
parser.add_argument("--max-workers", type=int, default=1, help="Maximum number of concurrent processes")
|
||||
parser.add_argument("--timeout", type=int, default=1800, help="Timeout for each process in seconds")
|
||||
parser.add_argument("--retry-count", type=int, default=2, help="Number of retries for failed runs")
|
||||
parser.add_argument(
|
||||
"--target-recalls",
|
||||
type=float,
|
||||
nargs='+',
|
||||
default=[0.85, 0.90, 0.95],
|
||||
help="Target recalls to achieve (e.g., --target-recalls 0.85 0.90 0.95)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
path_suffix = args.path_suffix
|
||||
|
||||
pq_compressed = args.pq_compressed
|
||||
beam_width = args.beam_width
|
||||
max_workers = args.max_workers
|
||||
timeout = args.timeout
|
||||
retry_count = args.retry_count
|
||||
|
||||
TARGET_RECALLS = args.target_recalls
|
||||
|
||||
task = args.task
|
||||
|
||||
# Process management
|
||||
running_processes = {} # PID -> Process object
|
||||
|
||||
# Based on previous data, search around these values
|
||||
if args.index_type == "diskann":
|
||||
if task == "nq":
|
||||
if pq_compressed is None:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(10, 50),
|
||||
0.90: range(62, 67), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(190, 195) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif pq_compressed == 10:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(10, 70),
|
||||
0.90: range(90, 127), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(200, 384) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif pq_compressed == 20:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(10, 50),
|
||||
0.90: range(64, 128), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(188, 192) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif pq_compressed == 5:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(10, 500),
|
||||
0.90: range(768, 2000), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(3000, 4096) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif task == "trivia":
|
||||
if pq_compressed is None:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(90, 150),
|
||||
0.90: range(150, 200), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(200, 300) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif task == "gpqa":
|
||||
if pq_compressed is None:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(1, 30),
|
||||
0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif task == "hotpot":
|
||||
if pq_compressed is None:
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(19, 160),
|
||||
0.90: range(120, 210), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(1000, 1200) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif args.index_type == "ivf_disk":
|
||||
if task == "nq":
|
||||
assert pq_compressed is None
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(13, 16),
|
||||
0.90: range(30,40),
|
||||
0.95: range(191, 194)
|
||||
}
|
||||
elif task == "trivia":
|
||||
assert pq_compressed is None
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(13, 50),
|
||||
0.90: range(30, 100),
|
||||
0.95: range(100, 400)
|
||||
}
|
||||
elif task == "gpqa":
|
||||
assert pq_compressed is None
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(1, 30),
|
||||
0.90: range(1, 30), # Narrow range around 64 (63, 64, 65, 66)
|
||||
0.95: range(1, 30) # Narrow range around 192 (190, 191, 192, 193, 194)
|
||||
}
|
||||
elif task == "hotpot":
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(13, 100),
|
||||
0.90: range(30, 200),
|
||||
0.95: range(191, 700)
|
||||
}
|
||||
elif args.index_type == "hnsw":
|
||||
if task == "nq":
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(130, 140),
|
||||
0.90: range(550, 666),
|
||||
0.95: range(499, 1199),
|
||||
}
|
||||
if task == "gpqa":
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(40, 70),
|
||||
0.90: range(60, 100),
|
||||
0.95: range(200, 500),
|
||||
}
|
||||
elif task == "hotpot":
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(450, 480),
|
||||
0.90: range(1000, 1300),
|
||||
0.95: range(2000, 4000),
|
||||
}
|
||||
elif task == "trivia":
|
||||
NPROBE_RANGES = {
|
||||
0.85: range(100, 400),
|
||||
0.90: range(700, 1800),
|
||||
0.95: range(506, 1432)
|
||||
}
|
||||
|
||||
# Create a directory for logs if it doesn't exist
|
||||
os.makedirs("nprobe_logs", exist_ok=True)
|
||||
|
||||
# Set up signal handling for clean termination
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Cleaning up running processes...")
|
||||
for pid, process in running_processes.items():
|
||||
try:
|
||||
if process.poll() is None: # Process is still running
|
||||
print(f"Terminating process {pid}...")
|
||||
process.terminate()
|
||||
time.sleep(0.5)
|
||||
if process.poll() is None: # If still running after terminate
|
||||
print(f"Killing process {pid}...")
|
||||
process.kill()
|
||||
|
||||
# Kill any child processes
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
print(f"Killing child process {child.pid}...")
|
||||
child.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
print("All processes terminated. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
def run_batch_demo(nprobe: int, retry: int = 0) -> Optional[float]:
|
||||
"""Run main.py in batch mode with a specific nprobe value and extract the recall."""
|
||||
command = f"python -u ./demo/main.py --search-only --load-indices {args.index_type} --domain rpj_wiki --lazy-load-passages --nprobe {nprobe} --task {task} --skip-passages"
|
||||
if pq_compressed is not None:
|
||||
command += f" --diskann-search-memory-maximum {pq_compressed}"
|
||||
if beam_width is not None:
|
||||
command += f" --diskann-beam-width {beam_width}"
|
||||
if args.index_type == "hnsw":
|
||||
command += f" --hnsw-old"
|
||||
# command += " --embedder intfloat/multilingual-e5-small"
|
||||
|
||||
cmd = [
|
||||
"fish", "-c",
|
||||
# f"set -gx LD_PRELOAD \"/lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so /lib/x86_64-linux-gnu/libiomp5.so\" && "
|
||||
"source ./.venv/bin/activate.fish &&"
|
||||
+ command
|
||||
]
|
||||
|
||||
print(f"Running with nprobe={nprobe}, beam_width={beam_width}, retry={retry}/{retry_count}")
|
||||
log_file = f"nprobe_logs/nprobe_{nprobe}_beam{beam_width}_{path_suffix}_retry{retry}.log"
|
||||
|
||||
try:
|
||||
# Also save the command to the log file
|
||||
with open(log_file, "w") as f:
|
||||
f.write(f"Command: {cmd[1]}\n\n")
|
||||
f.write(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write("=== OUTPUT BEGINS ===\n")
|
||||
|
||||
# Run the command and tee the output to both stdout and the log file
|
||||
with open(log_file, "a") as f:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1 # Line buffered
|
||||
)
|
||||
|
||||
# Register the process for cleanup
|
||||
pid = process.pid
|
||||
running_processes[pid] = process
|
||||
|
||||
# Process output line by line for real-time logging
|
||||
if process.stdout: # Check if stdout is not None
|
||||
# Set a timeout
|
||||
start_time = time.time()
|
||||
current_output = ""
|
||||
|
||||
while process.poll() is None:
|
||||
# Check for timeout
|
||||
if time.time() - start_time > timeout:
|
||||
print(f"Process timeout for nprobe={nprobe}, killing...")
|
||||
f.write("\n\nProcess timed out, killing...\n")
|
||||
process.terminate()
|
||||
time.sleep(0.5)
|
||||
if process.poll() is None:
|
||||
process.kill()
|
||||
|
||||
# Clean up child processes
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
child.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
if pid in running_processes:
|
||||
del running_processes[pid]
|
||||
|
||||
# Retry if we have attempts left
|
||||
if retry < retry_count:
|
||||
print(f"Retrying nprobe={nprobe}...")
|
||||
return run_batch_demo(nprobe, retry + 1)
|
||||
return None
|
||||
|
||||
# Read output with a small timeout to allow for process checking
|
||||
try:
|
||||
line = process.stdout.readline()
|
||||
if not line:
|
||||
time.sleep(0.1) # Small pause to avoid busy waiting
|
||||
continue
|
||||
|
||||
print(line, end='') # Print to console
|
||||
f.write(line) # Write to log file
|
||||
f.flush() # Make sure it's written immediately
|
||||
except:
|
||||
time.sleep(0.1)
|
||||
|
||||
exit_code = process.wait()
|
||||
|
||||
# Process complete, remove from running list
|
||||
if pid in running_processes:
|
||||
del running_processes[pid]
|
||||
|
||||
f.write(f"\nExit code: {exit_code}\n")
|
||||
f.write(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
|
||||
# Re-read the log file to extract recall rate
|
||||
with open(log_file, "r") as f:
|
||||
log_content = f.read()
|
||||
|
||||
# Try multiple patterns to find recall rate
|
||||
recall = None
|
||||
patterns = [
|
||||
fr"Avg recall rate for {args.index_type}: ([0-9.]+)",
|
||||
r"recall: ([0-9.]+)",
|
||||
fr"{args.index_type}.*?recall.*?([0-9.]+)",
|
||||
fr"recall.*?{args.index_type}.*?([0-9.]+)"
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, log_content, re.IGNORECASE)
|
||||
if matches:
|
||||
try:
|
||||
recall = float(matches[-1]) # Take the last one if multiple matches
|
||||
print(f"Found recall rate using pattern: {pattern}")
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if recall is None:
|
||||
print(f"Warning: Could not extract recall rate from output log {log_file}")
|
||||
# Try to find any number that looks like a recall rate (between 0 and 1)
|
||||
possible_recalls = re.findall(r"recall.*?([0-9]+\.[0-9]+)", log_content, re.IGNORECASE)
|
||||
if possible_recalls:
|
||||
try:
|
||||
recall_candidates = [float(r) for r in possible_recalls if 0 <= float(r) <= 1]
|
||||
if recall_candidates:
|
||||
recall = recall_candidates[-1] # Take the last one
|
||||
print(f"Guessed recall rate: {recall} (based on pattern matching)")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if recall is None:
|
||||
# Log this failure with more context
|
||||
with open("nprobe_logs/failed_recalls.log", "a") as f:
|
||||
f.write(f"Failed to extract recall for nprobe={nprobe} at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
|
||||
# Retry if we have attempts left
|
||||
if retry < retry_count:
|
||||
print(f"Retrying nprobe={nprobe} due to failed recall extraction...")
|
||||
return run_batch_demo(nprobe, retry + 1)
|
||||
|
||||
return None
|
||||
|
||||
print(f"nprobe={nprobe}, recall={recall:.4f}")
|
||||
return recall
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Command timed out for nprobe={nprobe}")
|
||||
with open(log_file, "a") as f:
|
||||
f.write("\n\nCommand timed out after 1800 seconds\n")
|
||||
|
||||
# Retry if we have attempts left
|
||||
if retry < retry_count:
|
||||
print(f"Retrying nprobe={nprobe}...")
|
||||
return run_batch_demo(nprobe, retry + 1)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error running command for nprobe={nprobe}: {e}")
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"\n\nError: {e}\n")
|
||||
|
||||
# Retry if we have attempts left
|
||||
if retry < retry_count:
|
||||
print(f"Retrying nprobe={nprobe} due to error: {e}...")
|
||||
return run_batch_demo(nprobe, retry + 1)
|
||||
|
||||
return None
|
||||
|
||||
def batch_run_nprobe_values(nprobe_values):
|
||||
"""Run multiple nprobe values in parallel with a thread pool."""
|
||||
results = {}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_nprobe = {executor.submit(run_batch_demo, nprobe): nprobe for nprobe in nprobe_values}
|
||||
|
||||
for future in concurrent.futures.as_completed(future_to_nprobe):
|
||||
nprobe = future_to_nprobe[future]
|
||||
try:
|
||||
recall = future.result()
|
||||
if recall is not None:
|
||||
results[nprobe] = recall
|
||||
print(f"Completed nprobe={nprobe} with recall={recall:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Error processing nprobe={nprobe}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def adaptive_search_nprobe(target_recall: float, min_nprobe: int, max_nprobe: int, tolerance: float = 0.001) -> Dict:
|
||||
"""
|
||||
Use an adaptive search strategy to find the optimal nprobe value for a target recall.
|
||||
Combines binary search with exploration to handle non-linear relationships.
|
||||
|
||||
Args:
|
||||
target_recall: The target recall to achieve
|
||||
min_nprobe: Minimum nprobe value to start search
|
||||
max_nprobe: Maximum nprobe value for search
|
||||
tolerance: How close we need to get to the target_recall
|
||||
|
||||
Returns:
|
||||
Dictionary with the best nprobe, achieved recall, and other metadata
|
||||
"""
|
||||
print(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...")
|
||||
print(f"Search range: {min_nprobe} - {max_nprobe}")
|
||||
|
||||
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
||||
f.write(f"\nAdaptive searching for nprobe that achieves {target_recall*100:.1f}% recall...\n")
|
||||
f.write(f"Search range: {min_nprobe} - {max_nprobe}\n")
|
||||
|
||||
best_result = {"nprobe": None, "recall": None, "difference": float('inf')}
|
||||
all_results = {"nprobe": [], "recall": []}
|
||||
|
||||
# Save initial file for this search
|
||||
search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target_recall:.2f}.json"
|
||||
search_data = {
|
||||
"target": target_recall,
|
||||
"current_best": best_result,
|
||||
"all_results": all_results,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"search_range": {"min": min_nprobe, "max": max_nprobe}
|
||||
}
|
||||
|
||||
with open(search_results_file, "w") as f:
|
||||
json.dump(search_data, f, indent=2)
|
||||
|
||||
# Start with a strategic sampling to understand the recall curve
|
||||
# Choose more points if the range is large
|
||||
range_size = max_nprobe - min_nprobe
|
||||
if range_size > 500:
|
||||
num_initial_samples = 5
|
||||
elif range_size > 100:
|
||||
num_initial_samples = 4
|
||||
else:
|
||||
num_initial_samples = 3
|
||||
|
||||
sample_points = [min_nprobe]
|
||||
step = range_size // (num_initial_samples - 1)
|
||||
for i in range(1, num_initial_samples - 1):
|
||||
sample_points.append(min_nprobe + i * step)
|
||||
sample_points.append(max_nprobe)
|
||||
|
||||
# Run initial sample points in parallel
|
||||
initial_results = batch_run_nprobe_values(sample_points)
|
||||
|
||||
# Update all_results and best_result based on initial_results
|
||||
for nprobe, recall in initial_results.items():
|
||||
all_results["nprobe"].append(nprobe)
|
||||
all_results["recall"].append(recall)
|
||||
|
||||
diff = abs(recall - target_recall)
|
||||
if diff < best_result["difference"]:
|
||||
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
||||
|
||||
# Update search results file
|
||||
search_data = {
|
||||
"target": target_recall,
|
||||
"current_best": best_result,
|
||||
"all_results": all_results,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"search_range": {"min": min_nprobe, "max": max_nprobe}
|
||||
}
|
||||
with open(search_results_file, "w") as f:
|
||||
json.dump(search_data, f, indent=2)
|
||||
|
||||
# Check if we've already reached target within tolerance
|
||||
if best_result["difference"] <= tolerance:
|
||||
print(f"Found good enough nprobe value: {best_result['nprobe']} with recall {best_result['recall']:.4f}")
|
||||
return best_result
|
||||
|
||||
# Analyze initial results to decide on next strategy
|
||||
# Sort results by nprobe
|
||||
sorted_results = sorted([(n, r) for n, r in zip(all_results["nprobe"], all_results["recall"])])
|
||||
nprobes, recalls = zip(*sorted_results)
|
||||
|
||||
# Check if the relationship is monotonic
|
||||
is_monotonic = all(recalls[i] <= recalls[i+1] for i in range(len(recalls)-1)) or \
|
||||
all(recalls[i] >= recalls[i+1] for i in range(len(recalls)-1))
|
||||
|
||||
if is_monotonic:
|
||||
print("Relationship appears monotonic, proceeding with binary search.")
|
||||
# Find the two closest points that bracket the target
|
||||
bracket_low, bracket_high = None, None
|
||||
for i in range(len(recalls)-1):
|
||||
if (recalls[i] <= target_recall <= recalls[i+1]) or (recalls[i] >= target_recall >= recalls[i+1]):
|
||||
bracket_low, bracket_high = nprobes[i], nprobes[i+1]
|
||||
break
|
||||
|
||||
if bracket_low is None:
|
||||
# Target is outside our current range, adjust range
|
||||
if all(r < target_recall for r in recalls):
|
||||
# All recalls are too low, need to increase nprobe
|
||||
bracket_low = nprobes[-1]
|
||||
bracket_high = min(max_nprobe, nprobes[-1] * 2)
|
||||
else:
|
||||
# All recalls are too high, need to decrease nprobe
|
||||
bracket_low = max(min_nprobe, nprobes[0] // 2)
|
||||
bracket_high = nprobes[0]
|
||||
|
||||
# Binary search between bracket_low and bracket_high
|
||||
while abs(bracket_high - bracket_low) > 3:
|
||||
mid_nprobe = (bracket_low + bracket_high) // 2
|
||||
if mid_nprobe in initial_results:
|
||||
mid_recall = initial_results[mid_nprobe]
|
||||
else:
|
||||
mid_recall = run_batch_demo(mid_nprobe)
|
||||
if mid_recall is not None:
|
||||
all_results["nprobe"].append(mid_nprobe)
|
||||
all_results["recall"].append(mid_recall)
|
||||
|
||||
diff = abs(mid_recall - target_recall)
|
||||
if diff < best_result["difference"]:
|
||||
best_result = {"nprobe": mid_nprobe, "recall": mid_recall, "difference": diff}
|
||||
|
||||
# Update search results file
|
||||
search_data["current_best"] = best_result
|
||||
search_data["all_results"] = all_results
|
||||
with open(search_results_file, "w") as f:
|
||||
json.dump(search_data, f, indent=2)
|
||||
|
||||
# Check if we're close enough
|
||||
if mid_recall is not None:
|
||||
if abs(mid_recall - target_recall) <= tolerance:
|
||||
break
|
||||
|
||||
# Adjust brackets
|
||||
if mid_recall < target_recall:
|
||||
bracket_low = mid_nprobe
|
||||
else:
|
||||
bracket_high = mid_nprobe
|
||||
else:
|
||||
# If we failed to get a result, try a different point
|
||||
bracket_high = mid_nprobe - 1
|
||||
else:
|
||||
print("Relationship appears non-monotonic, using adaptive sampling.")
|
||||
# For non-monotonic relationships, we'll use adaptive sampling
|
||||
# First, find the best current point
|
||||
best_idx = recalls.index(min(recalls, key=lambda r: abs(r - target_recall)))
|
||||
best_nprobe = nprobes[best_idx]
|
||||
|
||||
# Try points around the best point with decreasing radius
|
||||
radius = max(50, (max_nprobe - min_nprobe) // 10)
|
||||
min_radius = 3
|
||||
|
||||
while radius >= min_radius:
|
||||
# Try points at current radius around best_nprobe
|
||||
test_points = []
|
||||
lower_bound = max(min_nprobe, best_nprobe - radius)
|
||||
upper_bound = min(max_nprobe, best_nprobe + radius)
|
||||
|
||||
if lower_bound not in initial_results and lower_bound != best_nprobe:
|
||||
test_points.append(lower_bound)
|
||||
if upper_bound not in initial_results and upper_bound != best_nprobe:
|
||||
test_points.append(upper_bound)
|
||||
|
||||
# Add a point in the middle if range is large enough
|
||||
if upper_bound - lower_bound > 2*radius/3 and len(test_points) < max_workers:
|
||||
mid_point = (lower_bound + upper_bound) // 2
|
||||
if mid_point not in initial_results and mid_point != best_nprobe:
|
||||
test_points.append(mid_point)
|
||||
|
||||
# Run tests
|
||||
if test_points:
|
||||
new_results = batch_run_nprobe_values(test_points)
|
||||
initial_results.update(new_results)
|
||||
|
||||
# Update all_results and best_result
|
||||
for nprobe, recall in new_results.items():
|
||||
all_results["nprobe"].append(nprobe)
|
||||
all_results["recall"].append(recall)
|
||||
|
||||
diff = abs(recall - target_recall)
|
||||
if diff < best_result["difference"]:
|
||||
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
||||
best_nprobe = nprobe # Update the center for next iteration
|
||||
|
||||
# Update search results file
|
||||
search_data["current_best"] = best_result
|
||||
search_data["all_results"] = all_results
|
||||
with open(search_results_file, "w") as f:
|
||||
json.dump(search_data, f, indent=2)
|
||||
|
||||
# Check if we're close enough
|
||||
if best_result["difference"] <= tolerance:
|
||||
break
|
||||
|
||||
# Reduce radius for next iteration
|
||||
radius = max(min_radius, radius // 2)
|
||||
|
||||
# After search, do a final fine-tuning around the best result
|
||||
if best_result["nprobe"] is not None:
|
||||
fine_tune_range = range(max(min_nprobe, best_result["nprobe"] - 2),
|
||||
min(max_nprobe, best_result["nprobe"] + 3))
|
||||
|
||||
fine_tune_points = [n for n in fine_tune_range if n not in all_results["nprobe"]]
|
||||
if fine_tune_points:
|
||||
fine_tune_results = batch_run_nprobe_values(fine_tune_points)
|
||||
|
||||
for nprobe, recall in fine_tune_results.items():
|
||||
all_results["nprobe"].append(nprobe)
|
||||
all_results["recall"].append(recall)
|
||||
|
||||
diff = abs(recall - target_recall)
|
||||
if diff < best_result["difference"]:
|
||||
best_result = {"nprobe": nprobe, "recall": recall, "difference": diff}
|
||||
|
||||
# Final update to search results file
|
||||
search_data["current_best"] = best_result
|
||||
search_data["all_results"] = all_results
|
||||
search_data["search_range"] = {"min": min_nprobe, "max": max_nprobe, "phase": "fine_tune"}
|
||||
with open(search_results_file, "w") as f:
|
||||
json.dump(search_data, f, indent=2)
|
||||
|
||||
return best_result
|
||||
|
||||
def find_optimal_nprobe_values():
|
||||
"""Find the optimal nprobe values for target recall rates using adaptive search."""
|
||||
# Dictionary to store results for each target recall
|
||||
results = {}
|
||||
# Dictionary to store all nprobe-recall pairs for plotting
|
||||
all_data = {target: {"nprobe": [], "recall": []} for target in TARGET_RECALLS}
|
||||
|
||||
# Create a summary file for all runs
|
||||
with open(f"nprobe_logs/summary_{path_suffix}.log", "w") as f:
|
||||
f.write(f"Find optimal nprobe values - started at {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
f.write(f"Target recalls: {TARGET_RECALLS}\n")
|
||||
f.write(f"nprobe ranges: {NPROBE_RANGES}\n\n")
|
||||
f.write(f"Max workers: {max_workers}\n")
|
||||
f.write(f"Timeout per process: {timeout}s\n")
|
||||
f.write(f"Retry count: {retry_count}\n\n")
|
||||
|
||||
for target in TARGET_RECALLS:
|
||||
# Use the existing NPROBE_RANGES to determine min and max values
|
||||
min_nprobe = min(NPROBE_RANGES[target])
|
||||
max_nprobe = max(NPROBE_RANGES[target])
|
||||
|
||||
print(f"\nUsing NPROBE_RANGES for target {target*100:.1f}%: {min_nprobe} to {max_nprobe}")
|
||||
|
||||
# Run adaptive search instead of binary search
|
||||
best_result = adaptive_search_nprobe(
|
||||
target_recall=target,
|
||||
min_nprobe=min_nprobe,
|
||||
max_nprobe=max_nprobe
|
||||
)
|
||||
|
||||
results[target] = best_result
|
||||
|
||||
# Save all tested points to all_data for plotting
|
||||
search_results_file = f"nprobe_logs/search_results_{path_suffix}_{target:.2f}.json"
|
||||
try:
|
||||
with open(search_results_file, "r") as f:
|
||||
search_data = json.load(f)
|
||||
if "all_results" in search_data:
|
||||
all_data[target]["nprobe"] = search_data["all_results"]["nprobe"]
|
||||
all_data[target]["recall"] = search_data["all_results"]["recall"]
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: Could not load search results for {target}: {e}")
|
||||
|
||||
print(f"For target recall {target*100:.1f}%:")
|
||||
print(f" Best nprobe value: {best_result['nprobe']}")
|
||||
print(f" Achieved recall: {best_result['recall']:.4f}")
|
||||
print(f" Difference: {best_result['difference']:.4f}")
|
||||
|
||||
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
||||
f.write(f"For target recall {target*100:.1f}%:\n")
|
||||
f.write(f" Best nprobe value: {best_result['nprobe']}\n")
|
||||
f.write(f" Achieved recall: {best_result['recall']:.4f}\n")
|
||||
f.write(f" Difference: {best_result['difference']:.4f}\n")
|
||||
|
||||
# Plot the results if we have data
|
||||
if all_data and any(data["nprobe"] for data in all_data.values()):
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Plot each target's data
|
||||
for target in TARGET_RECALLS:
|
||||
if not all_data[target]["nprobe"]:
|
||||
continue
|
||||
|
||||
nprobe_values = all_data[target]["nprobe"]
|
||||
recall_values = all_data[target]["recall"]
|
||||
|
||||
# Sort data points for better visualization
|
||||
sorted_points = sorted(zip(nprobe_values, recall_values))
|
||||
sorted_nprobe, sorted_recall = zip(*sorted_points) if sorted_points else ([], [])
|
||||
|
||||
plt.plot(sorted_nprobe, sorted_recall, 'o-',
|
||||
label=f"Target {target*100:.1f}%, Best={results[target]['nprobe']}")
|
||||
|
||||
# Mark the optimal point
|
||||
opt_nprobe = results[target]["nprobe"]
|
||||
opt_recall = results[target]["recall"]
|
||||
plt.plot(opt_nprobe, opt_recall, 'r*', markersize=15)
|
||||
|
||||
# Add a horizontal line at the target recall
|
||||
plt.axhline(y=target, color='gray', linestyle='--', alpha=0.5)
|
||||
|
||||
plt.xlabel('nprobe value')
|
||||
plt.ylabel('Recall rate')
|
||||
plt.title(f'Recall Rate vs nprobe Value (Max Workers: {max_workers})')
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
plt.savefig(f'nprobe_logs/nprobe_vs_recall_{path_suffix}.png')
|
||||
print(f"Plot saved to nprobe_logs/nprobe_vs_recall_{path_suffix}.png")
|
||||
else:
|
||||
print("No data to plot.")
|
||||
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
||||
f.write("No data to plot.\n")
|
||||
|
||||
# Save final results
|
||||
with open(f"nprobe_logs/optimal_nprobe_values_{path_suffix}.json", "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(f"nprobe_logs/summary_{path_suffix}.log", "a") as f:
|
||||
f.write(f"\nFind optimal nprobe values - finished at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
if results:
|
||||
f.write("\nOptimal nprobe values for target recall rates:\n")
|
||||
for target, data in results.items():
|
||||
f.write(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})\n")
|
||||
else:
|
||||
f.write("No optimal nprobe values found.\n")
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
results = find_optimal_nprobe_values()
|
||||
|
||||
if not results:
|
||||
print("No optimal nprobe values found.")
|
||||
sys.exit(1)
|
||||
|
||||
print("\nOptimal nprobe values for target recall rates:")
|
||||
for target, data in results.items():
|
||||
print(f"{target*100:.1f}% recall: nprobe={data['nprobe']} (actual recall: {data['recall']:.4f})")
|
||||
|
||||
# Generate the command for running the latency test with the optimal nprobe values
|
||||
optimal_values = [data["nprobe"] for target, data in sorted(results.items())]
|
||||
test_cmd = f"source ./.venv/bin/activate.fish && cd ~ && python ./Power-RAG/demo/test_serve.py --nprobe_values {' '.join(map(str, optimal_values))}"
|
||||
|
||||
print("\nRun this command to test latency with the optimal nprobe values:")
|
||||
print(test_cmd)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nScript interrupted by user. Cleaning up running processes...")
|
||||
signal_handler(signal.SIGINT, None)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
# Clean up any running processes before re-raising
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
signal_handler(signal.SIGINT, None)
|
||||
raise e
|
||||
60
research/utils/generate_dataset_cache.fish
Normal file
60
research/utils/generate_dataset_cache.fish
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env fish
|
||||
|
||||
# Set default parameters
|
||||
set domain "rpj_wiki"
|
||||
set embedder "facebook/contriever-msmarco"
|
||||
set k 5
|
||||
set tasks "nq" "trivia" "hotpot" "gpqa"
|
||||
|
||||
# Parse command line arguments
|
||||
for i in (seq 1 (count $argv))
|
||||
switch $argv[$i]
|
||||
case "--domain"
|
||||
set domain $argv[(math $i + 1)]
|
||||
case "--embedder"
|
||||
set embedder $argv[(math $i + 1)]
|
||||
case "--k"
|
||||
set k $argv[(math $i + 1)]
|
||||
case "--tasks"
|
||||
set j (math $i + 1)
|
||||
set tasks
|
||||
while test $j -le (count $argv) && not string match -q -- "--*" $argv[$j]
|
||||
set -a tasks $argv[$j]
|
||||
set j (math $j + 1)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
echo "Running with the following parameters:"
|
||||
echo "Domain: $domain"
|
||||
echo "Embedder: $embedder"
|
||||
echo "k: $k"
|
||||
echo "Datasets: $tasks"
|
||||
|
||||
# Create directory for results
|
||||
set results_dir "retrieval_results"
|
||||
mkdir -p $results_dir
|
||||
|
||||
# Process each dataset using retrieval_demo directly
|
||||
for task in $tasks
|
||||
echo ""
|
||||
echo "===== Processing dataset: $task ====="
|
||||
|
||||
# Step 1: Run retrieval_demo with flat index to generate cache and get results
|
||||
echo "Running retrieval for $task..."
|
||||
echo "python demo/main.py --domain $domain --task $task --search --load flat --lazy"
|
||||
python demo/main.py --domain $domain --task $task --search --load flat --lazy
|
||||
|
||||
# Check if successful
|
||||
if test $status -ne 0
|
||||
echo "Retrieval for $task failed"
|
||||
continue
|
||||
end
|
||||
|
||||
echo "Completed processing for $task"
|
||||
echo "--------------------------------"
|
||||
end
|
||||
|
||||
echo "All operations completed successfully!"
|
||||
echo "The cache files have been created at the locations specified by get_flat_cache_path() in config.py"
|
||||
echo "You can now use test_all_datasets.py to view the results"
|
||||
61
research/utils/mem_monitor.fish
Normal file
61
research/utils/mem_monitor.fish
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env fish
|
||||
|
||||
# 用法: ./mem_monitor.fish <PID> [interval_seconds=5]
|
||||
# 比如: ./mem_monitor.fish 110303 5
|
||||
# 会在当前目录写一个 mem_usage_110303.log,每隔5秒记录一次RSS和VSZ(单位MB)。
|
||||
|
||||
function usage
|
||||
echo "用法: mem_monitor.fish <PID> [interval_seconds=5]"
|
||||
exit 1
|
||||
end
|
||||
|
||||
if test (count $argv) -lt 1
|
||||
usage
|
||||
end
|
||||
|
||||
set pid $argv[1]
|
||||
set interval 5
|
||||
if test (count $argv) -gt 1
|
||||
set interval $argv[2]
|
||||
end
|
||||
|
||||
# 输出到 mem_usage_<pid>.log
|
||||
set logfile mem_usage_$pid.log
|
||||
echo "写入日志: $logfile"
|
||||
echo "timestamp,rss_MB,vms_MB" > $logfile
|
||||
|
||||
# 轮询检查
|
||||
while true
|
||||
# 获取 RSS/VSZ (KB) 值
|
||||
# 兼容 macOS 的 ps 命令,不使用 Linux 特有的选项
|
||||
set proc_line (ps -p $pid -o rss,vsz | tail -n +2 2>/dev/null)
|
||||
|
||||
# 若取不到(进程已退出),则停止
|
||||
if test -z "$proc_line"
|
||||
echo "进程 $pid 已退出或不存在."
|
||||
exit 0
|
||||
end
|
||||
|
||||
# 将单行字符串通过空格拆分为数组,如 ("79673856" "95904664")
|
||||
set arr (string split ' ' (string trim $proc_line))
|
||||
if test (count $arr) -lt 2
|
||||
echo "解析 ps 输出时出现意外: $proc_line"
|
||||
exit 1
|
||||
end
|
||||
|
||||
# 分别赋值
|
||||
set rss_kb $arr[1]
|
||||
set vsz_kb $arr[2]
|
||||
|
||||
# 时间戳
|
||||
set t (date "+%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 转换成 MB
|
||||
set rss_MB (math "$rss_kb / 1024.0")
|
||||
set vsz_MB (math "$vsz_kb / 1024.0")
|
||||
|
||||
# 写日志
|
||||
echo "$t,$rss_MB,$vsz_MB" >> $logfile
|
||||
|
||||
sleep $interval
|
||||
end
|
||||
76
research/utils/modelfile.fish
Normal file
76
research/utils/modelfile.fish
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env fish
|
||||
|
||||
# 创建无chat template的纯llama模型,对齐sglang的generate行为
|
||||
# 支持llama3.2的1b和3b型号,以及llama3.1的8b型号
|
||||
|
||||
function create_llama32_model
|
||||
set model_size $argv[1]
|
||||
set model_name "llama3.2:$model_size-pure"
|
||||
|
||||
echo "创建 $model_name 模型..."
|
||||
|
||||
# 创建临时Modelfile
|
||||
echo "FROM llama3.2:$model_size
|
||||
TEMPLATE \"\"
|
||||
PARAMETER stop \"\"
|
||||
PARAMETER stop \"<|start_header_id|>\"
|
||||
PARAMETER stop \"<|end_header_id|>\"
|
||||
PARAMETER stop \"<|eot_id|>\"
|
||||
PARAMETER stop \"USER:\"
|
||||
PARAMETER stop \"ASSISTANT:\"
|
||||
PARAMETER temperature 1.0
|
||||
PARAMETER num_ctx 4096
|
||||
PARAMETER seed 1234
|
||||
PARAMETER num_predict 100" > Modelfile
|
||||
|
||||
# 创建模型
|
||||
ollama create $model_name -f ./Modelfile
|
||||
|
||||
# 清理临时文件
|
||||
rm Modelfile
|
||||
|
||||
echo "$model_name 创建完成"
|
||||
end
|
||||
|
||||
function create_llama31_model
|
||||
set model_size $argv[1]
|
||||
set model_name "llama3.1:$model_size-pure"
|
||||
|
||||
echo "创建 $model_name 模型..."
|
||||
|
||||
# 创建临时Modelfile
|
||||
echo "FROM llama3.1:$model_size
|
||||
TEMPLATE \"\"
|
||||
PARAMETER stop \"\"
|
||||
PARAMETER stop \"<|start_header_id|>\"
|
||||
PARAMETER stop \"<|end_header_id|>\"
|
||||
PARAMETER stop \"<|eot_id|>\"
|
||||
PARAMETER stop \"USER:\"
|
||||
PARAMETER stop \"ASSISTANT:\"
|
||||
PARAMETER temperature 1.0
|
||||
PARAMETER num_ctx 4096
|
||||
PARAMETER seed 1234
|
||||
PARAMETER num_predict 100" > Modelfile
|
||||
|
||||
# 创建模型
|
||||
ollama create $model_name -f ./Modelfile
|
||||
|
||||
# 清理临时文件
|
||||
rm Modelfile
|
||||
|
||||
echo "$model_name 创建完成"
|
||||
end
|
||||
|
||||
# 创建Llama 3.2的1b和3b模型
|
||||
for size in 1b 3b
|
||||
create_llama32_model $size
|
||||
end
|
||||
|
||||
# 创建Llama 3.1的8b模型
|
||||
create_llama31_model 8b
|
||||
|
||||
echo "完成! 所有纯文本llama模型已创建"
|
||||
echo "使用方法: "
|
||||
echo "- ollama run llama3.2:1b-pure \"你的提示\""
|
||||
echo "- ollama run llama3.2:3b-pure \"你的提示\""
|
||||
echo "- ollama run llama3.1:8b-pure \"你的提示\""
|
||||
193
research/utils/parse_logs_draw_skip_reorder.py
Normal file
193
research/utils/parse_logs_draw_skip_reorder.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import argparse
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
|
||||
def parse_log(log_file_path):
|
||||
"""
|
||||
Parses the log file to extract relevant data for accuracy-recall curve comparison.
|
||||
Args:
|
||||
log_file_path (str): Path to the log file to parse.
|
||||
Returns:
|
||||
dict: A dictionary containing extracted results.
|
||||
"""
|
||||
data = {
|
||||
"recalls_with_skip": [],
|
||||
"f1_scores_with_skip": [],
|
||||
"exact_match_scores_with_skip": [],
|
||||
"recalls_without_skip": [],
|
||||
"f1_scores_without_skip": [],
|
||||
"exact_match_scores_without_skip": [],
|
||||
"nprobe_values": [],
|
||||
}
|
||||
|
||||
with open(log_file_path, "r") as file:
|
||||
logs = file.readlines()
|
||||
|
||||
# Variables to track the state during parsing
|
||||
is_skip_reorder_true = False
|
||||
is_skip_reorder_false = False
|
||||
current_nprobe = None
|
||||
|
||||
for line in logs:
|
||||
# Debug: print the current line being processed
|
||||
# print(f"Processing line: {line.strip()}")
|
||||
|
||||
# Check for skip_reorder flag
|
||||
if "skip_search_reorder=True" in line:
|
||||
is_skip_reorder_true = True
|
||||
is_skip_reorder_false = False
|
||||
elif "skip_search_reorder=False" in line:
|
||||
is_skip_reorder_true = False
|
||||
is_skip_reorder_false = True
|
||||
|
||||
# Extract nprobe values (assuming they are given before the experiment)
|
||||
nprobe_match = re.search(r"nprobe=(\d+)", line)
|
||||
if nprobe_match:
|
||||
current_nprobe = int(nprobe_match.group(1))
|
||||
if current_nprobe not in data["nprobe_values"]:
|
||||
data["nprobe_values"].append(current_nprobe)
|
||||
print(f"Found nprobe value: {current_nprobe}")
|
||||
|
||||
# Extract average recall rate
|
||||
avg_recall_match = re.search(
|
||||
r"Avg recall rate for (flat|diskann): ([0-9\.e\-]+)", line
|
||||
)
|
||||
if avg_recall_match:
|
||||
recall_value = float(avg_recall_match.group(2))
|
||||
print(
|
||||
f"Found avg recall rate: {recall_value} for {avg_recall_match.group(1)} in line {line!r}"
|
||||
)
|
||||
|
||||
if "flat" in avg_recall_match.group(1):
|
||||
# data["recalls_without_skip"].append(recall_value)
|
||||
pass
|
||||
elif "diskann" in avg_recall_match.group(1):
|
||||
if is_skip_reorder_true:
|
||||
data["recalls_with_skip"].append(recall_value)
|
||||
elif is_skip_reorder_false:
|
||||
data["recalls_without_skip"].append(recall_value)
|
||||
|
||||
# Extract exact_match, f1, and recall scores from evaluation results
|
||||
eval_match = re.search(
|
||||
r"\{'exact_match': ([0-9\.]+), 'exact_match_stderr': [0-9\.]+, 'f1': ([0-9\.]+), 'f1_stderr': [0-9\.]+",
|
||||
line,
|
||||
)
|
||||
if eval_match:
|
||||
exact_match = float(eval_match.group(1))
|
||||
f1 = float(eval_match.group(2))
|
||||
|
||||
print(f"Found evaluation results -> Exact Match: {exact_match}, F1: {f1}")
|
||||
|
||||
# Add to appropriate list based on skip_reorder flag
|
||||
if is_skip_reorder_true:
|
||||
data["exact_match_scores_with_skip"].append(exact_match)
|
||||
data["f1_scores_with_skip"].append(f1)
|
||||
elif is_skip_reorder_false:
|
||||
data["exact_match_scores_without_skip"].append(exact_match)
|
||||
data["f1_scores_without_skip"].append(f1)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def plot_skip_reorder_comparison(data, output_dir):
|
||||
"""
|
||||
绘制带有和不带 skip_reorder 参数的准确率-召回率曲线。
|
||||
|
||||
Args:
|
||||
data: The parsed data including recalls, f1 scores, and exact match scores.
|
||||
output_dir: Path where the plot will be saved.
|
||||
"""
|
||||
recalls_with_skip = data["recalls_with_skip"]
|
||||
f1_scores_with_skip = data["f1_scores_with_skip"]
|
||||
exact_match_scores_with_skip = data["exact_match_scores_with_skip"]
|
||||
recalls_without_skip = data["recalls_without_skip"]
|
||||
f1_scores_without_skip = data["f1_scores_without_skip"]
|
||||
exact_match_scores_without_skip = data["exact_match_scores_without_skip"]
|
||||
nprobe_values = data["nprobe_values"]
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Check if data lists are not empty and have the same length before plotting
|
||||
if (
|
||||
recalls_with_skip
|
||||
and len(recalls_with_skip) == len(f1_scores_with_skip)
|
||||
and len(recalls_with_skip) == len(exact_match_scores_with_skip)
|
||||
):
|
||||
plt.plot(
|
||||
recalls_with_skip,
|
||||
f1_scores_with_skip,
|
||||
"bo-",
|
||||
label="F1 Score (with skip_reorder)",
|
||||
markersize=8,
|
||||
linewidth=2,
|
||||
)
|
||||
plt.plot(
|
||||
recalls_with_skip,
|
||||
exact_match_scores_with_skip,
|
||||
"rs-",
|
||||
label="Exact Match (with skip_reorder)",
|
||||
markersize=8,
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
if (
|
||||
recalls_without_skip
|
||||
and len(recalls_without_skip) == len(f1_scores_without_skip)
|
||||
and len(recalls_without_skip) == len(exact_match_scores_without_skip)
|
||||
):
|
||||
plt.plot(
|
||||
recalls_without_skip,
|
||||
f1_scores_without_skip,
|
||||
"go-",
|
||||
label="F1 Score (without skip_reorder)",
|
||||
markersize=8,
|
||||
linewidth=2,
|
||||
)
|
||||
plt.plot(
|
||||
recalls_without_skip,
|
||||
exact_match_scores_without_skip,
|
||||
"ms-",
|
||||
label="Exact Match (without skip_reorder)",
|
||||
markersize=8,
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
plt.xlabel("Recall")
|
||||
plt.ylabel("Score")
|
||||
plt.title("Recall vs Accuracy Comparison")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xlim(0.0, 1.0)
|
||||
|
||||
# Save the plot only if data is present
|
||||
if len(nprobe_values) > 0:
|
||||
plot_path = os.path.join(
|
||||
output_dir,
|
||||
f"recall_vs_acc_comparison.png",
|
||||
)
|
||||
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
|
||||
print(f"Plot saved to {plot_path}")
|
||||
else:
|
||||
print("No valid data to plot.")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse log file and plot results")
|
||||
parser.add_argument(
|
||||
"log_file_path", type=str, help="Path to the log file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, help="Path to the output directory", default="skip_reorder_comparison"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse the log
|
||||
parsed_data = parse_log(args.log_file_path)
|
||||
|
||||
print(parsed_data)
|
||||
|
||||
# Plot the data
|
||||
plot_skip_reorder_comparison(parsed_data, args.output_dir)
|
||||
114
research/utils/plot_degree_distribution.py
Normal file
114
research/utils/plot_degree_distribution.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def plot_degree_distribution(degree_file_path: str, output_image_path: str):
|
||||
"""
|
||||
Reads a file containing node degrees (one per line) and plots the
|
||||
degree distribution as a histogram.
|
||||
|
||||
Args:
|
||||
degree_file_path: Path to the file containing degrees.
|
||||
output_image_path: Path to save the output plot image.
|
||||
"""
|
||||
try:
|
||||
# Read degrees from the file
|
||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
||||
print(f"[LOG] Read {len(degrees)} degrees from {degree_file_path}")
|
||||
|
||||
if len(degrees) == 0:
|
||||
print("[WARN] Degree file is empty. No plot generated.")
|
||||
return
|
||||
|
||||
# Calculate basic statistics
|
||||
min_deg = np.min(degrees)
|
||||
max_deg = np.max(degrees)
|
||||
avg_deg = np.mean(degrees)
|
||||
median_deg = np.median(degrees)
|
||||
|
||||
print(f"[LOG] Degree Stats: Min={min_deg}, Max={max_deg}, Avg={avg_deg:.2f}, Median={median_deg}")
|
||||
|
||||
# Plotting the distribution
|
||||
plt.figure(figsize=(10, 6))
|
||||
# Determine appropriate number of bins, maybe max_deg+1 if not too large
|
||||
# Or use automatic binning like 'auto' or Sturges' rule etc.
|
||||
# Using max_deg - min_deg + 1 bins can be too many if the range is large
|
||||
# Let's try 'auto' binning first
|
||||
n_bins = 'auto'
|
||||
# If max_deg is reasonably small, we can use exact bins
|
||||
if max_deg <= 1000: # Heuristic threshold
|
||||
n_bins = max_deg - min_deg + 1
|
||||
|
||||
counts, bin_edges, patches = plt.hist(degrees, bins=n_bins, edgecolor='black', alpha=0.7)
|
||||
plt.xlabel("Node Degree")
|
||||
plt.ylabel("Number of Nodes")
|
||||
plt.title(f"Degree Distribution (from {os.path.basename(degree_file_path)})")
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
|
||||
# Add text for statistics on the plot
|
||||
stats_text = (
|
||||
f"Total Nodes: {len(degrees)}\n"
|
||||
f"Min Degree: {min_deg}\n"
|
||||
f"Max Degree: {max_deg}\n"
|
||||
f"Avg Degree: {avg_deg:.2f}\n"
|
||||
f"Median Degree: {median_deg}"
|
||||
)
|
||||
# Position the text box; adjust as needed
|
||||
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
|
||||
fontsize=9, verticalalignment='top', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
|
||||
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_image_path)
|
||||
print(f"[LOG] Degree distribution plot saved to {output_image_path}")
|
||||
# plt.show() # Uncomment if you want to display the plot interactively
|
||||
|
||||
# Create weighted degree distribution plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
# Calculate weighted distribution (degree * number of nodes)
|
||||
unique_degrees, degree_counts = np.unique(degrees, return_counts=True)
|
||||
weighted_counts = unique_degrees * degree_counts
|
||||
|
||||
plt.bar(unique_degrees, weighted_counts, edgecolor='black', alpha=0.7)
|
||||
plt.xlabel("Node Degree")
|
||||
plt.ylabel("Degree × Number of Nodes")
|
||||
plt.title(f"Weighted Degree Distribution (from {os.path.basename(degree_file_path)})")
|
||||
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
||||
|
||||
# Add text for statistics on the plot
|
||||
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
|
||||
fontsize=9, verticalalignment='top', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
|
||||
|
||||
# Generate weighted output filename based on the original output path
|
||||
weighted_output_path = os.path.splitext(output_image_path)[0] + "_weighted" + os.path.splitext(output_image_path)[1]
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(weighted_output_path)
|
||||
print(f"[LOG] Weighted degree distribution plot saved to {weighted_output_path}")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"[ERROR] Degree file not found: {degree_file_path}")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] An error occurred: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Plot the degree distribution from a file containing node degrees."
|
||||
)
|
||||
parser.add_argument(
|
||||
"degree_file",
|
||||
type=str,
|
||||
help="Path to the input file containing node degrees (one degree per line)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output",
|
||||
type=str,
|
||||
default="degree_distribution.png",
|
||||
help="Path to save the output plot image (default: degree_distribution.png)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
plot_degree_distribution(args.degree_file, args.output)
|
||||
244
research/utils/prepare_query_files.py
Normal file
244
research/utils/prepare_query_files.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Ensure project root is in path for imports
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
try:
|
||||
from datasets import load_dataset, Dataset, IterableDataset
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
print("Error: Required libraries 'datasets' or 'tqdm' not found.")
|
||||
print("Please install them using: pip install datasets tqdm")
|
||||
sys.exit(1)
|
||||
|
||||
from demo.config import TASK_CONFIGS, get_example_path
|
||||
|
||||
# Color constants for output
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# --- Dataset Specific Loading Configurations ---
|
||||
# You might need to adjust these based on the exact dataset structure on Hugging Face
|
||||
DATASET_LOAD_INFO = {
|
||||
"nq": {
|
||||
"hf_name": "nq_open",
|
||||
"split": "validation", # NQ Open doesn't have a standard HF dataset, usually custom splits are used.
|
||||
# This entry assumes a custom formatted source or might need manual creation.
|
||||
# Let's mark it as needing manual setup for now.
|
||||
"query_key": "question",
|
||||
"needs_manual_setup": True,
|
||||
"manual_setup_instructions": "nq_open requires a pre-formatted file. Please ensure it exists at the target path."
|
||||
},
|
||||
"trivia": {
|
||||
"hf_name": "trivia_qa",
|
||||
"subset": "rc.nocontext", # Use rc.nocontext as a valid config
|
||||
"split": "validation",
|
||||
"query_key": "question",
|
||||
"needs_manual_setup": False
|
||||
},
|
||||
"hotpot": {
|
||||
"hf_name": "hotpot_qa",
|
||||
"subset": "distractor", # Explicitly choose the 'distractor' config
|
||||
"split": "validation",
|
||||
"query_key": "question",
|
||||
"needs_manual_setup": False
|
||||
},
|
||||
"gpqa": {
|
||||
"hf_name": "Idavidrein/gpqa", # Corrected HF identifier
|
||||
"subset": "gpqa_main", # Use subset (name) for the configuration
|
||||
"split": "train", # Align with evaluation_demo
|
||||
"query_key": "Question", # CORRECTED: Use uppercase 'Q' as found in the dataset item
|
||||
"needs_manual_setup": False # Assuming this config loads correctly now
|
||||
},
|
||||
"retrievalqa": {
|
||||
"hf_name": "aialt/RetrievalQA",
|
||||
"split": "train",
|
||||
"query_key": "question",
|
||||
"needs_manual_setup": False,
|
||||
"custom_loading": True # Flag to use custom loading logic
|
||||
}
|
||||
}
|
||||
# --- End Dataset Specific Loading Configurations ---
|
||||
|
||||
def format_query(original_query: str) -> str:
|
||||
# """Formats the query string according to the NQ example."""
|
||||
# # Basic check to prevent double formatting if somehow the prefix is already there
|
||||
# if original_query.startswith("Answer these questions:"):
|
||||
# return original_query
|
||||
# return f"Answer these questions:\n\nQ: {original_query}?\nA:"
|
||||
return original_query
|
||||
|
||||
def load_retrievalqa():
|
||||
"""Custom function to load the RetrievalQA dataset with its complex structure.
|
||||
Downloads the JSONL file directly and processes it line by line to avoid schema issues.
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
print(f"{YELLOW}Attempting to directly download and parse RetrievalQA dataset...{RESET}")
|
||||
url = "https://huggingface.co/datasets/aialt/RetrievalQA/resolve/main/retrievalqa.jsonl"
|
||||
|
||||
# Create a temp file to store the downloaded data
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp_file:
|
||||
# Download the file
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Ensure we got a successful response
|
||||
|
||||
# Process the dataset line by line to avoid schema issues
|
||||
data = []
|
||||
line_count = 0
|
||||
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line: # Skip empty lines
|
||||
try:
|
||||
item = json.loads(line)
|
||||
# Extract just what we need - the question
|
||||
if "question" in item:
|
||||
data.append({"question": item["question"]})
|
||||
line_count += 1
|
||||
if line_count % 500 == 0:
|
||||
print(f"{YELLOW}Processed {line_count} lines...{RESET}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"{RED}Error parsing JSON at line {line_count}: {e}{RESET}")
|
||||
continue
|
||||
|
||||
print(f"{GREEN}Successfully parsed {len(data)} questions from RetrievalQA dataset.{RESET}")
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"{RED}Custom loading for RetrievalQA failed: {e}{RESET}")
|
||||
raise
|
||||
|
||||
def prepare_file(task: str, force_overwrite: bool = False):
|
||||
"""Loads, formats, and saves the query file for a specific task."""
|
||||
print(f"--- Processing task: {task} ---")
|
||||
|
||||
if task not in TASK_CONFIGS:
|
||||
print(f"{RED}Error: Task '{task}' not found in TASK_CONFIGS in config.py.{RESET}")
|
||||
return False
|
||||
|
||||
if task not in DATASET_LOAD_INFO:
|
||||
print(f"{RED}Error: Loading configuration for task '{task}' not defined in DATASET_LOAD_INFO.{RESET}")
|
||||
return False
|
||||
|
||||
config = TASK_CONFIGS[task]
|
||||
load_info = DATASET_LOAD_INFO[task]
|
||||
target_path = Path(config.query_path) # Use the path from config.py
|
||||
|
||||
if target_path.exists() and not force_overwrite:
|
||||
print(f"{YELLOW}Target file already exists: {target_path}. Skipping.{RESET}")
|
||||
print(f"Use --force to overwrite.")
|
||||
return True
|
||||
|
||||
# Initialize query_key before the try block uses it in except
|
||||
query_key: Optional[str] = None
|
||||
try:
|
||||
# Use custom loading for retrievalqa
|
||||
if task == "retrievalqa" and load_info.get('custom_loading', False):
|
||||
raw_dataset = load_retrievalqa()
|
||||
|
||||
# Custom handling for retrievalqa data format (list of dicts)
|
||||
print(f"Formatting and saving queries to {target_path}...")
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
with open(target_path, "w", encoding="utf-8") as f_out:
|
||||
for item in tqdm(raw_dataset, desc=f"Formatting {task}"):
|
||||
if "question" in item:
|
||||
formatted_query = format_query(item["question"])
|
||||
f_out.write(json.dumps({"query": formatted_query}) + "\n")
|
||||
count += 1
|
||||
|
||||
print(f"{GREEN}Successfully generated query file for {task} with {count} queries: {target_path}{RESET}")
|
||||
return True
|
||||
else:
|
||||
print(f"Loading raw dataset: {load_info['hf_name']} (subset: {load_info.get('subset')}, split: {load_info['split']}) ...")
|
||||
raw_dataset = load_dataset(
|
||||
load_info['hf_name'],
|
||||
name=load_info.get('subset'), # Pass the config name via 'name' (subset)
|
||||
split=load_info['split']
|
||||
)
|
||||
|
||||
query_key = load_info['query_key']
|
||||
print(f"Formatting and saving queries to {target_path}...")
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
|
||||
|
||||
count = 0
|
||||
with open(target_path, "w", encoding="utf-8") as f_out:
|
||||
for item in tqdm(raw_dataset, desc=f"Formatting {task}"):
|
||||
if not isinstance(item, dict) or query_key not in item:
|
||||
print(f"{YELLOW}Warning: Skipping item due to unexpected format or missing key '{query_key}'. Item: {item}{RESET}")
|
||||
continue
|
||||
original_query = item[query_key]
|
||||
formatted_query = format_query(original_query)
|
||||
f_out.write(json.dumps({"query": formatted_query}) + "\n")
|
||||
count += 1
|
||||
|
||||
print(f"{GREEN}Successfully generated query file for {task} with {count} queries: {target_path}{RESET}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"{RED}Error processing task '{task}': {e}{RESET}")
|
||||
key_info = f"query key ('{query_key}')" if query_key else "query key (not assigned)"
|
||||
print(f"Check dataset name ('{load_info['hf_name']}'), subset ('{load_info.get('subset')}'), split ('{load_info['split']}'), and {key_info}.")
|
||||
print(f"Target path was: {target_path}")
|
||||
# Attempt to clean up potentially incomplete file
|
||||
if target_path.exists():
|
||||
try:
|
||||
target_path.unlink()
|
||||
print(f"{YELLOW}Cleaned up potentially incomplete file: {target_path}{RESET}")
|
||||
except OSError as unlink_e:
|
||||
print(f"{RED}Error cleaning up file {target_path}: {unlink_e}{RESET}")
|
||||
return False # Indicate failure
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Prepare formatted query files for datasets.")
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
nargs="+",
|
||||
default=["nq", "trivia", "hotpot", "gpqa", "retrievalqa"],
|
||||
choices=list(TASK_CONFIGS.keys()),
|
||||
help="Which tasks to prepare query files for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Overwrite existing query files."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Starting query file preparation for tasks: {', '.join(args.tasks)}")
|
||||
if args.force:
|
||||
print(f"{YELLOW}Force overwrite enabled.{RESET}")
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
for task in args.tasks:
|
||||
if prepare_file(task, args.force):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print("\n--- Preparation Summary ---")
|
||||
print(f"Tasks processed: {len(args.tasks)}")
|
||||
print(f"Successful: {success_count}")
|
||||
print(f"Failed/Skipped due to errors or manual setup needed: {fail_count}")
|
||||
if fail_count > 0:
|
||||
print(f"{RED}Some tasks requires manual intervention or failed. Please check the logs above.{RESET}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"{GREEN}All specified tasks prepared successfully.{RESET}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
research/utils/run_bm25.fish
Normal file
5
research/utils/run_bm25.fish
Normal file
@@ -0,0 +1,5 @@
|
||||
python demo/main.py --task nq --search --use-original --load bm25 --skip-pa
|
||||
|
||||
python demo/main.py --task trivia --search --use-original --load bm25 --skip-pa
|
||||
python demo/main.py --task gpqa --search --use-original --load bm25 --skip-pa
|
||||
python demo/main.py --task hotpot --search --use-original --load bm25 --skip-pa
|
||||
313
research/utils/run_recall_experiment.py
Normal file
313
research/utils/run_recall_experiment.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import argparse
|
||||
import numpy as np # Added for sorting unique nprobes
|
||||
|
||||
# --- Configuration ---
|
||||
PYTHON_EXE = "python" # Or specify the full path if needed
|
||||
SCRIPT_TO_RUN = "run_server.py"
|
||||
LOG_DIR = "exp_logs" # Directory where run_server.py saves logs/summaries
|
||||
RESULTS_DIR = "experiment_results" # Directory to save plot and data
|
||||
os.makedirs(RESULTS_DIR, exist_ok=True)
|
||||
|
||||
# --- Argument Parsing for Flexibility ---
|
||||
parser = argparse.ArgumentParser(description="Run recall experiments or plot results from existing CSV.")
|
||||
parser.add_argument('--nprobes', type=str, default="2,4,8,16,32,64", help='Comma-separated list of nprobe values (used when running experiments)')
|
||||
parser.add_argument('--degrees', type=str, default="None,30,240", help='Comma-separated list of degree values (use None for default, used when running experiments)')
|
||||
parser.add_argument('--task', type=str, default="nq", help='Task argument for run_server.py (used when running experiments, or inferred from --input-csv)')
|
||||
parser.add_argument('--input-csv', type=str, default=None, help='Path to an existing CSV file to plot directly, skipping experiments.') # New argument
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Initialize Variables ---
|
||||
results_df = None
|
||||
task_name = args.task # Default task name
|
||||
NPROBE_VALUES = []
|
||||
DEGREE_VALUES = []
|
||||
|
||||
# --- Mode Selection: Run Experiments or Plot from CSV ---
|
||||
|
||||
if args.input_csv:
|
||||
# --- Plot from CSV Mode ---
|
||||
print(f"--- Plotting from existing CSV: {args.input_csv} ---")
|
||||
try:
|
||||
results_df = pd.read_csv(args.input_csv)
|
||||
print(f"Loaded data with {len(results_df)} rows.")
|
||||
|
||||
# ---- NEW: Replace 'default' string with 60 if present ----
|
||||
if 'degree' in results_df.columns and results_df['degree'].dtype == 'object': # Check if column exists and might contain strings
|
||||
results_df['degree'] = results_df['degree'].replace('default', 60)
|
||||
# Attempt to convert the column to numeric after replacement
|
||||
results_df['degree'] = pd.to_numeric(results_df['degree'], errors='coerce')
|
||||
print("Replaced 'default' degree values with 60 and converted column to numeric.")
|
||||
# ---- END NEW ----
|
||||
|
||||
# Infer task name from filename if possible
|
||||
match = re.search(r'results_([^_]+)_[\d_]+\.csv', os.path.basename(args.input_csv))
|
||||
if match:
|
||||
task_name = match.group(1)
|
||||
print(f"Inferred task name: {task_name}")
|
||||
else:
|
||||
print(f"Could not infer task name from filename, using default: {task_name}")
|
||||
|
||||
# Get NPROBE_VALUES from loaded data for plotting ticks
|
||||
if 'nprobe' in results_df.columns:
|
||||
NPROBE_VALUES = sorted(results_df['nprobe'].unique())
|
||||
print(f"Nprobe values from data: {NPROBE_VALUES}")
|
||||
else:
|
||||
print("Warning: 'nprobe' column not found in CSV. Plot ticks might be incorrect.")
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input CSV file not found: {args.input_csv}")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error reading CSV file {args.input_csv}: {e}")
|
||||
exit(1)
|
||||
|
||||
else:
|
||||
# --- Run Experiments Mode ---
|
||||
print("--- Running New Experiments ---")
|
||||
# Parse nprobe values
|
||||
try:
|
||||
NPROBE_VALUES = [int(p.strip()) for p in args.nprobes.split(',')]
|
||||
except ValueError:
|
||||
print("Error: Invalid nprobe values. Please provide comma-separated integers.")
|
||||
exit(1)
|
||||
|
||||
# Parse degree values
|
||||
DEGREE_VALUES = []
|
||||
for d_str in args.degrees.split(','):
|
||||
d_str = d_str.strip()
|
||||
if d_str.lower() == 'none':
|
||||
DEGREE_VALUES.append(None)
|
||||
else:
|
||||
try:
|
||||
DEGREE_VALUES.append(int(d_str))
|
||||
except ValueError:
|
||||
print(f"Error: Invalid degree value '{d_str}'. Use 'None' or integers.")
|
||||
exit(1)
|
||||
|
||||
print(f"Nprobe values to test: {NPROBE_VALUES}")
|
||||
print(f"Degree values to test: {DEGREE_VALUES}")
|
||||
print(f"Task: {task_name}") # Use task_name
|
||||
|
||||
|
||||
# --- Helper Functions (Only needed for experiment mode) ---
|
||||
def parse_recall_from_summary(summary_file_path):
|
||||
"""Parses the recall rate from the summary file."""
|
||||
try:
|
||||
with open(summary_file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
# Regex to find the recall rate line
|
||||
match = re.search(r"Average Recall Rate:\s*([\d.]+)", content)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
else:
|
||||
print(f"Warning: Could not find recall rate in {summary_file_path}")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Summary file not found at {summary_file_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error reading or parsing summary file {summary_file_path}: {e}")
|
||||
return None
|
||||
|
||||
def find_summary_file(output_text):
|
||||
"""Finds the summary file path from the script's output."""
|
||||
# Regex to find the summary file path line
|
||||
match = re.search(r"Summary written to:\s*(.*\.txt)", output_text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
else:
|
||||
# Fallback: Search for any summary file pattern in the log directory if not found in stdout
|
||||
print("Warning: Could not find summary file path in script output. Searching log directory...")
|
||||
try:
|
||||
# Look for the most recent summary file
|
||||
log_files = [os.path.join(LOG_DIR, f) for f in os.listdir(LOG_DIR) if f.startswith("summary_") and f.endswith(".txt")]
|
||||
if log_files:
|
||||
latest_summary = max(log_files, key=os.path.getmtime)
|
||||
print(f"Found potential summary file by search: {latest_summary}")
|
||||
return latest_summary
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Log directory '{LOG_DIR}' not found during fallback search.")
|
||||
except Exception as e:
|
||||
print(f"Error during fallback summary file search: {e}")
|
||||
print("Fallback search failed.")
|
||||
return None
|
||||
|
||||
|
||||
# --- Main Experiment Loop ---
|
||||
results = []
|
||||
start_experiment_time = time.time()
|
||||
|
||||
for degree in DEGREE_VALUES:
|
||||
for nprobe in NPROBE_VALUES:
|
||||
run_start_time = time.time()
|
||||
degree_str = str(degree) if degree is not None else "Default"
|
||||
print(f"\n--- Running Experiment: degree={degree_str}, nprobe={nprobe}, task={task_name} ---") # Use task_name
|
||||
|
||||
# Base command
|
||||
cmd = [
|
||||
PYTHON_EXE, "-u", SCRIPT_TO_RUN,
|
||||
"--nprobe", str(nprobe),
|
||||
"--task", task_name # Use task_name
|
||||
]
|
||||
# Add degree if specified
|
||||
if degree is not None:
|
||||
cmd.extend(["--degree", str(degree)])
|
||||
|
||||
print(f"Executing command: {' '.join(cmd)}")
|
||||
|
||||
recall_rate = None
|
||||
summary_file = None
|
||||
process_returncode = -1 # Default to error
|
||||
|
||||
try:
|
||||
# Run the script and capture output
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False, # Check returncode manually
|
||||
encoding='utf-8',
|
||||
errors='replace'
|
||||
)
|
||||
process_returncode = process.returncode
|
||||
|
||||
print(f"Command finished with return code: {process.returncode}")
|
||||
# Uncomment below to see full output for debugging
|
||||
# print("--- stdout ---")
|
||||
# print(process.stdout)
|
||||
# print("--- stderr ---")
|
||||
# print(process.stderr)
|
||||
# print("--------------")
|
||||
|
||||
# Attempt to find summary file path from stdout
|
||||
summary_file = find_summary_file(process.stdout)
|
||||
|
||||
if summary_file:
|
||||
# Give the filesystem a moment before reading
|
||||
time.sleep(1)
|
||||
recall_rate = parse_recall_from_summary(summary_file)
|
||||
else:
|
||||
print("ERROR: Could not locate summary file for this run.")
|
||||
|
||||
if process.returncode != 0:
|
||||
print(f"Warning: Script execution failed (return code {process.returncode}) for degree={degree_str}, nprobe={nprobe}.")
|
||||
# Recall might still be None or potentially parsed if summary existed
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"CRITICAL ERROR: Could not find script '{SCRIPT_TO_RUN}' or Python executable '{PYTHON_EXE}'.")
|
||||
# No result to append for this specific run in this case
|
||||
continue # Skip to next iteration
|
||||
except Exception as e:
|
||||
print(f"CRITICAL ERROR running experiment for degree={degree_str}, nprobe={nprobe}: {e}")
|
||||
# Append error result
|
||||
results.append({
|
||||
"degree": 60 if degree is None else degree, # <-- MODIFIED: Use 60 for None
|
||||
"nprobe": nprobe,
|
||||
"recall": None,
|
||||
"duration_s": time.time() - run_start_time,
|
||||
"return_code": process_returncode, # Use captured or default error code
|
||||
"summary_file": summary_file,
|
||||
"error": str(e)
|
||||
})
|
||||
continue # Skip to next iteration
|
||||
|
||||
run_duration = time.time() - run_start_time
|
||||
print(f"Result: degree={degree_str}, nprobe={nprobe}, recall={recall_rate}, duration={run_duration:.2f}s")
|
||||
results.append({
|
||||
"degree": 60 if degree is None else degree, # <-- MODIFIED: Use 60 for None
|
||||
"nprobe": nprobe,
|
||||
"recall": recall_rate,
|
||||
"duration_s": run_duration,
|
||||
"return_code": process_returncode,
|
||||
"summary_file": summary_file,
|
||||
"error": None if process_returncode == 0 and recall_rate is not None else "Run failed or recall not found"
|
||||
})
|
||||
|
||||
# Optional: add a small delay between runs if needed
|
||||
# time.sleep(5)
|
||||
|
||||
# --- Post-Experiment Processing ---
|
||||
print("\n--- Experiment Complete. Processing Results ---")
|
||||
total_duration = time.time() - start_experiment_time
|
||||
print(f"Total experiment duration: {total_duration:.2f}s")
|
||||
|
||||
if not results:
|
||||
print("No results collected. Exiting.")
|
||||
exit()
|
||||
|
||||
# Convert to DataFrame
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
# Save results to CSV
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results_csv_path = os.path.join(RESULTS_DIR, f"experiment_results_{task_name}_{timestamp}.csv") # Use task_name
|
||||
try:
|
||||
results_df.to_csv(results_csv_path, index=False)
|
||||
print(f"Results saved to: {results_csv_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving results to CSV: {e}")
|
||||
|
||||
|
||||
# --- Data Processing and Plotting (Common to both modes) ---
|
||||
|
||||
if results_df is None or results_df.empty:
|
||||
print("No data available to plot. Exiting.")
|
||||
exit()
|
||||
|
||||
print("\n--- Generating Plot ---")
|
||||
|
||||
# Filter out runs where recall could not be parsed or is missing
|
||||
plot_df = results_df.dropna(subset=['recall'])
|
||||
|
||||
# Filter out rows where degree is 'default' as we need numeric degree for calculation
|
||||
# Also ensure nprobe is numeric
|
||||
plot_df_numeric = plot_df[pd.to_numeric(plot_df['degree'], errors='coerce').notna()].copy()
|
||||
plot_df_numeric['degree'] = pd.to_numeric(plot_df_numeric['degree'])
|
||||
plot_df_numeric['nprobe'] = pd.to_numeric(plot_df_numeric['nprobe']) # Ensure nprobe is numeric
|
||||
|
||||
if plot_df_numeric.empty:
|
||||
print("No successful runs with numeric degree and recall values found. Cannot generate plot.")
|
||||
else:
|
||||
# Calculate the new x-axis value
|
||||
plot_df_numeric['degree_times_nprobe'] = plot_df_numeric['degree'] * plot_df_numeric['nprobe']
|
||||
|
||||
# Convert 'degree' column back to string for legend grouping
|
||||
plot_df_numeric['degree_label'] = plot_df_numeric['degree'].astype(int).astype(str)
|
||||
|
||||
# Plotting
|
||||
plt.figure(figsize=(12, 7))
|
||||
|
||||
# Group by degree and plot recall vs degree * nprobe
|
||||
# Sort group by the new x-axis value for correct line plotting
|
||||
for degree_label, group in plot_df_numeric.groupby('degree_label'):
|
||||
group = group.sort_values('degree_times_nprobe')
|
||||
plt.plot(group['degree_times_nprobe'], group['recall'], marker='o', linestyle='-', label=f'Degree={degree_label}')
|
||||
|
||||
plt.xlabel("Degree * Nprobe") # Updated X-axis label
|
||||
plt.ylabel("Average Recall Rate")
|
||||
plt.title(f"Recall Rate vs. Degree * Nprobe (Task: {task_name})") # Updated title
|
||||
plt.legend(title="Graph Degree")
|
||||
plt.grid(True, which="both", linestyle='--', linewidth=0.5)
|
||||
# plt.xscale('log', base=2) # Removed log scale, let matplotlib decide or adjust later if needed
|
||||
# plt.xticks(...) # Removed custom ticks, let matplotlib decide
|
||||
|
||||
# Save plot
|
||||
plot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Use a new timestamp for the plot
|
||||
# Updated filename
|
||||
plot_path = os.path.join(RESULTS_DIR, f"recall_vs_degree_nprobe_{task_name}_{plot_timestamp}.png")
|
||||
try:
|
||||
plt.savefig(plot_path)
|
||||
print(f"Plot saved to: {plot_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving plot: {e}")
|
||||
# plt.show() # Uncomment to display the plot interactively
|
||||
|
||||
print("\nDone.")
|
||||
89
research/utils/s3.md
Normal file
89
research/utils/s3.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# How to Download Needed Data from S3
|
||||
|
||||
## Install AWS CLI v2
|
||||
|
||||
Install AWS CLI v2 by following the instructions at https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html.
|
||||
|
||||
|
||||
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
||||
unzip awscliv2.zip
|
||||
sudo ./aws/install
|
||||
|
||||
## Configure SSO
|
||||
|
||||
Run the following command:
|
||||
|
||||
```console
|
||||
aws configure sso
|
||||
```
|
||||
|
||||
Example output:
|
||||
```
|
||||
(retrieval_scaling) (base) ➜ retrieval_scaling git:(main) ✗ aws configure sso
|
||||
SSO session name (Recommended): yichuan
|
||||
SSO start URL [None]: https://ucberkeley.awsapps.com/start#/
|
||||
SSO region [None]: us-west-2
|
||||
SSO registration scopes [sso:account:access]:
|
||||
Attempting to automatically open the SSO authorization page in your default browser.
|
||||
If the browser does not open or you wish to use a different device to authorize this request, open the following URL:
|
||||
|
||||
https://oidc.us-west-2.amazonaws.com/authorize?response_type=code&client_id=i3YtHZTRneXEIApSyvdgSHVzLXdlc3QtMg&redirect_uri=http%3A%2F%2F127.0.0.1%3A37899%2Foauth%2Fcallback&state=5f52320e-0929-4e44-83c7-f6bd9b492010&code_challenge_method=S256&scopes=sso%3Aaccount%3Aaccess&code_challenge=HYnZ4Pc-tqI8CdJb6qEAR0LjI1_UjN-zln26lqJKeL8
|
||||
The only AWS account available to you is: 976193267581
|
||||
Using the account ID 976193267581
|
||||
There are 2 roles available to you.
|
||||
Using the role name "UCB-FederatedAdmins"
|
||||
Default client Region [None]:
|
||||
CLI default output format (json if not specified) [None]:
|
||||
Profile name [UCB-FederatedAdmins-976193267581]:
|
||||
To use this profile, specify the profile name using --profile, as shown:
|
||||
|
||||
aws sts get-caller-identity --profile UCB-FederatedAdmins-976193267581
|
||||
```
|
||||
|
||||
After configuration, you must include `--profile UCB-FederatedAdmins-976193267581` with each AWS operation to use the SSO credentials.
|
||||
|
||||
## Refresh the SSO
|
||||
|
||||
If you encounter the error `Error when retrieving token from sso: Token has expired and refresh failed`, simply run the SSO configuration command again.
|
||||
|
||||
## Download S3 Data
|
||||
|
||||
All data is stored in `s3://retrieval-scaling-out`, which includes 4 directories:
|
||||
- embeddings/
|
||||
- examples/
|
||||
- indices/
|
||||
- passages/
|
||||
|
||||
Download the data using AWS CLI:
|
||||
|
||||
```console
|
||||
aws s3 cp s3://retrieval-scaling-out ~/scaling_out --profile UCB-FederatedAdmins-976193267581
|
||||
|
||||
aws s3 cp s3://retrieval-scaling-out/examples/test_c4.jsonl ~/examples/scaling_out --profile UCB-FederatedAdmins-976193267581
|
||||
```
|
||||
|
||||
### Faster Download Options
|
||||
|
||||
To accelerate downloads, you can try the following methods:
|
||||
|
||||
Use multipart downloads:
|
||||
```console
|
||||
aws s3 cp s3://retrieval-scaling-out ~/scaling_out --profile UCB-FederatedAdmins-976193267581 --recursive --multipart-threshold 128MB --multipart-chunksize 512MB
|
||||
```
|
||||
|
||||
Configure higher concurrency:
|
||||
```console
|
||||
aws configure set default.s3.max_concurrent_requests 50
|
||||
aws configure set default.s3.max_queue_size 10000
|
||||
```
|
||||
|
||||
Utilize S3 Transfer Acceleration:
|
||||
```console
|
||||
aws s3 cp s3://retrieval-scaling-out ~/scaling_out --profile UCB-FederatedAdmins-976193267581 --recursive --endpoint-url https://s3-accelerate.amazonaws.com
|
||||
```
|
||||
|
||||
Or use alternative tools like `s5cmd`:
|
||||
```console
|
||||
pip install s5cmd
|
||||
s5cmd --profile UCB-FederatedAdmins-976193267581 cp s3://retrieval-scaling-out/* ~/scaling_out/
|
||||
```
|
||||
212
research/utils/seperate_hnsw_flat.py
Normal file
212
research/utils/seperate_hnsw_flat.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
INDEX_FLAT_L2_FOURCC = int.from_bytes(b'IxF2', 'little')
|
||||
INDEX_FLAT_IP_FOURCC = int.from_bytes(b'IxFI', 'little')
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
||||
INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||
INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||
INDEX_HNSW_2L_FOURCC = int.from_bytes(b'IHN2', 'little')
|
||||
INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little')
|
||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||
|
||||
HNSW_FOURCCS = {
|
||||
INDEX_HNSW_FLAT_FOURCC,
|
||||
INDEX_HNSW_PQ_FOURCC,
|
||||
INDEX_HNSW_SQ_FOURCC,
|
||||
INDEX_HNSW_2L_FOURCC,
|
||||
INDEX_HNSW_CAGRA_FOURCC,
|
||||
}
|
||||
FLAT_FOURCCS = {INDEX_FLAT_L2_FOURCC, INDEX_FLAT_IP_FOURCC}
|
||||
|
||||
|
||||
# --- Helper functions for reading binary data ---
|
||||
|
||||
def read_struct(f, fmt):
|
||||
"""Reads data according to the struct format."""
|
||||
size = struct.calcsize(fmt)
|
||||
data = f.read(size)
|
||||
if len(data) != size:
|
||||
raise EOFError("File ended unexpectedly.")
|
||||
return struct.unpack(fmt, data)[0]
|
||||
|
||||
def read_vector(f, element_fmt):
|
||||
"""Reads a vector (size followed by data)."""
|
||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
||||
element_size = struct.calcsize(element_fmt)
|
||||
data_bytes = f.read(count * element_size)
|
||||
if len(data_bytes) != count * element_size:
|
||||
raise EOFError("File ended unexpectedly when reading vector data.")
|
||||
# Unpack the elements individually if needed, or return raw bytes
|
||||
# For simplicity here, we'll return the raw bytes and size
|
||||
return count, data_bytes
|
||||
|
||||
def read_vector_data(f, element_fmt, count):
|
||||
"""Reads vector data when the count is known."""
|
||||
element_size = struct.calcsize(element_fmt)
|
||||
data_bytes = f.read(count * element_size)
|
||||
if len(data_bytes) != count * element_size:
|
||||
raise EOFError("File ended unexpectedly when reading vector data.")
|
||||
return data_bytes
|
||||
|
||||
# --- Main Separation Logic ---
|
||||
|
||||
def separate_hnsw_flat(input_filename, graph_output_filename, storage_output_filename):
|
||||
"""
|
||||
Separates an IndexHNSWFlat file into graph and storage components.
|
||||
"""
|
||||
print(f"Processing: {input_filename}")
|
||||
try:
|
||||
with open(input_filename, 'rb') as f_in, \
|
||||
open(graph_output_filename, 'wb') as f_graph_out, \
|
||||
open(storage_output_filename, 'wb') as f_storage_out:
|
||||
|
||||
# 1. Read and write HNSW FourCC
|
||||
hnsw_fourcc = read_struct(f_in, '<I')
|
||||
if hnsw_fourcc != INDEX_HNSW_FLAT_FOURCC:
|
||||
print(f"Error: Expected IndexHNSWFlat FourCC ({INDEX_HNSW_FLAT_FOURCC:08x}), "
|
||||
f"but got {hnsw_fourcc:08x}. Is this an IndexHNSWFlat file?", file=sys.stderr)
|
||||
return False
|
||||
f_graph_out.write(struct.pack('<I', hnsw_fourcc))
|
||||
print(f" Index type: HNSWFlat ({hnsw_fourcc:08x})")
|
||||
|
||||
# 2. Read and write Index Header
|
||||
# d, ntotal, dummy1, dummy2, is_trained, metric_type, [metric_arg]
|
||||
d = read_struct(f_in, '<i')
|
||||
ntotal = read_struct(f_in, '<q') # idx_t is int64 in Faiss default
|
||||
dummy1 = read_struct(f_in, '<q')
|
||||
dummy2 = read_struct(f_in, '<q')
|
||||
is_trained = read_struct(f_in, '?') # bool -> 1 byte
|
||||
metric_type = read_struct(f_in, '<i')
|
||||
metric_arg = 0.0
|
||||
header_data = struct.pack('<iq?i', d, ntotal, # omit dummies here
|
||||
is_trained, metric_type)
|
||||
if metric_type > 1:
|
||||
metric_arg = read_struct(f_in, '<f')
|
||||
header_data += struct.pack('<f', metric_arg)
|
||||
|
||||
# Write header *without* dummies to graph file
|
||||
# We'll reconstruct the full header later if needed, but for now
|
||||
# just keep the essential parts. Alternatively, write the exact bytes read.
|
||||
# Let's write exact bytes for simplicity of reassembly
|
||||
f_graph_out.write(struct.pack('<i', d))
|
||||
f_graph_out.write(struct.pack('<q', ntotal))
|
||||
f_graph_out.write(struct.pack('<q', dummy1))
|
||||
f_graph_out.write(struct.pack('<q', dummy2))
|
||||
f_graph_out.write(struct.pack('?', is_trained))
|
||||
f_graph_out.write(struct.pack('<i', metric_type))
|
||||
if metric_type > 1:
|
||||
f_graph_out.write(struct.pack('<f', metric_arg))
|
||||
|
||||
print(f" Dimensions (d): {d}")
|
||||
print(f" Num vectors (ntotal): {ntotal}")
|
||||
print(f" Is trained: {is_trained}")
|
||||
print(f" Metric type: {metric_type}")
|
||||
if metric_type > 1:
|
||||
print(f" Metric arg: {metric_arg}")
|
||||
|
||||
# 3. Read and write HNSW struct data
|
||||
print(" Reading HNSW graph data...")
|
||||
# assign_probas (vector<double>)
|
||||
count, data = read_vector(f_in, '<d')
|
||||
f_graph_out.write(struct.pack('<Q', count))
|
||||
f_graph_out.write(data)
|
||||
print(f" assign_probas size: {count}")
|
||||
|
||||
# cum_nneighbor_per_level (vector<int>)
|
||||
count, data = read_vector(f_in, '<i')
|
||||
f_graph_out.write(struct.pack('<Q', count))
|
||||
f_graph_out.write(data)
|
||||
print(f" cum_nneighbor_per_level size: {count}")
|
||||
|
||||
# levels (vector<int>) - Store node levels
|
||||
count, data = read_vector(f_in, '<i')
|
||||
f_graph_out.write(struct.pack('<Q', count))
|
||||
f_graph_out.write(data)
|
||||
print(f" levels size: {count}")
|
||||
|
||||
# offsets (vector<size_t>) - Store offsets for neighbors
|
||||
count, data = read_vector(f_in, '<Q')
|
||||
f_graph_out.write(struct.pack('<Q', count))
|
||||
f_graph_out.write(data)
|
||||
print(f" offsets size: {count}")
|
||||
|
||||
# neighbors (vector<storage_idx_t> -> int32_t typically)
|
||||
count, data = read_vector(f_in, '<i') # Assuming storage_idx_t is int32
|
||||
f_graph_out.write(struct.pack('<Q', count))
|
||||
f_graph_out.write(data)
|
||||
print(f" neighbors size: {count}")
|
||||
|
||||
# entry_point, max_level, efConstruction, efSearch
|
||||
entry_point = read_struct(f_in, '<i')
|
||||
max_level = read_struct(f_in, '<i')
|
||||
efConstruction = read_struct(f_in, '<i')
|
||||
efSearch = read_struct(f_in, '<i')
|
||||
# Read and discard the dummy upper_beam
|
||||
_ = read_struct(f_in, '<i')
|
||||
|
||||
f_graph_out.write(struct.pack('<i', entry_point))
|
||||
f_graph_out.write(struct.pack('<i', max_level))
|
||||
f_graph_out.write(struct.pack('<i', efConstruction))
|
||||
f_graph_out.write(struct.pack('<i', efSearch))
|
||||
f_graph_out.write(struct.pack('<i', 1)) # Write dummy upper_beam back
|
||||
print(f" entry_point: {entry_point}")
|
||||
print(f" max_level: {max_level}")
|
||||
print(f" efConstruction: {efConstruction}")
|
||||
print(f" efSearch: {efSearch}")
|
||||
|
||||
|
||||
# --- Storage Part ---
|
||||
print(" Reading storage (IndexFlat) data...")
|
||||
storage_start_pos = f_in.tell()
|
||||
|
||||
# 4. Check: Read the storage FourCC (should be IndexFlat)
|
||||
storage_fourcc = read_struct(f_in, '<I')
|
||||
if storage_fourcc not in FLAT_FOURCCS:
|
||||
print(f"Error: Expected IndexFlat FourCC ({list(FLAT_FOURCCS)}), "
|
||||
f"but got {storage_fourcc:08x} after HNSW data.", file=sys.stderr)
|
||||
return False
|
||||
print(f" Storage type: IndexFlat ({storage_fourcc:08x})")
|
||||
|
||||
# 5. Read the rest of the file as storage data
|
||||
f_in.seek(storage_start_pos) # Go back to start of storage
|
||||
storage_data = f_in.read() # Read everything remaining
|
||||
f_storage_out.write(storage_data)
|
||||
print(f" Wrote {len(storage_data)} bytes to storage file.")
|
||||
|
||||
# 6. Final Check: Did we reach the end of the input file?
|
||||
if f_in.read(1):
|
||||
print("Warning: Unexpected data found after storage part in input file.", file=sys.stderr)
|
||||
|
||||
print(f"Separation complete:")
|
||||
print(f" Graph structure: {graph_output_filename}")
|
||||
print(f" Vector storage: {storage_output_filename}")
|
||||
return True
|
||||
|
||||
except EOFError as e:
|
||||
print(f"Error: Reached end of file unexpectedly. The input file might be incomplete or corrupted. {e}", file=sys.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
# --- Example Usage ---
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python separate_hnsw_flat.py <input_index_file>")
|
||||
sys.exit(1)
|
||||
|
||||
input_file = sys.argv[1]
|
||||
base_name = os.path.splitext(input_file)[0]
|
||||
graph_file = base_name + ".hnsw_graph"
|
||||
storage_file = base_name + ".flat_storage"
|
||||
|
||||
if not os.path.exists(input_file):
|
||||
print(f"Error: Input file not found: {input_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
success = separate_hnsw_flat(input_file, graph_file, storage_file)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
18
research/utils/simple_search.fish
Normal file
18
research/utils/simple_search.fish
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
set -l BIG_GRAPH /opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/index.faiss \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/index.faiss \
|
||||
/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/index.faiss
|
||||
|
||||
set -l LABELS hnsw_IP_M30_efC128 \
|
||||
99_4_degree_based_hnsw_IP_M32_efC256 \
|
||||
d9_hnsw_IP_M8_efC128 \
|
||||
half_edges_IP_M32_efC128
|
||||
|
||||
for i in (seq 1 4)
|
||||
set graph (string join / $BIG_GRAPH[$i])
|
||||
set label (string join / $LABELS[$i])
|
||||
echo "Building HNSW index with $label..."
|
||||
python ./faiss/demo/large_graph_simple_build.py --index-file $graph
|
||||
end
|
||||
59
research/utils/subsample_data_new.py
Executable file
59
research/utils/subsample_data_new.py
Executable file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def subsample_jsonl_random(input_file_path, output_file_path, ratio=0.1, seed=42):
|
||||
"""
|
||||
Subsamples 10% of the data from a JSONL file efficiently.
|
||||
|
||||
Args:
|
||||
input_file_path (str): Path to the input JSONL file.
|
||||
output_file_path (str): Path to the output JSONL file where the subsample will be saved.
|
||||
seed (int): Seed for the random number generator to ensure reproducibility.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# First pass: count the number of lines in the file
|
||||
line_count = 0
|
||||
with open(input_file_path, "r", encoding="utf-8") as file:
|
||||
for _ in file:
|
||||
line_count += 1
|
||||
print(f"Total lines: {line_count}")
|
||||
|
||||
# Calculate indices for 10% sample
|
||||
np.random.seed(seed)
|
||||
sample_size = int(line_count * ratio)
|
||||
selected_indices = set(np.random.choice(line_count, sample_size, replace=False))
|
||||
|
||||
# Second pass: write the selected lines to the output file
|
||||
print(f"Subsampling {sample_size} lines")
|
||||
current_index = 0
|
||||
with (
|
||||
open(input_file_path, "r", encoding="utf-8") as input_file,
|
||||
open(output_file_path, "w", encoding="utf-8") as output_file,
|
||||
):
|
||||
for line in input_file:
|
||||
if current_index in selected_indices:
|
||||
output_file.write(line)
|
||||
current_index += 1
|
||||
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Time: {(end_time - start_time) / 60:.2f} minutes\tRaw Size: {line_count}\t Sampled Size: {sample_size}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# input_dir = '/mnt/md-256k/redpajama_v1/common_crawl_2023_06'
|
||||
# output_dir = '/mnt/md-256k/massive_ds_data/subsampled_0.1/rpj_common_crawl_2023_06'
|
||||
input_dir = "/mnt/md-256k/massive_ds_data/full/dpr_wiki"
|
||||
output_dir = "/mnt/md-256k/massive_ds_data/subsampled_0.1/dpr_wiki"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for filename in os.listdir(input_dir):
|
||||
input_path = os.path.join(input_dir, filename)
|
||||
output_path = os.path.join(output_dir, filename)
|
||||
subsample_jsonl_random(input_path, output_path)
|
||||
52
research/utils/timing.py
Executable file
52
research/utils/timing.py
Executable file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import time
|
||||
import functools
|
||||
|
||||
|
||||
class time_exec:
|
||||
def __init__(self, func):
|
||||
functools.update_wrapper(self, func)
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = self.func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
timing_log = (
|
||||
f"Function '{self.func.__name__}' executed in {execution_time:.2f} seconds"
|
||||
)
|
||||
print(timing_log)
|
||||
return result, execution_time
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, args):
|
||||
self.log_file = args.log_file
|
||||
self.ds_domain = args.domain
|
||||
self.seed = args.seed
|
||||
self.datastore_size = args.sample_size
|
||||
self.stride = args.stride
|
||||
self.max_seq_length = args.max_seq_length
|
||||
self.merge = args.merge
|
||||
self.prefix = f"{self.ds_domain}\t{self.seed}\t{self.datastore_size}\t{self.stride}\t{self.max_seq_length}\t{self.merge}"
|
||||
|
||||
def log_results(
|
||||
self,
|
||||
time_sample=None,
|
||||
time_chunk=None,
|
||||
time_index=None,
|
||||
time_eval=None,
|
||||
num_eval=None,
|
||||
perplexity=None,
|
||||
):
|
||||
# Create the log entry
|
||||
log_entry = f"{self.prefix}\t{time_sample}\t{time_chunk}\t{time_index}\t{time_eval}\t{num_eval}\t{perplexity}\n"
|
||||
|
||||
# Open the file in append mode. Creates the file if it doesn't exist.
|
||||
with open(self.log_file, "a") as file:
|
||||
file.write(log_entry)
|
||||
|
||||
def log_string(self, log_string):
|
||||
with open(self.log_file, "a") as file:
|
||||
file.write(log_string)
|
||||
Reference in New Issue
Block a user