torchvision.transforms 数据预处理:ToTensor()

文章目录

    • 1、ToTensor() 函数的作用
    • 2、读取图像时 PIL 和 opencv 的选择
      • 2.1 使用 PIL
      • 2.2 使用 opencv
    • 3、ToTensor() 的使用
      • 3.1 关键知识点
      • 3.2 代码示例

ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。

1、ToTensor() 函数的作用

必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor,其实它的功能不止于此

看一下 ToTensor() 函数的源码:

class ToTensor:
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
    """

大意是:

(1)将 PIL Image 或 numpy.ndarray 转为 tensor

(2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。

(3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。

2、读取图像时 PIL 和 opencv 的选择

在自己建立 dataset 迭代器时,一般操作是检索数据集图像的路径,然后使用 PIL 库或 opencv库读取图片路径。

2.1 使用 PIL

import numpy as np
from PIL import Image

filePath="Dataset/FFHQ/00000.png"
img1=Image.open(filePath)
print(f"img1 = {img1}")    
# img1 = <PIL.PngImagePlugin.PngImageFile image mode=RGB size=128x128 at 0x253DC205A88>

img2 = np.array(img1)
print(f"img2 = {img2}")

"""
img2 = [[[  0 130 146]
  [  0 128 144]
  [  0 125 141]
  ...
  [133 162 164]
  [133 157 159]
  [134 157 163]]]
"""

可以看到,使用 PIL.Image 读取的图像是一种 PIL 类,mode=RGB,要想获得图像的像素值还需要将其转为 np.array 格式。

而 opencv 可以直接将图像读取为 np.array 格式,因此首选 opencv 。

2.2 使用 opencv

import cv2

filePath="Dataset/FFHQ/00000.png"
img=cv2.imread(filePath)
print(f"img.shape = {img.shape}")     # img.shape = (128, 128, 3)
print(f"img = {img}")     # img.dtype = uint8

"""
img = [[[146 130   0]
  [144 128   0]
  [141 125   0]
  ...
  [164 162 133]
  [159 157 133]
  [163 157 134]]]
"""

仔细对比PIL 和 opencv 的输出结果可以发现,PIL默认输出的图片格式为 RGB,而opencv输出的是BGR格式。

使用opencv读取的图像是[H,W,C]大小的,数据格式是 np.uint8 ,经过 ToTensor() 会进行归一化。而其他的数据类型(如 np.int8)经过 ToTensor() 数值不变,不进行归一化,后面会详细讲述。并且经过ToTensor()后图像格式变为 [C,H,W]。

3、ToTensor() 的使用

3.1 关键知识点

不管是使用 PLT还是opencv,最终得到都是 np.array类型。因此:

ToTensor() 是将 np.array 的数据 转为 tensor 格式

这里一定要明确几个点:

(1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。
(2)np.array 浮点型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。
(3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8
    经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。
(4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。
(5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。
(6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。

3.2 代码示例

下面通过代码熟悉 ToTensor() 的使用,分别看一下 np.uint8 和 非 np.uint8 类型的 np.array 经过 ToTensor() 之后的输出。

(1) np.uint8 类型

import numpy as np
from torchvision import transforms

data = np.array([
    [0, 5, 10, 20, 0],
    [255, 125, 180, 255, 196]
], dtype=np.uint8)

tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
         [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
"""

(2)非 np.uint8 类型

import numpy as np
from torchvision import transforms

data = np.array([
    [0, 5, 10, 20, 0],
    [255, 125, 180, 255, 196]
])      # data.dtype = int32

tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[  0,   5,  10,  20,   0],
         [255, 125, 180, 255, 196]]], dtype=torch.int32)
"""

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年8月8日
下一篇 2023年8月8日

相关推荐