网络结构

从代码层面理解UnetFormer的具体实现是过程中的图像维度变化。窗口机制的实现原理,跳远连接的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class UNetFormer(nn.Module):
def __init__(self,
decode_channels=64,
dropout=0.1,
backbone_name='swsl_resnet18',
pretrained=True,
window_size=8,
num_classes=6
):
super().__init__()
self.training = False
self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32,
out_indices=(1, 2, 3, 4), pretrained=pretrained)
encoder_channels = self.backbone.feature_info.channels()

self.decoder = Decoder(encoder_channels, decode_channels, dropout, window_size, num_classes)

def forward(self, x):
h, w = x.size()[-2:]
res1, res2, res3, res4 = self.backbone(x)
if self.training:
x, ah = self.decoder(res1, res2, res3, res4, h, w)
return x, ah
else:
x = self.decoder(res1, res2, res3, res4, h, w)
return x

该代码定义了网络的主干结构,使用timm库构建了四层ResNet18主干网络。并以编码器-解码器的结构构建整个网络框架。

编码器

编码器的主体结构使用timm库构建,并使用feature_info参数获取不同层次的通道数参数。

解码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def forward(self, res1, res2, res3, res4, h, w):
if self.training:
x = self.b4(self.pre_conv(res4))
h4 = self.up4(x)
x = self.p3(x, res3)
x = self.b3(x)
h3 = self.up3(x)
x = self.p2(x, res2)
x = self.b2(x)
h2 = x
x = self.p1(x, res1)
x = self.segmentation_head(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)
ah = h4 + h3 + h2
ah = self.aux_head(ah, h, w)
return x, ah
else:
x = self.b4(self.pre_conv(res4))
x = self.p3(x, res3)
x = self.b3(x)
x = self.p2(x, res2)
x = self.b2(x)
x = self.p1(x, res1)
x = self.segmentation_head(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)
return x

解码器结构由GLTB块和WS块初步处理,并在过程中将编码器的三层残差经过处理后进行上采样sum

GLTB

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Block(nn.Module):
def __init__(self, dim=256, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, window_size=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = GlobalLocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, window_size=window_size)

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,
drop=drop)
self.norm2 = norm_layer(dim)

def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))

return x

网络结构由两阶段的注意力和多层感知机处理结果与原始输入相加实现。

全局局部注意力

  • 局部注意力

    img

    1
    2
    self.local1 = ConvBN(dim, dim, kernel_size=3)
    self.local2 = ConvBN(dim, dim, kernel_size=1
  • 全局注意力

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    def forward(self, x):
    B, C, H, W = x.shape

    local = self.local2(x) + self.local1(x)
    x = self.pad(x, self.ws)
    B, C, Hp, Wp = x.shape
    qkv = self.qkv(x)

    q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads,
    d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, qkv=3, ws1=self.ws, ws2=self.ws)

    dots = (q @ k.transpose(-2, -1)) * self.scale

    if self.relative_pos_embedding:
    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    self.ws * self.ws, self.ws * self.ws, -1) # Wh*Ww,Wh*Ww,nH
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    dots += relative_position_bias.unsqueeze(0)

    attn = dots.softmax(dim=-1)
    attn = attn @ v

    attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads,
    d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, ws1=self.ws, ws2=self.ws)

    attn = attn[:, :, :H, :W]

    out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \
    self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect'))

    out = out + local
    out = self.pad_out(out)
    out = self.proj(out)
    # print(out.size())
    out = out[:, :, :H, :W]

    return out

将输入特征升维3倍通道数,使用rearrange将输入数据展平到窗口化维度,并展平一维化后计算各自头的注意力得分。之后将注意力得分进行相对位置嵌入捕获特征间的位置关系,经过softmax处理后与值相乘得到最后的自注意力矩阵。

之后使用横向和纵向的两个池化操作进一步提取特征。

最后将局部与全局注意力特征融合

WS权重分配

该模块权重融合残差和输入特征

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class WF(nn.Module):
def __init__(self, in_channels=128, decode_channels=128, eps=1e-8):
super(WF, self).__init__()
self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.eps = eps
self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

def forward(self, x, res):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weights = nn.ReLU()(self.weights)
fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
x = self.post_conv(x)
return x
  • 上采样输入特征图
  • 权重relu激活和归一化处理
  • 特征加权融合

FRH特征修复头

Image 7 of 7

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class FeatureRefinementHead(nn.Module):
def __init__(self, in_channels=64, decode_channels=64):
super().__init__()
self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
self.eps = 1e-8
self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

self.pa = nn.Sequential(
nn.Conv2d(decode_channels, decode_channels, kernel_size=3, padding=1, groups=decode_channels),
nn.Sigmoid())
self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
Conv(decode_channels, decode_channels // 16, kernel_size=1),
nn.ReLU6(),
Conv(decode_channels // 16, decode_channels, kernel_size=1),
nn.Sigmoid())

self.shortcut = ConvBN(decode_channels, decode_channels, kernel_size=1)
self.proj = SeparableConvBN(decode_channels, decode_channels, kernel_size=3)
self.act = nn.ReLU6()

def forward(self, x, res):
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
weights = nn.ReLU()(self.weights)
fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
x = self.post_conv(x)
shortcut = self.shortcut(x)
pa = self.pa(x) * x
ca = self.ca(x) * x
x = pa + ca
x = self.proj(x) + shortcut
x = self.act(x)

return x

将输入特征上采样之后与原始残差加权融合,之后使用双分支结构分别处理通道和空间信息。通道信息利用全局池化提取,空间信息使用卷积提取。将结果sum融合后经过卷积处理与shortcut剪切后大小相同的残差原始信息融合。