最近看大佬的Transformer源码时发现大家都用到了这个库,查看了一下相关API,实现太简洁了,真正的所见即所得,速MARK!
import einops
x = einops.rearrange(x, 'n h w c -> n (h w) c')
x = einops.rearrange(x, 'n (h w) c -> n c h w', h=h1)
x = einops.rearrange(x, '(n1 n2) h w c -> (n1 h) (n2 w) c ', n1=2)
from einops.layes.torch import Rearrange
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('n h w -> h w n')
)
# str: mean, min, max, sum, prod
x = einops.reduce(x, 'n c h w -> n h w', 'mean') # average over channel