深度学习—AlexNet_CIFAR100代码

发布时间 2023-04-03 16:55:11作者: JevonChao

 

 

 

 

  1 # 导入所需的包
  2 import torch
  3 #import wandb
  4 import torch.nn as nn
  5 from torchvision import transforms
  6 from torch.utils.data import DataLoader
  7 from torchvision.datasets import CIFAR100
  8 
  9 # 使用Compose容器组合定义图像预处理方式
 10 transf = transforms.Compose([
 11     # 改变图像大小
 12     transforms.Resize(224),
 13     # 将给定图片转为shape为(C, H, W)的tensor
 14     transforms.ToTensor()
 15 ])
 16 # 数据准备
 17 traindata = CIFAR100(
 18     # 数据集的地址
 19     root="./",
 20     # 是否为训练集,True为训练集
 21     train=True,
 22     # 使用数据预处理
 23     transform=transf,
 24     # 是否需要下载, True为需要下载
 25     download=True
 26 )
 27 testdata = CIFAR100(
 28     root="./",
 29     train=False,
 30     transform=transf,
 31     download=True
 32 )
 33 # 定义数据加载器
 34 trainloader = DataLoader(
 35     # 需要加载的数据
 36     traindata,
 37     # 定义batch大小
 38     batch_size=128,
 39     # 是否打乱顺序,True为打乱顺序
 40     shuffle=True
 41 )
 42 testloader = DataLoader(
 43     testdata,
 44     batch_size=128,
 45     shuffle=False
 46 )
 47 
 48 # 定义AlexNet网络
 49 # 此处写代码
 50 from torch import nn
 51 class AlexNet(nn.Module):
 52     # 初始化
 53     def __init__(self):
 54         super(AlexNet, self).__init__()
 55         self.conv1 = nn.Conv2d(in_channels=3,
 56                                out_channels=96,
 57                                kernel_size=11,
 58                                padding=2,
 59                                stride=4)
 60         self.relu1 = nn.ReLU()
 61         self.max_pool1 = nn.MaxPool2d(kernel_size=3,
 62                                       stride=2,
 63                                       padding=0)
 64 
 65         self.conv2 = nn.Conv2d(in_channels=96,
 66                                out_channels=256,
 67                                kernel_size=5,
 68                                padding=2,
 69                                stride=1)
 70         self.relu2 = nn.ReLU()
 71         self.max_pool2 = nn.MaxPool2d(kernel_size=3,
 72                                       stride=2,
 73                                       padding=0)
 74 
 75         self.conv3 = nn.Conv2d(in_channels=256,
 76                                out_channels=384,
 77                                kernel_size=3,
 78                                padding=1,
 79                                stride=1)
 80         self.relu3 = nn.ReLU()
 81 
 82         self.conv4 = nn.Conv2d(in_channels=384,
 83                                out_channels=384,
 84                                kernel_size=3,
 85                                padding=1,
 86                                stride=1)
 87         self.relu4 = nn.ReLU()
 88 
 89         self.conv5 = nn.Conv2d(in_channels=384,
 90                                out_channels=256,
 91                                kernel_size=3,
 92                                padding=1,
 93                                stride=1)
 94         self.relu5 = nn.ReLU()
 95         self.max_pool5 = nn.MaxPool2d(kernel_size=3,
 96                                       stride=2,
 97                                       padding=0)
 98 
 99         self.dropout1=nn.Dropout(0.5)
100         self.linear1 = nn.Linear(in_features=256*6*6,
101                                 out_features=4096,
102                                 bias=True)
103         self.relu6 = nn.ReLU()
104 
105         self.dropout2 = nn.Dropout(0.5)
106         self.linear2 = nn.Linear(in_features=4096,
107                                  out_features=4096,
108                                  bias=True)
109         self.relu7 = nn.ReLU()
110 
111         self.linear3 = nn.Linear(in_features=4096,
112                                  out_features=100,
113                                  bias=True)
114 
115 
116     # 定义前向计算过程
117     def forward(self, x):
118         x = self.conv1(x)
119         x = self.relu1(x)
120         x = self.max_pool1(x)
121 
122         x = self.conv2(x)
123         x = self.relu2(x)
124         x = self.max_pool2(x)
125 
126         x = self.conv3(x)
127         x = self.relu3(x)
128 
129         x = self.conv4(x)
130         x = self.relu4(x)
131 
132         x = self.conv5(x)
133         x = self.relu5(x)
134         x = self.max_pool5(x)
135 
136         # 将特征展平(超级重要!!!)
137         x = x.view(x.shape[0], -1)
138 
139         x = self.dropout1(x)
140         x = self.linear1(x)
141         x = self.relu6(x)
142 
143         x = self.dropout2(x)
144         x = self.linear2(x)
145         x = self.relu7(x)
146 
147 
148         x = self.linear3(x)
149 
150         return x
151 
152 # 定义网络的预训练
153 def train(net, train_loader, test_loader, device, l_r = 0.0002, num_epochs=25,):
154     # 使用wandb跟踪训练过程
155     #experiment = wandb.init(project='AlexNet', resume='allow', anonymous='must')
156     # 定义损失函数
157     criterion = nn.CrossEntropyLoss()
158     # 定义优化器
159     optimizer = torch.optim.Adam(net.parameters(), lr=l_r)
160     # 将网络移动到指定设备
161     net = net.to(device)
162     # 正式开始训练
163     for epoch in range(num_epochs):
164         # 保存一个Epoch的损失
165         train_loss = 0
166         # 计算准确度
167         test_corrects = 0
168         # 设置模型为训练模式
169         net.train()
170         for step, (imgs, labels) in enumerate(train_loader):
171             # 训练使用的数据移动到指定设备
172             imgs = imgs.to(device)
173             labels = labels.to(device)
174             output = net(imgs)
175             # 计算损失
176             loss = criterion(output, labels)
177             # 将梯度清零
178             optimizer.zero_grad()
179             # 将损失进行后向传播
180             loss.backward()
181             # 更新网络参数
182             optimizer.step()
183             train_loss += loss.item()
184         # 设置模型为验证模式
185         net.eval()
186         for step, (imgs, labels) in enumerate(test_loader):
187             imgs = imgs.to(device)
188             labels = labels.to(device)
189             output = net(imgs)
190             pre_lab = torch.argmax(output, 1)
191             corrects = (torch.sum(pre_lab == labels.data).double() / imgs.size(0))
192             test_corrects += corrects.item()
193         #一个Epoch结束时,使用wandb保存需要可视化的数据
194         # experiment.log({
195         #     'epoch':epoch,
196         #     'train loss': train_loss / len(train_loader),
197         #     'test acc': test_corrects / len(test_loader),
198         # })
199         print('Epoch: {}/{}'.format(epoch, num_epochs-1))
200         print('{} Train Loss:{:.4f}'.format(epoch, train_loss / len(train_loader)))
201         print('{} Test Acc:{:.4f}'.format(epoch, test_corrects / len(test_loader)))
202         # 保存此Epoch训练的网络的参数
203         torch.save(net.state_dict(), './net.pth')
204 
205 if __name__ == "__main__":
206     # 定义训练使用的设备
207     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
208     net = AlexNet()
209     train(net, trainloader, testloader, device, l_r=0.0003, num_epochs=10)