[cnn]FashionMINST训练+保存模型+调用模型判断给定图片

发布时间 2023-05-26 10:37:42作者: J1nWan
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28   #图像尺寸 28*28
num_class = 10    #标签总数
num_epochs = 3    #训练总周期
batch_size = 64    #一个批次多少图片

train_dataset = datasets.FashionMNIST(
  root='data',
  train=True,
  transform=transforms.ToTensor(),
  download=True,
)

test_dataset = datasets.FashionMNIST(
  root='data',
   train=False,
  transform=transforms.ToTensor(),
  download=True,
)

train_loader = torch.utils.data.DataLoader(
  dataset = train_dataset,
  batch_size = batch_size,
  shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
  dataset = test_dataset,
  batch_size = batch_size,
  shuffle = True,
)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(   #输入为(1,28,28)
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,      #要得到几个特征图      
                kernel_size=5,        #卷积核大小      
                stride=1,             #步长     
                padding=2,                  
            ),                         #输出特征图为(16*28*28)     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
        )
        self.conv2 = nn.Sequential(          #输入(16,14,14)
            nn.Conv2d(16, 32, 5, 1, 2),     #输出(32,14,14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                #输出(32,7,7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10) #全连接

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
        output = self.out(x)
        return output, x 
def accuracy(predictions,labels):
  pred = torch.max(predictions.data,1)[1]
  rights = pred.eq(labels.data.view_as(pred)).sum()
  return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001)

for epoch in range(num_epochs+1):
  #保留epoch的结果
  train_rights = []
  for batch_idx,(data,target) in enumerate(train_loader):
    data = data.to(device)
    target = target.to(device)
    net.train()
    output = net(data)[0]
    loss = criterion(output,target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    right = accuracy(output,target)
    train_rights.append(right)

    if batch_idx %100 ==0:
      net.eval()
      val_rights = []
      for(data,target) in test_loader:
        data = data.to(device)
        target = target.to(device)
        output = net(data)[0]
        right = accuracy(output,target)
        val_rights.append(right)
      #计算准确率
      train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
      val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights))

      print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
        epoch,
        batch_idx * batch_size,
        len(train_loader.dataset),
        100. * batch_idx / len(train_loader),
        loss.data,
        100. * train_r[0].cpu().numpy() / train_r[1],
        100. * val_r[0].cpu().numpy() / val_r[1]
      )
      )
      torch.save(net, 'cnn_test.pt')

