千家信息网

pytorch带batch的tensor类型图像如何显示

发表于:2025-11-15 作者:千家信息网编辑
千家信息网最后更新 2025年11月15日,本篇内容主要讲解"pytorch带batch的tensor类型图像如何显示",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"pytorch带batch的ten
千家信息网最后更新 2025年11月15日pytorch带batch的tensor类型图像如何显示

本篇内容主要讲解"pytorch带batch的tensor类型图像如何显示",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"pytorch带batch的tensor类型图像如何显示"吧!

显示图像

绘图最常用的库就是matplotlib:

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:

>>> x = torch.randn(2, 3, 5)>>> x.size()torch.Size([2, 3, 5])>>> x.permute(1, 2, 0).size()torch.Size([3, 5, 2])

代码示例

#%% 导入模块import torchimport matplotlib.pyplot as pltfrom torchvision.utils import make_gridfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms#%% 下载数据集train_file = datasets.MNIST(    root='./dataset/',    train=True,    transform=transforms.Compose([        transforms.ToTensor(),        transforms.Normalize((0.1307,), (0.3081,))    ]),    download=True)#%% 制作数据加载器train_loader = DataLoader(    dataset=train_file,    batch_size=9,    shuffle=True)#%% 训练数据可视化images, labels = next(iter(train_loader))print(images.size())  # torch.Size([9, 1, 28, 28])plt.figure(figsize=(9, 9))for i in range(9):    plt.subplot(3, 3, i+1)    plt.title(labels[i].item())    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')    plt.axis('off')plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。

 # 方法1:Image.show() # transforms.ToPILImage()中有一句 # npimg = np.transpose(pic.numpy(), (1, 2, 0)) # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维 img = transforms.ToPILImage(image[0]) img.show() # 方法2:plt.imshow(ndarray) img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维 img = img.numpy() # FloatTensor转为ndarray img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后 # 显示图片 plt.imshow(img) plt.show() cnt += 1

到此,相信大家对"pytorch带batch的tensor类型图像如何显示"有了更深的了解,不妨来实际操作一番吧!这里是网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

0