1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import torch import torchvision import torch.nn as nn from torchsummary import summary
class Network(nn.Module): def __init__(self, input_size,output_size): super(Network,self).__init__()
self.model = nn.Sequential( nn.Flatten(), nn.Linear(input_size,512), nn.BatchNorm1d(512), nn.ReLU(),
nn.Linear(512,256), nn.BatchNorm1d(256), nn.Linear(256,128), nn.ReLU(),
nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128,output_size) )
for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self,x): return self.model(x)
|