当前epoch:0[0/60000(0%)]	损失:2.31	训练集准确率:7.81%	测试集准确率:19.02%
当前epoch:0[6400/60000(11%)]	损失:0.70	训练集准确率:66.34%	测试集准确率:74.87%
当前epoch:0[12800/60000(21%)]	损失:0.39	训练集准确率:72.43%	测试集准确率:81.34%
当前epoch:0[19200/60000(32%)]	损失:0.49	训练集准确率:75.35%	测试集准确率:81.99%
当前epoch:0[25600/60000(43%)]	损失:0.49	训练集准确率:77.42%	测试集准确率:84.12%
当前epoch:0[32000/60000(53%)]	损失:0.31	训练集准确率:78.93%	测试集准确率:84.10%
当前epoch:0[38400/60000(64%)]	损失:0.35	训练集准确率:80.00%	测试集准确率:84.13%
当前epoch:0[44800/60000(75%)]	损失:0.34	训练集准确率:80.90%	测试集准确率:85.27%
当前epoch:0[51200/60000(85%)]	损失:0.31	训练集准确率:81.55%	测试集准确率:85.85%
当前epoch:0[57600/60000(96%)]	损失:0.48	训练集准确率:82.14%	测试集准确率:86.35%
当前epoch:1[0/60000(0%)]	损失:0.29	训练集准确率:89.06%	测试集准确率:86.13%
当前epoch:1[6400/60000(11%)]	损失:0.42	训练集准确率:87.19%	测试集准确率:84.78%
当前epoch:1[12800/60000(21%)]	损失:0.35	训练集准确率:86.99%	测试集准确率:87.09%
当前epoch:1[19200/60000(32%)]	损失:0.38	训练集准确率:87.33%	测试集准确率:86.61%
当前epoch:1[25600/60000(43%)]	损失:0.32	训练集准确率:87.55%	测试集准确率:86.70%
当前epoch:1[32000/60000(53%)]	损失:0.51	训练集准确率:87.72%	测试集准确率:87.08%
当前epoch:1[38400/60000(64%)]	损失:0.46	训练集准确率:87.88%	测试集准确率:87.95%
当前epoch:1[44800/60000(75%)]	损失:0.33	训练集准确率:87.94%	测试集准确率:88.00%
当前epoch:1[51200/60000(85%)]	损失:0.28	训练集准确率:88.04%	测试集准确率:88.51%
当前epoch:1[57600/60000(96%)]	损失:0.18	训练集准确率:88.15%	测试集准确率:87.43%
当前epoch:2[0/60000(0%)]	损失:0.17	训练集准确率:93.75%	测试集准确率:87.79%
当前epoch:2[6400/60000(11%)]	损失:0.29	训练集准确率:89.71%	测试集准确率:87.70%
当前epoch:2[12800/60000(21%)]	损失:0.27	训练集准确率:89.51%	测试集准确率:88.32%
当前epoch:2[19200/60000(32%)]	损失:0.20	训练集准确率:89.24%	测试集准确率:88.77%
当前epoch:2[25600/60000(43%)]	损失:0.23	训练集准确率:89.45%	测试集准确率:88.24%
当前epoch:2[32000/60000(53%)]	损失:0.21	训练集准确率:89.57%	测试集准确率:88.22%
当前epoch:2[38400/60000(64%)]	损失:0.30	训练集准确率:89.54%	测试集准确率:88.10%
当前epoch:2[44800/60000(75%)]	损失:0.19	训练集准确率:89.54%	测试集准确率:88.92%
当前epoch:2[51200/60000(85%)]	损失:0.45	训练集准确率:89.50%	测试集准确率:88.96%
当前epoch:2[57600/60000(96%)]	损失:0.20	训练集准确率:89.53%	测试集准确率:89.55%
当前epoch:3[0/60000(0%)]	损失:0.27	训练集准确率:93.75%	测试集准确率:88.16%
当前epoch:3[6400/60000(11%)]	损失:0.19	训练集准确率:90.24%	测试集准确率:89.76%
当前epoch:3[12800/60000(21%)]	损失:0.19	训练集准确率:90.10%	测试集准确率:89.41%
当前epoch:3[19200/60000(32%)]	损失:0.24	训练集准确率:90.32%	测试集准确率:89.48%
当前epoch:3[25600/60000(43%)]	损失:0.34	训练集准确率:90.42%	测试集准确率:89.58%
当前epoch:3[32000/60000(53%)]	损失:0.27	训练集准确率:90.30%	测试集准确率:88.86%
当前epoch:3[38400/60000(64%)]	损失:0.34	训练集准确率:90.28%	测试集准确率:89.39%
当前epoch:3[44800/60000(75%)]	损失:0.37	训练集准确率:90.36%	测试集准确率:88.66%
当前epoch:3[51200/60000(85%)]	损失:0.17	训练集准确率:90.36%	测试集准确率:89.72%
当前epoch:3[57600/60000(96%)]	损失:0.20	训练集准确率:90.41%	测试集准确率:89.29%
当前epoch:4[0/60000(0%)]	损失:0.15	训练集准确率:92.19%	测试集准确率:89.55%
当前epoch:4[6400/60000(11%)]	损失:0.30	训练集准确率:91.43%	测试集准确率:89.89%
当前epoch:4[12800/60000(21%)]	损失:0.15	训练集准确率:91.25%	测试集准确率:89.62%
当前epoch:4[19200/60000(32%)]	损失:0.20	训练集准确率:91.23%	测试集准确率:89.95%
当前epoch:4[25600/60000(43%)]	损失:0.16	训练集准确率:91.24%	测试集准确率:89.70%
当前epoch:4[32000/60000(53%)]	损失:0.21	训练集准确率:91.22%	测试集准确率:89.95%
当前epoch:4[38400/60000(64%)]	损失:0.33	训练集准确率:91.18%	测试集准确率:90.42%
当前epoch:4[44800/60000(75%)]	损失:0.19	训练集准确率:91.24%	测试集准确率:89.69%
当前epoch:4[51200/60000(85%)]	损失:0.26	训练集准确率:91.22%	测试集准确率:90.35%
当前epoch:4[57600/60000(96%)]	损失:0.28	训练集准确率:91.25%	测试集准确率:88.77%
当前epoch:5[0/60000(0%)]	损失:0.25	训练集准确率:93.75%	测试集准确率:89.79%
当前epoch:5[6400/60000(11%)]	损失:0.21	训练集准确率:91.21%	测试集准确率:89.90%
当前epoch:5[12800/60000(21%)]	损失:0.15	训练集准确率:91.51%	测试集准确率:90.71%
当前epoch:5[19200/60000(32%)]	损失:0.16	训练集准确率:91.77%	测试集准确率:90.45%
当前epoch:5[25600/60000(43%)]	损失:0.21	训练集准确率:91.84%	测试集准确率:90.56%
当前epoch:5[32000/60000(53%)]	损失:0.12	训练集准确率:91.86%	测试集准确率:89.10%
当前epoch:5[38400/60000(64%)]	损失:0.28	训练集准确率:91.82%	测试集准确率:90.42%
当前epoch:5[44800/60000(75%)]	损失:0.15	训练集准确率:91.88%	测试集准确率:90.19%
当前epoch:5[51200/60000(85%)]	损失:0.33	训练集准确率:91.87%	测试集准确率:90.03%
当前epoch:5[57600/60000(96%)]	损失:0.10	训练集准确率:91.80%	测试集准确率:90.74%
当前epoch:6[0/60000(0%)]	损失:0.15	训练集准确率:93.75%	测试集准确率:90.36%
当前epoch:6[6400/60000(11%)]	损失:0.31	训练集准确率:92.28%	测试集准确率:90.85%
当前epoch:6[12800/60000(21%)]	损失:0.23	训练集准确率:92.15%	测试集准确率:90.68%
当前epoch:6[19200/60000(32%)]	损失:0.15	训练集准确率:92.37%	测试集准确率:90.71%
当前epoch:6[25600/60000(43%)]	损失:0.31	训练集准确率:92.29%	测试集准确率:91.02%
当前epoch:6[32000/60000(53%)]	损失:0.21	训练集准确率:92.43%	测试集准确率:90.57%
当前epoch:6[38400/60000(64%)]	损失:0.25	训练集准确率:92.43%	测试集准确率:90.51%
当前epoch:6[44800/60000(75%)]	损失:0.21	训练集准确率:92.48%	测试集准确率:90.56%
当前epoch:6[51200/60000(85%)]	损失:0.07	训练集准确率:92.43%	测试集准确率:91.04%
当前epoch:6[57600/60000(96%)]	损失:0.14	训练集准确率:92.43%	测试集准确率:90.68%
当前epoch:7[0/60000(0%)]	损失:0.25	训练集准确率:89.06%	测试集准确率:91.18%
当前epoch:7[6400/60000(11%)]	损失:0.11	训练集准确率:92.51%	测试集准确率:91.09%
当前epoch:7[12800/60000(21%)]	损失:0.17	训练集准确率:92.98%	测试集准确率:91.21%
当前epoch:7[19200/60000(32%)]	损失:0.23	训练集准确率:93.06%	测试集准确率:90.80%
当前epoch:7[25600/60000(43%)]	损失:0.18	训练集准确率:92.95%	测试集准确率:91.39%
当前epoch:7[32000/60000(53%)]	损失:0.24	训练集准确率:93.01%	测试集准确率:91.06%
当前epoch:7[38400/60000(64%)]	损失:0.27	训练集准确率:92.94%	测试集准确率:91.18%
当前epoch:7[44800/60000(75%)]	损失:0.31	训练集准确率:92.77%	测试集准确率:90.88%
当前epoch:7[51200/60000(85%)]	损失:0.17	训练集准确率:92.73%	测试集准确率:91.42%
当前epoch:7[57600/60000(96%)]	损失:0.17	训练集准确率:92.75%	测试集准确率:90.75%
当前epoch:8[0/60000(0%)]	损失:0.15	训练集准确率:95.31%	测试集准确率:91.15%
当前epoch:8[6400/60000(11%)]	损失:0.18	训练集准确率:93.13%	测试集准确率:91.42%
当前epoch:8[12800/60000(21%)]	损失:0.12	训练集准确率:93.24%	测试集准确率:91.31%
当前epoch:8[19200/60000(32%)]	损失:0.27	训练集准确率:93.37%	测试集准确率:91.25%
当前epoch:8[25600/60000(43%)]	损失:0.17	训练集准确率:93.38%	测试集准确率:91.52%
当前epoch:8[32000/60000(53%)]	损失:0.19	训练集准确率:93.16%	测试集准确率:91.51%
当前epoch:8[38400/60000(64%)]	损失:0.26	训练集准确率:93.11%	测试集准确率:91.34%
当前epoch:8[44800/60000(75%)]	损失:0.44	训练集准确率:93.05%	测试集准确率:91.35%
当前epoch:8[51200/60000(85%)]	损失:0.31	训练集准确率:93.03%	测试集准确率:91.23%
当前epoch:8[57600/60000(96%)]	损失:0.22	训练集准确率:93.01%	测试集准确率:90.74%
当前epoch:9[0/60000(0%)]	损失:0.19	训练集准确率:93.75%	测试集准确率:91.15%
当前epoch:9[6400/60000(11%)]	损失:0.27	训练集准确率:93.64%	测试集准确率:90.78%
当前epoch:9[12800/60000(21%)]	损失:0.25	训练集准确率:93.73%	测试集准确率:91.31%
当前epoch:9[19200/60000(32%)]	损失:0.23	训练集准确率:93.42%	测试集准确率:89.56%
当前epoch:9[25600/60000(43%)]	损失:0.27	训练集准确率:93.24%	测试集准确率:90.82%
当前epoch:9[32000/60000(53%)]	损失:0.23	训练集准确率:93.33%	测试集准确率:91.29%
当前epoch:9[38400/60000(64%)]	损失:0.09	训练集准确率:93.31%	测试集准确率:91.24%
当前epoch:9[44800/60000(75%)]	损失:0.25	训练集准确率:93.31%	测试集准确率:90.78%
当前epoch:9[51200/60000(85%)]	损失:0.19	训练集准确率:93.33%	测试集准确率:91.34%
当前epoch:9[57600/60000(96%)]	损失:0.12	训练集准确率:93.35%	测试集准确率:91.30%
当前epoch:10[0/60000(0%)]	损失:0.17	训练集准确率:93.75%	测试集准确率:91.27%
当前epoch:10[6400/60000(11%)]	损失:0.13	训练集准确率:93.81%	测试集准确率:91.61%
当前epoch:10[12800/60000(21%)]	损失:0.22	训练集准确率:93.91%	测试集准确率:91.13%
当前epoch:10[19200/60000(32%)]	损失:0.14	训练集准确率:93.92%	测试集准确率:91.19%
当前epoch:10[25600/60000(43%)]	损失:0.22	训练集准确率:93.92%	测试集准确率:91.78%
当前epoch:10[32000/60000(53%)]	损失:0.15	训练集准确率:93.95%	测试集准确率:90.79%
当前epoch:10[38400/60000(64%)]	损失:0.09	训练集准确率:93.92%	测试集准确率:91.42%
当前epoch:10[44800/60000(75%)]	损失:0.12	训练集准确率:93.86%	测试集准确率:91.62%
当前epoch:10[51200/60000(85%)]	损失:0.14	训练集准确率:93.84%	测试集准确率:90.67%
当前epoch:10[57600/60000(96%)]	损失:0.13	训练集准确率:93.78%	测试集准确率:91.42%
from PIL import Image
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8, 8))
cols, rows = 1,1  #准备显示几幅图像
for i in range(1, cols * rows + 1):  #[1,10)
    sample_idx = torch.randint(len(train_dataset), 
                        size=(1,)).item() #取一幅随机图像
    
    img, label = train_dataset[sample_idx]
    print(img.shape)
    print(label)
    figure.add_subplot(rows, cols, i)  #3x3的figure加载第i幅图像
    #plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
