注意图

发布时间 2023-12-20 20:57:28作者: 太好了还有脑子可以用
点击查看代码
import matplotlib.pyplot as plt

def visualize_attention(attention_map, original_image):
    plt.imshow(original_image[0], cmap='gray')
    plt.title('原始图像')
    plt.show()

    plt.imshow(attention_map[0, 0].detach().cpu().numpy(), cmap='hot', interpolation="nearest")
    plt.title('注意力图')
    plt.show()

# 假设你可以访问名为 'attention_layer' 的 ModulatedAttLayer 实例
# 以及一个名为 'input_image' 的输入张量

output, feature_maps = attention_layer(input_image)
visualize_attention(feature_maps[1], input_image)