Pytorch框架学习路径(二:张量操作)

张量操作

张量操作

张量拼接与切分

torch.cat()

Pytorch框架学习路径(二:张量操作)

1、torch.cat创建张量
import torch
torch.manual_seed(1)

# ======================================= example 1 =======================================
# torch.cat

flag = True
# flag = False

if flag:
    t = torch.ones((2, 3))

    t_0 = torch.cat([t, t], dim=0)
    t_1 = torch.cat([t, t, t], dim=1)

    print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

OUT:

t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])

torch.stack()

2、 torch.stack()创建张量
# ======================================= example 2 =======================================
# torch.stack

flag = True
# flag = False

if flag:
    t = torch.ones((2, 3))
	
	# 增加一个第0个维度,并在这个维度上合并
    t_stack_0 = torch.stack([t, t, t], dim=0) 
    # 增加一个第2个维度,并在这个维度上合并
    t_stack_1 = torch.stack([t, t, t], dim=2)
	
	print("t: {} shape: {}".format(t, t.shape))
    print("\nt_stack_0: {} shape: {}".format(t_stack_0, t_stack_0.shape))
    print("\nt_stack_1: {} shape: {}".format(t_stack_1, t_stack_1.shape))

OUT:

t: tensor([[1., 1., 1.],
        [1., 1., 1.]]) shape: torch.Size([2, 3])

t_stack_0: tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape: torch.Size([3, 2, 3])

t_stack_1: tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]) shape: torch.Size([2, 3, 3])

torch.chunk()

Pytorch框架学习路径(二:张量操作)

3、torch.chunk()创建张量
# ======================================= example 3 =======================================
# torch.chunk

flag = True
# flag = False

if flag:
    a = torch.ones((2, 7))  # 7
    list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 按第一维度(也就是列)切分成3份

    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

OUT:

第1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])

torch.split()

Pytorch框架学习路径(二:张量操作)

4、torch.split()创建张量
# ======================================= example 4 =======================================
# torch.split

flag = True
# flag = False

if flag:
    t = torch.ones((2, 5))

    # list[2, 1, 2]求和一定要等于指定维度dim=1上的大小5,不然会报错
    list_of_tensors = torch.split(t, [2, 1, 2], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
    
    # 举例:sum[2, 1, 1] = 4 不等于 5
    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))

OUT:

第1个张量:tensor([[1., 1.],
        [1., 1.]]), shape is torch.Size([2, 2])
第2个张量:tensor([[1.],
        [1.]]), shape is torch.Size([2, 1])
第3个张量:tensor([[1., 1.],
        [1., 1.]]), shape is torch.Size([2, 2])
Traceback (most recent call last):
  File "E:/Code/In_pytorch/pytorch_lesson_code/Hello_pytorch/lesson/lesson-03/lesson-03.py", line 70, in <module>
    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)
  File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\functional.py", line 156, in split
    return tensor.split(split_size_or_sections, dim)
  File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\_tensor.py", line 518, in split
    return super(Tensor, self).split_with_sizes(split_size, dim)
RuntimeError: start (2) + length (1) exceeds dimension size (2).

张量索引

torch.index_select()

Pytorch框架学习路径(二:张量操作)

5、torch.index_select()
# ======================================= example 5 =======================================
# torch.index_select

flag = True
# flag = False

if flag:
    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # float
    t_select_0 = torch.index_select(t, dim=0, index=idx)
    print("t:\n{}\nt_select_0:\n{}".format(t, t_select_0))
   
    print("\n")
    
    t_select_1 = torch.index_select(t, dim=1, index=idx)
    print("t:\n{}\nt_select_1:\n{}".format(t, t_select_1))

OUT:

t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
t_select_0:
tensor([[4, 5, 0],
        [2, 5, 8]])


t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
t_select_1:
tensor([[4, 0],
        [5, 1],
        [2, 8]])

torch.masked_select()

Pytorch框架学习路径(二:张量操作)

6、torch.masked_select()
# ======================================= example 6 =======================================
# torch.masked_select

flag = True
# flag = False

if flag:

    t = torch.randint(0, 9, size=(3, 3))
    # le指:小于
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

OUT:

t:
tensor([[4, 5, 0],
        [5, 7, 1],
        [2, 5, 8]])
mask:
tensor([[ True,  True,  True],
        [ True, False,  True],
        [ True,  True, False]])
t_select:
tensor([4, 5, 0, 5, 1, 2, 5]) 

张量变换

torch.reshape()

Pytorch框架学习路径(二:张量操作)

7、torch.reshape()
# ======================================= example 7 =======================================
# torch.reshape

flag = True
# flag = False

if flag:
    t = torch.randperm(8)
    t_reshape = torch.reshape(t, (-1, 2, 2))    # -1
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

    t[0] = 1024
    print("\n")
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
    print("t.data 内存地址:{}".format(id(t.data)))
    print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))

OUT:

t:tensor([5, 4, 2, 6, 7, 3, 1, 0])
t_reshape:
tensor([[[5, 4],
         [2, 6]],

        [[7, 3],
         [1, 0]]])


t:tensor([1024,    4,    2,    6,    7,    3,    1,    0])
t_reshape:
tensor([[[1024,    4],
         [   2,    6]],

        [[   7,    3],
         [   1,    0]]])
t.data 内存地址:1747722786440
t_reshape.data 内存地址:1747722786440

torch.transpose()

Pytorch框架学习路径(二:张量操作)

8、torch.transpose()
# ======================================= example 8 =======================================
# torch.transpose

# flag = True
flag = False

if flag:
    # torch.transpose
    t = torch.rand((2, 3, 4))
    t_transpose = torch.transpose(t, dim0=1, dim1=2)    # c*h*w     h*w*c
    print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))

OUT:

t shape:torch.Size([2, 3, 4])
t_transpose shape: torch.Size([2, 4, 3])

torch.squeeze()

Pytorch框架学习路径(二:张量操作)

9、torch.squeeze()
# ======================================= example 9 =======================================
# torch.squeeze

flag = True
# flag = False

if flag:
    t = torch.rand((1, 2, 3, 1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t, dim=0)
    t_1 = torch.squeeze(t, dim=1)
    print("t.shape: {}".format(t.shape))
    print("t_sq.shape: {}".format(t_sq.shape))
    print("t_0.shape: {}".format(t_0.shape))
    print("t_1.shape: {}".format(t_1.shape))

OUT:

t.shape: torch.Size([1, 2, 3, 1])
t_sq.shape: torch.Size([2, 3])
t_0.shape: torch.Size([2, 3, 1])
t_1.shape: torch.Size([1, 2, 3, 1])

张量数学运算

Pytorch框架学习路径(二:张量操作)

torch.add()

Pytorch框架学习路径(二:张量操作)

# ======================================= example 8 =======================================
# torch.add

flag = True
# flag = False

if flag:
    t_0 = torch.randn((3, 3))
    t_1 = torch.ones_like(t_0)
    t_add = torch.add(t_0, 10, t_1)

    print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))

OUT:

t_0:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661],
        [-1.5228,  0.3817, -1.0276]])
t_1:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
t_add_10:
tensor([[10.6614, 10.2669, 10.0617],
        [10.6213,  9.5481,  9.8339],
        [ 8.4772, 10.3817,  8.9724]])

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年5月30日 上午10:14
下一篇 2022年5月30日

相关推荐