PyTorch打印指定层的输入输出

features_in_hook = []
features_out_hook = []

def hook(module, fea_in, fea_out):
    #features_in_hook.clear()
    #features_out_hook.clear()
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None


model.features[0].register_forward_hook(hook=hook)
outputs = model(images)
# print(features_in_hook)  # 打印指定层的输入
# print(features_out_hook)  # 打印指定层的输出

    for i in range(16):
        plt.subplot(4, 4, i + 1)
        pic = features_out_hook[0][0][i].cpu().detach().numpy()
        plt.imshow(pic, cmap='gray')
    plt.show()

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

20 + 6 =