Files
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

422 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import kendalltau, spearmanr
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else
("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"使用设备: {device}")
# 定义自定义比较函数(基于内积)
def compare(a, b):
"""
计算两个向量的内积,并返回其负值作为距离度量
数值越小表示越相似(与提供的代码一致)
"""
result = np.dot(a, b)
return -result # 返回负值,与原代码一致
# 批量计算相似度
def compute_similarities(queries, corpus):
"""计算查询向量与语料库向量之间的相似度矩阵"""
similarities = np.zeros((len(queries), len(corpus)))
for i, query in enumerate(queries):
for j, doc in enumerate(corpus):
similarities[i, j] = compare(query, doc)
return similarities
# 加载两个模型
model_names = [
"facebook/contriever-msmarco", # Contriever模型
"facebook/contriever-msmarco-int4" # Contriever模型 (int4)
]
# 扩展的样本文本 - 分为多个主题组
texts = [
# 组1: 关于狐狸和动物 (0-9)
"The quick brown fox jumps over the lazy dog.",
"A rapid auburn fox leaps above the inactive canine.",
"The sly fox outsmarts the hunting hounds in the forest.",
"Foxes are known for their cunning behavior and bushy tails.",
"The red fox is the largest of the true foxes and the most common fox species.",
"Dogs have been companions to humans for thousands of years.",
"The lazy dog slept through the commotion of the playful fox.",
"Wolves and foxes belong to the same family, Canidae.",
"The arctic fox changes its coat color with the seasons.",
"Domestic dogs come in hundreds of breeds of various sizes and appearances.",
# 组2: 人工智能和机器学习 (10-19)
"Machine learning is a branch of artificial intelligence.",
"Deep learning is a subset of machine learning.",
"Neural networks are computing systems inspired by biological neural networks.",
"AI systems can now beat human champions at complex games like chess and Go.",
"Natural language processing allows computers to understand human language.",
"Reinforcement learning involves training agents to make sequences of decisions.",
"Computer vision enables machines to derive information from images and videos.",
"The Turing test measures a machine's ability to exhibit intelligent behavior.",
"Supervised learning uses labeled training data to learn the mapping function.",
"Unsupervised learning finds patterns in data without pre-existing labels.",
# 组3: 巴黎和法国地标 (20-29)
"The Eiffel Tower is located in Paris, France.",
"The Louvre Museum is in the city of Paris.",
"Notre-Dame Cathedral is a medieval Catholic cathedral on the Île de la Cité in Paris.",
"The Arc de Triomphe stands at the center of the Place Charles de Gaulle in Paris.",
"The Seine River flows through the heart of Paris.",
"Montmartre is a large hill in Paris's 18th arrondissement known for its artistic history.",
"The Palace of Versailles is located in the Île-de-France region of France.",
"The Champs-Élysées is an avenue in Paris famous for its theatres, cafés, and luxury shops.",
"The Sacré-Cœur Basilica offers one of the most beautiful panoramic views of Paris.",
"The Musée d'Orsay houses the largest collection of impressionist masterpieces in the world.",
# 组4: 可再生能源 (30-39)
"Solar panels convert sunlight into electricity.",
"Wind turbines generate power from moving air.",
"Hydroelectric power is generated from flowing water.",
"Geothermal energy harnesses heat from within the Earth.",
"Biomass energy comes from organic materials like plants and waste.",
"Tidal energy uses the natural rise and fall of coastal tidal waters.",
"Renewable energy sources can help reduce greenhouse gas emissions.",
"Solar farms can span hundreds of acres with thousands of panels.",
"Offshore wind farms are built in bodies of water to harvest wind energy.",
"Energy storage systems are crucial for balancing renewable energy supply and demand.",
# 组5: 编程语言 (40-49)
"Python is a popular programming language for data science.",
"JavaScript is commonly used for web development.",
"Java is known for its 'write once, run anywhere' capability.",
"C++ provides high-performance and close hardware control.",
"Ruby is praised for its simplicity and productivity.",
"PHP is a server-side scripting language designed for web development.",
"Swift is used to develop applications for Apple platforms.",
"Rust offers memory safety without using garbage collection.",
"Go was designed at Google to improve programming productivity.",
"Kotlin is fully interoperable with Java and provides more concise syntax.",
]
# 扩展的查询句子
query_texts = [
# 动物相关查询
"A fox jumped over a dog.",
"Wild animals and their behaviors in forests.",
"Different species of foxes around the world.",
# AI相关查询
"Artificial intelligence and neural networks.",
"Machine learning algorithms and applications.",
"The future of deep learning technology.",
# 巴黎相关查询
"Famous landmarks in Paris, France.",
"Tourist attractions along the Seine River.",
"Historical buildings and museums in Paris.",
# 能源相关查询
"Renewable energy sources and sustainability.",
"Solar and wind power generation technologies.",
"Alternative clean energy solutions for the future.",
# 编程相关查询
"Computer programming languages comparison.",
"Best languages for web development.",
"Programming tools for data science applications."
]
# 函数获取BGE模型的嵌入
def get_bge_embeddings(model, tokenizer, texts, device):
# 处理大量文本时分批进行
batch_size = 16
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
max_length=512, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
# BGE使用[CLS]标记
embeddings = model_output.last_hidden_state[:, 0]
# 归一化嵌入
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(normalized_embeddings.cpu().numpy())
return np.vstack(all_embeddings)
# 函数获取Contriever模型的嵌入
def get_contriever_embeddings(model, tokenizer, texts, device, use_int4=False):
# 处理大量文本时分批进行
batch_size = 16
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
max_length=512, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
# Contriever使用平均池化
attention_mask = encoded_input['attention_mask'].unsqueeze(-1)
embeddings = (model_output.last_hidden_state * attention_mask).sum(1) / attention_mask.sum(1)
# 归一化嵌入
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(normalized_embeddings.cpu().numpy())
return np.vstack(all_embeddings)
# 主函数
def compare_embeddings():
results = {}
for i, model_name in enumerate(model_names):
model_display_name = model_name
# 给第二个模型一个不同的显示名称,以便区分
if i == 1:
model_display_name = "facebook/contriever-msmarco-int4"
print(f"\n======= 加载模型 {i+1}: {model_display_name} =======")
tokenizer = AutoTokenizer.from_pretrained(model_names[0]) # 两个模型使用相同的tokenizer
# 如果是第二个模型int4版本进行量化
if i == 1:
print("应用int4量化...")
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = AutoModel.from_pretrained(
model_names[0], # 使用相同的基础模型
quantization_config=quantization_config,
device_map="auto"
)
print("成功加载int4模型")
except Exception as e:
print(f"int4加载失败: {e}")
print("回退到标准模型...")
model = AutoModel.from_pretrained(model_names[0]).to(device)
else:
model = AutoModel.from_pretrained(model_names[0]).to(device)
model.eval()
print(f"使用 {model_display_name} 生成嵌入...")
# 所有模型都使用contriever
use_int4 = i == 1
corpus_embeddings = get_contriever_embeddings(model, tokenizer, texts, device, use_int4)
query_embeddings = get_contriever_embeddings(model, tokenizer, query_texts, device, use_int4)
print(f"语料库嵌入形状: {corpus_embeddings.shape}")
print(f"查询嵌入形状: {query_embeddings.shape}")
# 使用自定义函数计算相似度
similarity_scores = compute_similarities(query_embeddings, corpus_embeddings)
# 对每个查询,按相似度排序文本索引(较小的值表示更相似)
ranked_indices = {}
for j, scores in enumerate(similarity_scores):
# 按相似度从低到高排序(因为我们返回的是负内积值)
sorted_indices = np.argsort(scores)
ranked_indices[f"query_{j+1}"] = sorted_indices
results[model_display_name] = {
'corpus_embeddings': corpus_embeddings,
'query_embeddings': query_embeddings,
'similarity_scores': similarity_scores,
'ranked_indices': ranked_indices
}
# 立即打印这个模型的一些结果作为验证
print(f"\n=== {model_display_name} 初步结果 ===")
# 显示第一个查询的前3个结果
query_idx = 0
ranked_idx = ranked_indices[f"query_{query_idx+1}"]
top_texts = [texts[idx] for idx in ranked_idx[:3]]
print(f"查询: '{query_texts[query_idx]}'")
print(f"排名前3位的文本:")
for j, text in enumerate(top_texts):
idx = ranked_idx[j]
score = similarity_scores[query_idx][idx]
print(f" {j+1}. [ID:{idx}] {text} (分数: {score:.4f})")
return results
# 分析结果
def analyze_results(results):
models = list(results.keys())
# 1. 比较相似度分数
print("\n=== 相似度分数比较 ===")
for model_name, result in results.items():
similarities = result['similarity_scores'].flatten()
print(f"{model_name} 相似度统计:")
print(f" 平均值: {similarities.mean():.4f}")
print(f" 最小值: {similarities.min():.4f}")
print(f" 最大值: {similarities.max():.4f}")
print(f" 标准差: {similarities.std():.4f}")
# 2. 比较排序结果针对每个查询显示前5个结果
print("\n=== 排序结果比较 ===")
for query_idx in range(len(query_texts)):
query_key = f"query_{query_idx+1}"
print(f"\n查询 {query_idx+1}: '{query_texts[query_idx]}'")
for model_name in models:
ranked_idx = results[model_name]['ranked_indices'][query_key]
top_texts = [texts[idx] for idx in ranked_idx[:5]]
print(f"{model_name} 排名前5位的文本:")
for i, text in enumerate(top_texts):
idx = ranked_idx[i]
score = results[model_name]['similarity_scores'][query_idx][idx]
print(f" {i+1}. [ID:{idx}] {text} (分数: {score:.4f})")
# 3. 排序一致性分析
print("\n=== 模型间排序一致性分析 ===")
kendall_tau_scores = []
spearman_scores = []
for query_idx in range(len(query_texts)):
query_key = f"query_{query_idx+1}"
# 获取各模型的排序结果只比较前10个结果
model1_top10 = results[models[0]]['ranked_indices'][query_key][:10]
model2_top10 = results[models[1]]['ranked_indices'][query_key][:10]
# 计算排序一致性
kt, _ = kendalltau(model1_top10, model2_top10)
sr, _ = spearmanr(model1_top10, model2_top10)
kendall_tau_scores.append(kt)
spearman_scores.append(sr)
# 计算前10个结果的重叠率
overlap = len(set(model1_top10) & set(model2_top10))
overlap_rate = overlap / 10.0
print(f"查询 {query_idx+1} '{query_texts[query_idx]}':")
print(f" Kendall's Tau = {kt:.4f}, Spearman's rank correlation = {sr:.4f}")
print(f" 前10结果重叠率: {overlap_rate:.2f} ({overlap}/10)")
print(f"\n平均 Kendall's Tau: {np.mean(kendall_tau_scores):.4f}")
print(f"平均 Spearman's rank correlation: {np.mean(spearman_scores):.4f}")
# 4. 可视化相似度分布差异
plt.figure(figsize=(12, 6))
for i, model_name in enumerate(models):
sns.histplot(results[model_name]['similarity_scores'].flatten(),
kde=True, label=model_name, alpha=0.6)
plt.title('不同模型的相似度分布')
plt.xlabel('相似度得分(越小越相似)')
plt.ylabel('频率')
plt.legend()
plt.savefig('similarity_distribution.png')
print("已保存相似度分布图表到 'similarity_distribution.png'")
# 5. 可视化主题相关性
plt.figure(figsize=(15, 10))
# 为每个主题组定义颜色
topic_colors = {
'动物': 'blue',
'AI': 'red',
'巴黎': 'green',
'能源': 'purple',
'编程': 'orange'
}
# 定义主题组范围
topic_ranges = {
'动物': (0, 10),
'AI': (10, 20),
'巴黎': (20, 30),
'能源': (30, 40),
'编程': (40, 50)
}
# 对每个查询显示前10个结果的主题分布
query_groups = [
[0, 1, 2], # 动物查询组
[3, 4, 5], # AI查询组
[6, 7, 8], # 巴黎查询组
[9, 10, 11], # 能源查询组
[12, 13, 14] # 编程查询组
]
for group_idx, group in enumerate(query_groups):
plt.subplot(len(query_groups), 1, group_idx+1)
# 为每个模型计算主题分布
bar_width = 0.35
bar_positions = np.arange(len(topic_ranges))
for model_idx, model_name in enumerate(models):
# 统计每个主题在前10个结果中的出现次数
topic_counts = {topic: 0 for topic in topic_ranges.keys()}
for query_idx in group:
query_key = f"query_{query_idx+1}"
top10 = results[model_name]['ranked_indices'][query_key][:10]
for idx in top10:
for topic, (start, end) in topic_ranges.items():
if start <= idx < end:
topic_counts[topic] += 1
# 绘制主题分布柱状图
plt.bar(bar_positions + (model_idx * bar_width),
list(topic_counts.values()),
bar_width,
label=model_name)
plt.title(f"查询组 {group_idx+1}: {', '.join([query_texts[i] for i in group[:1]])}")
plt.xticks(bar_positions + bar_width/2, list(topic_ranges.keys()))
plt.ylabel('前10结果中的出现次数')
plt.legend()
plt.tight_layout()
plt.savefig('topic_distribution.png')
print("已保存主题分布图表到 'topic_distribution.png'")
# 6. 可视化查询与相关文档的相似度热图
plt.figure(figsize=(15, 12))
for i, model_name in enumerate(models):
plt.subplot(2, 1, i+1)
# 获取相似度矩阵(负数越小表示越相似)
sim_matrix = results[model_name]['similarity_scores']
# 将负值转换为正值以便可视化(越大表示越相似)
sim_matrix_viz = -sim_matrix
# 绘制热图
sns.heatmap(sim_matrix_viz, cmap='YlGnBu',
xticklabels=[f"Doc{i}" for i in range(len(texts))],
yticklabels=[f"Q{i+1}" for i in range(len(query_texts))],
cbar_kws={'label': '相似度(越高越相似)'})
plt.title(f"{model_name} 相似度热图")
plt.xlabel('文档ID')
plt.ylabel('查询ID')
plt.tight_layout()
plt.savefig('similarity_heatmap.png')
print("已保存相似度热图到 'similarity_heatmap.png'")
if __name__ == "__main__":
print("开始比较嵌入模型...")
results = compare_embeddings()
analyze_results(results)
print("\n比较完成!")