Pytorch中主要使用view()
与reshape()
来改变tensor的shape。
torch.view()
通过共享内存地址的方式使用原tensor的基础数据,通过改变数据读取方式来返回一个具有新shape的新tensor;只能使用torch.Tensor.view()
方式调用;在使用时要求新shape与原shape的尺寸兼容,即函数只能应用于内存中连续存储的tensor,使用transpose
、permute
等函数改变tensor在内存内连续性后需使用contiguous()
方法返回拷贝后的值再调用该函数。
可参照下例辅助理解:
import torch
a = torch.arange(24).view(1,2,3,4)
b = a.view(1,3,2,4) # b.shape: 1 * 3 * 2 * 4
c = a.transpose(1,2) # c.shape: 1 * 3 * 2 * 4
# d = c.view(2, 12) # raise error because of the uncontinuous data.
d = c.contiguous().view(2, 12)
print(b)
'''
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]]])
'''
print(c)
'''
tensor([[[[ 0, 1, 2, 3],
[12, 13, 14, 15]],
[[ 4, 5, 6, 7],
[16, 17, 18, 19]],
[[ 8, 9, 10, 11],
[20, 21, 22, 23]]]])
'''
print(id(b) == id(c)) # False
print(id(b.data) == id(c.data)) # True
b[0, 0, :, :] = 100
print(a, b) # 'a' will also change its data.
torch.reshape()
通过拷贝并使用原tensor的基础数据(而非共享内存地址)以返回一个具有新shape的新tensor;可使用torch.reshape()
或torch.Tensor.reshape()
方法调用。此函数不依赖tensor在内存的连续性,当内存连续时,该函数与torch.view()
函数等价,当内存不连续时,会自动复制后再改变形状,相当于contiguous().view()
。此函数于Pytorch0.4时加入,解决了之前只有view
函数时的部分遗留问题。
可参照下例辅助理解:
import torch
a = torch.zeros(3, 2)
b = a.reshape(6)
c = a.t().reshape(6)
a.fill_(1)
print(b) # tensor([1., 1., 1., 1., 1., 1.])
print(c) # tensor([0., 0., 0., 0., 0., 0.])
对Pytorch中view函数和reshape函数的执行方式深入分析:在此过程中内存中数据分布并不发生改变,仅仅是数据读取方式发生了改变,更像是开创了一个特定shape的数组后单纯地将内存中数据逐个填入。
对比一下Pytorch和TensorFlow在更改tensor形状时的要求:假设我们有一个6*8大小的矩阵,希望将其转换成2*8*3的形状,TensorFlow会要求先将其拆成2*3*8再转成2*8*3;而Pytorch中可以直接转换而不报错,但这样的结果显然与我们想要的相去甚远,如果要正确转换格式,还是要先调换维度,再reshape/view。
一言以蔽之,Pytorch中改变矩阵shape的门槛更低,但也正是因此,更容易出错,对coder提出了更高的要求。
参照下例:
import torch
a = torch.zeros(6,5)
for i in range(6):
a[i,:] = i
print(a)
"""
tensor([[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[5., 5., 5., 5., 5.]])
"""
b = a.view(2,5,3)
c = a.reshape(2,5,3)
print(b, c)
"""
'b' is same as 'c', which is as below:
tensor([[[0., 0., 0.],
[0., 0., 1.],
[1., 1., 1.],
[1., 2., 2.],
[2., 2., 2.]],
[[3., 3., 3.],
[3., 3., 4.],
[4., 4., 4.],
[4., 5., 5.],
[5., 5., 5.]]])
"""
d = a.reshape(2,3,5).transpose(1,2)
print(d)
"""
tensor([[[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]],
[[3., 4., 5.],
[3., 4., 5.],
[3., 4., 5.],
[3., 4., 5.],
[3., 4., 5.]]])
"""
如果需要新tensor,使用copy()
;如果需要共享内存,使用view()
;无脑reshape()
不可取。