1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
| import torch from torch import nn from torchsummary import summary
class Residual(nn.Module): def __init__(self, input_channels, num_channels, use_1conv=False, strides=1): super(Residual, self).__init__() self.ReLu = nn.ReLU() self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=3, padding=1, stride=strides) self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(num_channels) self.bn2 = nn.BatchNorm2d(num_channels) if use_1conv: self.conv3 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=1, stride=strides) else: self.conv3 = None
def forward(self, x): conv1_output = self.conv1(x) bn1_output = self.bn1(conv1_output) relu_output = self.ReLu(bn1_output) conv2_output = self.conv2(relu_output) y = self.bn2(conv2_output) if self.conv3: x = self.conv3(x) y = self.ReLu(y + x) return y class ResNet18(nn.Module): def __init__(self,Residual): super(ResNet18,self).__init__() self.b1 = nn.Sequential( nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(kernel_size=3,stride=2,padding=1) )
self.b2 = nn.Sequential(Residual(64, 64, use_1conv=False, strides=1), Residual(64, 64, use_1conv=False, strides=1))
self.b3 = nn.Sequential(Residual(64, 128, use_1conv=True, strides=2), Residual(128, 128, use_1conv=False, strides=1))
self.b4 = nn.Sequential(Residual(128, 256, use_1conv=True, strides=2), Residual(256, 256, use_1conv=False, strides=1))
self.b5 = nn.Sequential(Residual(256, 512, use_1conv=True, strides=2), Residual(512, 512, use_1conv=False, strides=1))
self.b6 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 38)) def forward(self, x): x = self.b1(x) x = self.b2(x) x = self.b3(x) x = self.b4(x) x = self.b5(x) x = self.b6(x) return x
if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet18(Residual).to(device) print(summary(model, (3, 224, 224)))
|