深度学习--实战 ResNet18

发布时间 2023-04-24 16:51:56作者: 林每天都要努力

深度学习--实战 ResNet18

ResNet18的基本含义是,网络的基本架构是ResNet,网络的深度是18层。但是这里的网络深度指的是网络的权重层,也就是包括池化,激活,线性层。而不包括批量化归一层,池化层。

模型实现

import torch
from torch import nn
import  torch.nn.functional as F

class ResBlk(nn.Module):
    '''
    resnet block
    '''

    def __init__(self,ch_in,ch_out):
        '''

        :param ch_in:
        :param ch_out:
        '''

        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            #[b,ch_in,in,h,w] => [b,ch_out,h,w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self,x):
        '''

        :param x:
        :return:
        '''
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #short cut.
        #[b,ch_in,in,h,w] vs [b,ch_out,h,w]
        #element-wise add
        out = self.extra(x) + out
        return out


class ResNet18(nn.Module):
    '''

    '''
    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1=nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64)
        )

        #followed 4 blocks
        #[b,64,h,w] =>[b,128,h,w]
        self.blk1 = ResBlk(64,64)
        #[b,128,h,w] =>[b,256,h,w]
        self.blk2 = ResBlk(64,128)
        #[b,256,h,w] =>[b,512,h,w]
        self.blk3 = ResBlk(128,256)
        #[b,512,h,w] =>[b,1024,h,w]
        self.blk4 = ResBlk(256,512)

        self.outlayer = nn.Linear(512*32*32,10)

    def forward(self,x):
        '''

        :param x:
        :return:
        '''
        x = F.relu((self.conv1(x)))

        #[b,64,h,w] =>[b,1024,h,w]
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)

        x = x.view(x.size(0),-1)

        x=self.outlayer(x)

        return x


def main():
    blk = ResBlk(64,128)
    tmp = torch.randn(2,64,32,32)
    out = blk(tmp)
    print("blkk",out.shape)

    model = ResNet18()
    tmp = torch.randn(2, 3, 32, 32)
    out = model(tmp)
    print("resnet:",out.shape)


if __name__ =='__main__':
    main()

训练与测试

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import  nn,optim
from resnet import ResNet18

def main():

    batch_size = 32
    epochs = 1000
    learn_rate = 1e-3

    #导入图片,一次只导入一张
    cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)

    #导入图片,一次只导入一张
    cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]),download=True)

    #加载图
    cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)

    #iter迭代器,__next__()方法可以获得数据
    x, label = iter(cifer_train).__next__()
    print("x:",x.shape,"label:",label.shape)
    #x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])


    device = torch.device('cuda')
    #model = Lenet5().to(device)
    model = ResNet18().to(device)
    print(model)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=learn_rate)


    for epoch in range(epochs):
        model.train()
        for batchidx,(x,label) in enumerate(cifer_train):
            x,label = x.to(device),label.to(device)

            logits = model(x)
            #logits:[b,10]

            loss = criteon(logits,label)

            #backprop
            optimizer.zero_grad()  #梯度清零
            loss.backward()
            optimizer.step()  #梯度更新
        #
        print(epoch,loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x,label in cifer_test:
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred =logits.argmax(dim=1)

                #[b] vs [b] => scalar tensor
                total_correct += torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)

        acc = total_correct/total_num
        print("epoch:",epoch,"acc:",acc)


if __name__ == '__main__':
    main()