Zexian Li

一文详解Pytorch中view()和reshape()的细微区别

2020-12-22 · 5 min read

Pytorch中主要使用view()reshape()来改变tensor的shape。

torch.view()

torch.view()通过共享内存地址的方式使用原tensor的基础数据,通过改变数据读取方式来返回一个具有新shape的新tensor;只能使用torch.Tensor.view()方式调用;在使用时要求新shape与原shape的尺寸兼容,即函数只能应用于内存中连续存储的tensor,使用transposepermute等函数改变tensor在内存内连续性后需使用contiguous()方法返回拷贝后的值再调用该函数。
可参照下例辅助理解:

import torch

a = torch.arange(24).view(1,2,3,4)
b = a.view(1,3,2,4)     # b.shape: 1 * 3 * 2 * 4  
c = a.transpose(1,2)    # c.shape: 1 * 3 * 2 * 4
# d = c.view(2, 12)     # raise error because of the uncontinuous data.
d = c.contiguous().view(2, 12)
print(b)
'''
tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7]],

         [[ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23]]]])
'''
print(c)
'''
tensor([[[[ 0,  1,  2,  3],
          [12, 13, 14, 15]],

         [[ 4,  5,  6,  7],
          [16, 17, 18, 19]],

         [[ 8,  9, 10, 11],
          [20, 21, 22, 23]]]])
'''
print(id(b) == id(c))           # False
print(id(b.data) == id(c.data)) # True

b[0, 0, :, :] = 100
print(a, b) # 'a' will also change its data.

torch.reshape()

torch.reshape()通过拷贝并使用原tensor的基础数据(而非共享内存地址)以返回一个具有新shape的新tensor;可使用torch.reshape()torch.Tensor.reshape()方法调用。此函数不依赖tensor在内存的连续性,当内存连续时,该函数与torch.view()函数等价,当内存不连续时,会自动复制后再改变形状,相当于contiguous().view()。此函数于Pytorch0.4时加入,解决了之前只有view函数时的部分遗留问题。
可参照下例辅助理解:

import torch
a = torch.zeros(3, 2)
b = a.reshape(6)
c = a.t().reshape(6)
a.fill_(1)
print(b)    # tensor([1., 1., 1., 1., 1., 1.])
print(c)    # tensor([0., 0., 0., 0., 0., 0.])

Pytorch与TensorFlow对比

对Pytorch中view函数和reshape函数的执行方式深入分析:在此过程中内存中数据分布并不发生改变,仅仅是数据读取方式发生了改变,更像是开创了一个特定shape的数组后单纯地将内存中数据逐个填入。
对比一下Pytorch和TensorFlow在更改tensor形状时的要求:假设我们有一个6*8大小的矩阵,希望将其转换成2*8*3的形状,TensorFlow会要求先将其拆成2*3*8再转成2*8*3;而Pytorch中可以直接转换而不报错,但这样的结果显然与我们想要的相去甚远,如果要正确转换格式,还是要先调换维度,再reshape/view。
一言以蔽之,Pytorch中改变矩阵shape的门槛更低,但也正是因此,更容易出错,对coder提出了更高的要求。
参照下例:

import torch

a = torch.zeros(6,5)
for i in range(6):
    a[i,:] = i
print(a)
"""
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.],
        [5., 5., 5., 5., 5.]])
"""

b = a.view(2,5,3)
c = a.reshape(2,5,3)
print(b, c)
"""
'b' is same as 'c', which is as below:
tensor([[[0., 0., 0.],
         [0., 0., 1.],
         [1., 1., 1.],
         [1., 2., 2.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [3., 3., 4.],
         [4., 4., 4.],
         [4., 5., 5.],
         [5., 5., 5.]]])
"""
d = a.reshape(2,3,5).transpose(1,2)
print(d)
"""
tensor([[[0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.]],

        [[3., 4., 5.],
         [3., 4., 5.],
         [3., 4., 5.],
         [3., 4., 5.],
         [3., 4., 5.]]])
"""

总结

如果需要新tensor,使用copy();如果需要共享内存,使用view();无脑reshape()不可取。

Bad decisions make good stories.