Upsample Nearest Implementation

1. 通过Reshape和Concat结合实现Upsample

由于一些框架,不支持Upsample,可以用ReShape和Concat这种常见的Op替换

class MyUpSample(nn.Module):
    dump_patches = True
    def __init__(self, scale_factor = 2):
        super(MyUpSample, self).__init__()
        #self._upsample = nn.Upsample(scale_factor=scale_factor)
    def forward(self, x):
        #return self._upsample(x)
        ss = x.size()
        x = x.view(ss[0], ss[1], ss[2] * ss[3], 1)  # 1
        ups1 = torch.cat((x, x), 3)  # 2
        ups1 = ups1.view(ss[0], ss[1], ss[2], 2 * ss[3])  # 3
        ups2 = torch.cat((ups1, ups1), 3)  # 4
        ups2 = ups2.view(ss[0], ss[1], ss[2] * 2, ss[3] * 2)  # 5
        return ups2

其中代码注释中的5个步骤,分别对应下图中的5步:

ml_upsample

Loading Disqus comments...
Table of Contents