torch.Size([1, 28, 28])
9

png

model_path = "cnn_test.pt"  # 这里是你保存的.pt文件的路径

model = torch.load(model_path)
model.cuda()
print(model)

image_path = "1.png"  # 这里是你要分类的图像的路径

# 使用PIL库加载图像并将其转换为张量
image = Image.open(image_path)
print(image)
#image = image.convert('L')
transform = transforms.Compose([
    transforms.Resize((28,28)),
    #transforms.RandomResizedCrop(28),
    #transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
])

image_tensor = transform(image).unsqueeze(0).type(torch.cuda.FloatTensor)
#torch.transpose(image_tensor,1,2)
image_tensor.shape
CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=102x103 at 0x1F19505CE48>





torch.Size([1, 3, 28, 28])
#image_tensor.shape
get_gray = transforms.Compose([
    transforms.Grayscale(),
])
gray_image = get_gray(image_tensor)
gray_image.shape
torch.Size([1, 1, 28, 28])
with torch.no_grad():
    output = model(gray_image)
output[0].shape
torch.Size([1, 10])
with torch.no_grad():
  test_output, _ = model(gray_image)
pred_y = (torch.max(test_output, 1)[1].data).cpu().numpy()
labels_map[int(pred_y)]
'Dress'
with torch.no_grad():
    output = model(gray_image)
softmax = torch.nn.Softmax(dim=1)
probs = softmax(output[0])
probs

tensor([[1.3675e-04, 4.2272e-01, 3.3998e-11, 5.7714e-01, 2.2880e-16, 4.4630e-10,
         1.5328e-11, 5.6080e-26, 3.3082e-10, 2.3702e-28]], device='cuda:0')
max_val, max_idx = torch.max(probs, dim=1)

# 输出结果
print("最大值为:", max_val.item())
print("最大值的索引为:", max_idx.item())
最大值为: 0.5771426558494568
最大值的索引为: 3
labels_map[max_idx.item()]
'Dress'