深度学习
作者:
Lemmon_kk
,
2025-01-09 14:09:17
,
所有人可见
,
阅读 6
# SAFM
class SAFM(nn.Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = nn.ModuleList([nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h//2**i, w//2**i)
s = F.adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = F.interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch.cat(out, dim=1))
out = self.act(out) * x
return out