class VGG(nn.Module):
def __init__(self, pretrained=True):
super(VGG, self).__init__()
# conv1 1/2
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# conv2 1/4
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# conv3 1/8
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# conv4 1/16
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
# conv5 1/32
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
# load pretrained params from torchvision.models.vgg16(pretrained=True)
if pretrained:
pretrained_model = models.vgg16(pretrained=pretrained)
pretrained_params = pretrained_model.state_dict()
keys = list(pretrained_params.keys())
new_dict = {}
for index, key in enumerate(self.state_dict().keys()):
new_dict[key] = pretrained_params[keys[index]]
self.load_state_dict(new_dict)
def forward(self, x):
x = self.relu1_1(self.conv1_1(x))
x = self.relu1_2(self.conv1_2(x))
x = self.pool1(x)
pool1 = x
x = self.relu2_1(self.conv2_1(x))
x = self.relu2_2(self.conv2_2(x))
x = self.pool2(x)
pool2 = x
x = self.relu3_1(self.conv3_1(x))
x = self.relu3_2(self.conv3_2(x))
x = self.relu3_3(self.conv3_3(x))
x = self.pool3(x)
pool3 = x
x = self.relu4_1(self.conv4_1(x))
x = self.relu4_2(self.conv4_2(x))
x = self.relu4_3(self.conv4_3(x))
x = self.pool4(x)
pool4 = x
x = self.relu5_1(self.conv5_1(x))
x = self.relu5_2(self.conv5_2(x))
x = self.relu5_3(self.conv5_3(x))
x = self.pool5(x)
pool5 = x
return pool1, pool2, pool3, pool4, pool5
class FCNs(nn.Module):
def __init__(self, num_classes, backbone="vgg"):
super(FCNs, self).__init__()
self.num_classes = num_classes
if backbone == "vgg":
self.features = VGG()
# deconv1 1/16
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.relu1 = nn.ReLU()
# deconv1 1/8
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.relu2 = nn.ReLU()
# deconv1 1/4
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU()
# deconv1 1/2
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.ReLU()
# deconv1 1/1
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.ReLU()
self.classifier = nn.Conv2d(32, num_classes, kernel_size=1)
def forward(self, x):
features = self.features(x)
y = self.bn1(self.relu1(self.deconv1(features[4])) + features[3])
y = self.bn2(self.relu2(self.deconv2(y)) + features[2])
y = self.bn3(self.relu3(self.deconv3(y)) + features[1])
y = self.bn4(self.relu4(self.deconv4(y)) + features[0])
y = self.bn5(self.relu5(self.deconv5(y)))
y = self.classifier(y)
return y