1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| import torch import torch.nn as nn
class CNNModel(nn.Module): def __init__(self, out_channels=10): super(CNNModel, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 16, kernel_size=(5, 5), stride=(3, 3), padding=0), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(16), nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=(5, 5), stride=(3, 3), padding=0), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(32), nn.MaxPool2d(2),
nn.Conv2d(32, 1, kernel_size=(3, 3), stride=(2, 2), padding=0) ) self.ful_layer = nn.Sequential( nn.Linear(36, 16), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm1d(16),
nn.Linear(16, out_channels), nn.Softmax(dim=1) )
def forward(self, x): x = self.conv(x) x = x.view(x.size(0), -1) x = self.ful_layer(x) return x
model = CNNModel() x = torch.rand(16, 3, 512, 512) torch.onnx.export(model, x, "CNN.onnx")
|