张量操作
张量操作
张量拼接与切分
torch.cat()
1、torch.cat创建张量 |
import torch
torch.manual_seed(1)
# ======================================= example 1 =======================================
# torch.cat
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
OUT:
t_0:tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])
torch.stack()
2、 torch.stack()创建张量 |
# ======================================= example 2 =======================================
# torch.stack
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
# 增加一个第0个维度,并在这个维度上合并
t_stack_0 = torch.stack([t, t, t], dim=0)
# 增加一个第2个维度,并在这个维度上合并
t_stack_1 = torch.stack([t, t, t], dim=2)
print("t: {} shape: {}".format(t, t.shape))
print("\nt_stack_0: {} shape: {}".format(t_stack_0, t_stack_0.shape))
print("\nt_stack_1: {} shape: {}".format(t_stack_1, t_stack_1.shape))
OUT:
t: tensor([[1., 1., 1.],
[1., 1., 1.]]) shape: torch.Size([2, 3])
t_stack_0: tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]) shape: torch.Size([3, 2, 3])
t_stack_1: tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]) shape: torch.Size([2, 3, 3])
torch.chunk()
3、torch.chunk()创建张量 |
# ======================================= example 3 =======================================
# torch.chunk
flag = True
# flag = False
if flag:
a = torch.ones((2, 7)) # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3) # 按第一维度(也就是列)切分成3份
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
OUT:
第1个张量:tensor([[1., 1., 1.],
[1., 1., 1.]]), shape is torch.Size([2, 3])
第2个张量:tensor([[1., 1., 1.],
[1., 1., 1.]]), shape is torch.Size([2, 3])
第3个张量:tensor([[1.],
[1.]]), shape is torch.Size([2, 1])
torch.split()
4、torch.split()创建张量 |
# ======================================= example 4 =======================================
# torch.split
flag = True
# flag = False
if flag:
t = torch.ones((2, 5))
# list[2, 1, 2]求和一定要等于指定维度dim=1上的大小5,不然会报错
list_of_tensors = torch.split(t, [2, 1, 2], dim=1) # [2 , 1, 2]
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
# 举例:sum[2, 1, 1] = 4 不等于 5
list_of_tensors = torch.split(t, [2, 1, 1], dim=1)
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx, t, t.shape))
OUT:
第1个张量:tensor([[1., 1.],
[1., 1.]]), shape is torch.Size([2, 2])
第2个张量:tensor([[1.],
[1.]]), shape is torch.Size([2, 1])
第3个张量:tensor([[1., 1.],
[1., 1.]]), shape is torch.Size([2, 2])
Traceback (most recent call last):
File "E:/Code/In_pytorch/pytorch_lesson_code/Hello_pytorch/lesson/lesson-03/lesson-03.py", line 70, in <module>
list_of_tensors = torch.split(t, [2, 1, 1], dim=1)
File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\functional.py", line 156, in split
return tensor.split(split_size_or_sections, dim)
File "D:\ProgramData\Anaconda3\envs\py37\lib\site-packages\torch\_tensor.py", line 518, in split
return super(Tensor, self).split_with_sizes(split_size, dim)
RuntimeError: start (2) + length (1) exceeds dimension size (2).
张量索引
torch.index_select()
5、torch.index_select() |
# ======================================= example 5 =======================================
# torch.index_select
flag = True
# flag = False
if flag:
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long) # float
t_select_0 = torch.index_select(t, dim=0, index=idx)
print("t:\n{}\nt_select_0:\n{}".format(t, t_select_0))
print("\n")
t_select_1 = torch.index_select(t, dim=1, index=idx)
print("t:\n{}\nt_select_1:\n{}".format(t, t_select_1))
OUT:
t:
tensor([[4, 5, 0],
[5, 7, 1],
[2, 5, 8]])
t_select_0:
tensor([[4, 5, 0],
[2, 5, 8]])
t:
tensor([[4, 5, 0],
[5, 7, 1],
[2, 5, 8]])
t_select_1:
tensor([[4, 0],
[5, 1],
[2, 8]])
torch.masked_select()
6、torch.masked_select() |
# ======================================= example 6 =======================================
# torch.masked_select
flag = True
# flag = False
if flag:
t = torch.randint(0, 9, size=(3, 3))
# le指:小于
mask = t.le(5) # ge is mean greater than or equal/ gt: greater than le lt
t_select = torch.masked_select(t, mask)
print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))
OUT:
t:
tensor([[4, 5, 0],
[5, 7, 1],
[2, 5, 8]])
mask:
tensor([[ True, True, True],
[ True, False, True],
[ True, True, False]])
t_select:
tensor([4, 5, 0, 5, 1, 2, 5])
张量变换
torch.reshape()
7、torch.reshape() |
# ======================================= example 7 =======================================
# torch.reshape
flag = True
# flag = False
if flag:
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2)) # -1
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
t[0] = 1024
print("\n")
print("t:{}\nt_reshape:\n{}".format(t, t_reshape))
print("t.data 内存地址:{}".format(id(t.data)))
print("t_reshape.data 内存地址:{}".format(id(t_reshape.data)))
OUT:
t:tensor([5, 4, 2, 6, 7, 3, 1, 0])
t_reshape:
tensor([[[5, 4],
[2, 6]],
[[7, 3],
[1, 0]]])
t:tensor([1024, 4, 2, 6, 7, 3, 1, 0])
t_reshape:
tensor([[[1024, 4],
[ 2, 6]],
[[ 7, 3],
[ 1, 0]]])
t.data 内存地址:1747722786440
t_reshape.data 内存地址:1747722786440
torch.transpose()
8、torch.transpose() |
# ======================================= example 8 =======================================
# torch.transpose
# flag = True
flag = False
if flag:
# torch.transpose
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, dim0=1, dim1=2) # c*h*w h*w*c
print("t shape:{}\nt_transpose shape: {}".format(t.shape, t_transpose.shape))
OUT:
t shape:torch.Size([2, 3, 4])
t_transpose shape: torch.Size([2, 4, 3])
torch.squeeze()
9、torch.squeeze() |
# ======================================= example 9 =======================================
# torch.squeeze
flag = True
# flag = False
if flag:
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print("t.shape: {}".format(t.shape))
print("t_sq.shape: {}".format(t_sq.shape))
print("t_0.shape: {}".format(t_0.shape))
print("t_1.shape: {}".format(t_1.shape))
OUT:
t.shape: torch.Size([1, 2, 3, 1])
t_sq.shape: torch.Size([2, 3])
t_0.shape: torch.Size([2, 3, 1])
t_1.shape: torch.Size([1, 2, 3, 1])
张量数学运算
torch.add()
# ======================================= example 8 =======================================
# torch.add
flag = True
# flag = False
if flag:
t_0 = torch.randn((3, 3))
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0, 10, t_1)
print("t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}".format(t_0, t_1, t_add))
OUT:
t_0:
tensor([[ 0.6614, 0.2669, 0.0617],
[ 0.6213, -0.4519, -0.1661],
[-1.5228, 0.3817, -1.0276]])
t_1:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
t_add_10:
tensor([[10.6614, 10.2669, 10.0617],
[10.6213, 9.5481, 9.8339],
[ 8.4772, 10.3817, 8.9724]])
文章出处登录后可见!
已经登录?立即刷新