import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28 #图像尺寸 28*28
num_class = 10 #标签总数
num_epochs = 3 #训练总周期
batch_size = 64 #一个批次多少图片
train_dataset = datasets.FashionMNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
test_dataset = datasets.FashionMNIST(
root='data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = batch_size,
shuffle = True,
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入为(1,28,28)
nn.Conv2d(
in_channels=1,
out_channels=16, #要得到几个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2,
), #输出特征图为(16*28*28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
)
self.conv2 = nn.Sequential( #输入(16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
output = self.out(x)
return output, x
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001)
for epoch in range(num_epochs+1):
#保留epoch的结果
train_rights = []
for batch_idx,(data,target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
net.train()
output = net(data)[0]
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right)
if batch_idx %100 ==0:
net.eval()
val_rights = []
for(data,target) in test_loader:
data = data.to(device)
target = target.to(device)
output = net(data)[0]
right = accuracy(output,target)
val_rights.append(right)
#计算准确率
train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights))
print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
epoch,
batch_idx * batch_size,
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].cpu().numpy() / train_r[1],
100. * val_r[0].cpu().numpy() / val_r[1]
)
)
torch.save(net, 'cnn_test.pt')
当前epoch:0[0/60000(0%)] 损失:2.31 训练集准确率:7.81% 测试集准确率:19.02%
当前epoch:0[6400/60000(11%)] 损失:0.70 训练集准确率:66.34% 测试集准确率:74.87%
当前epoch:0[12800/60000(21%)] 损失:0.39 训练集准确率:72.43% 测试集准确率:81.34%
当前epoch:0[19200/60000(32%)] 损失:0.49 训练集准确率:75.35% 测试集准确率:81.99%
当前epoch:0[25600/60000(43%)] 损失:0.49 训练集准确率:77.42% 测试集准确率:84.12%
当前epoch:0[32000/60000(53%)] 损失:0.31 训练集准确率:78.93% 测试集准确率:84.10%
当前epoch:0[38400/60000(64%)] 损失:0.35 训练集准确率:80.00% 测试集准确率:84.13%
当前epoch:0[44800/60000(75%)] 损失:0.34 训练集准确率:80.90% 测试集准确率:85.27%
当前epoch:0[51200/60000(85%)] 损失:0.31 训练集准确率:81.55% 测试集准确率:85.85%
当前epoch:0[57600/60000(96%)] 损失:0.48 训练集准确率:82.14% 测试集准确率:86.35%
当前epoch:1[0/60000(0%)] 损失:0.29 训练集准确率:89.06% 测试集准确率:86.13%
当前epoch:1[6400/60000(11%)] 损失:0.42 训练集准确率:87.19% 测试集准确率:84.78%
当前epoch:1[12800/60000(21%)] 损失:0.35 训练集准确率:86.99% 测试集准确率:87.09%
当前epoch:1[19200/60000(32%)] 损失:0.38 训练集准确率:87.33% 测试集准确率:86.61%
当前epoch:1[25600/60000(43%)] 损失:0.32 训练集准确率:87.55% 测试集准确率:86.70%
当前epoch:1[32000/60000(53%)] 损失:0.51 训练集准确率:87.72% 测试集准确率:87.08%
当前epoch:1[38400/60000(64%)] 损失:0.46 训练集准确率:87.88% 测试集准确率:87.95%
当前epoch:1[44800/60000(75%)] 损失:0.33 训练集准确率:87.94% 测试集准确率:88.00%
当前epoch:1[51200/60000(85%)] 损失:0.28 训练集准确率:88.04% 测试集准确率:88.51%
当前epoch:1[57600/60000(96%)] 损失:0.18 训练集准确率:88.15% 测试集准确率:87.43%
当前epoch:2[0/60000(0%)] 损失:0.17 训练集准确率:93.75% 测试集准确率:87.79%
当前epoch:2[6400/60000(11%)] 损失:0.29 训练集准确率:89.71% 测试集准确率:87.70%
当前epoch:2[12800/60000(21%)] 损失:0.27 训练集准确率:89.51% 测试集准确率:88.32%
当前epoch:2[19200/60000(32%)] 损失:0.20 训练集准确率:89.24% 测试集准确率:88.77%
当前epoch:2[25600/60000(43%)] 损失:0.23 训练集准确率:89.45% 测试集准确率:88.24%
当前epoch:2[32000/60000(53%)] 损失:0.21 训练集准确率:89.57% 测试集准确率:88.22%
当前epoch:2[38400/60000(64%)] 损失:0.30 训练集准确率:89.54% 测试集准确率:88.10%
当前epoch:2[44800/60000(75%)] 损失:0.19 训练集准确率:89.54% 测试集准确率:88.92%
当前epoch:2[51200/60000(85%)] 损失:0.45 训练集准确率:89.50% 测试集准确率:88.96%
当前epoch:2[57600/60000(96%)] 损失:0.20 训练集准确率:89.53% 测试集准确率:89.55%
当前epoch:3[0/60000(0%)] 损失:0.27 训练集准确率:93.75% 测试集准确率:88.16%
当前epoch:3[6400/60000(11%)] 损失:0.19 训练集准确率:90.24% 测试集准确率:89.76%
当前epoch:3[12800/60000(21%)] 损失:0.19 训练集准确率:90.10% 测试集准确率:89.41%
当前epoch:3[19200/60000(32%)] 损失:0.24 训练集准确率:90.32% 测试集准确率:89.48%
当前epoch:3[25600/60000(43%)] 损失:0.34 训练集准确率:90.42% 测试集准确率:89.58%
当前epoch:3[32000/60000(53%)] 损失:0.27 训练集准确率:90.30% 测试集准确率:88.86%
当前epoch:3[38400/60000(64%)] 损失:0.34 训练集准确率:90.28% 测试集准确率:89.39%
当前epoch:3[44800/60000(75%)] 损失:0.37 训练集准确率:90.36% 测试集准确率:88.66%
当前epoch:3[51200/60000(85%)] 损失:0.17 训练集准确率:90.36% 测试集准确率:89.72%
当前epoch:3[57600/60000(96%)] 损失:0.20 训练集准确率:90.41% 测试集准确率:89.29%
当前epoch:4[0/60000(0%)] 损失:0.15 训练集准确率:92.19% 测试集准确率:89.55%
当前epoch:4[6400/60000(11%)] 损失:0.30 训练集准确率:91.43% 测试集准确率:89.89%
当前epoch:4[12800/60000(21%)] 损失:0.15 训练集准确率:91.25% 测试集准确率:89.62%
当前epoch:4[19200/60000(32%)] 损失:0.20 训练集准确率:91.23% 测试集准确率:89.95%
当前epoch:4[25600/60000(43%)] 损失:0.16 训练集准确率:91.24% 测试集准确率:89.70%
当前epoch:4[32000/60000(53%)] 损失:0.21 训练集准确率:91.22% 测试集准确率:89.95%
当前epoch:4[38400/60000(64%)] 损失:0.33 训练集准确率:91.18% 测试集准确率:90.42%
当前epoch:4[44800/60000(75%)] 损失:0.19 训练集准确率:91.24% 测试集准确率:89.69%
当前epoch:4[51200/60000(85%)] 损失:0.26 训练集准确率:91.22% 测试集准确率:90.35%
当前epoch:4[57600/60000(96%)] 损失:0.28 训练集准确率:91.25% 测试集准确率:88.77%
当前epoch:5[0/60000(0%)] 损失:0.25 训练集准确率:93.75% 测试集准确率:89.79%
当前epoch:5[6400/60000(11%)] 损失:0.21 训练集准确率:91.21% 测试集准确率:89.90%
当前epoch:5[12800/60000(21%)] 损失:0.15 训练集准确率:91.51% 测试集准确率:90.71%
当前epoch:5[19200/60000(32%)] 损失:0.16 训练集准确率:91.77% 测试集准确率:90.45%
当前epoch:5[25600/60000(43%)] 损失:0.21 训练集准确率:91.84% 测试集准确率:90.56%
当前epoch:5[32000/60000(53%)] 损失:0.12 训练集准确率:91.86% 测试集准确率:89.10%
当前epoch:5[38400/60000(64%)] 损失:0.28 训练集准确率:91.82% 测试集准确率:90.42%
当前epoch:5[44800/60000(75%)] 损失:0.15 训练集准确率:91.88% 测试集准确率:90.19%
当前epoch:5[51200/60000(85%)] 损失:0.33 训练集准确率:91.87% 测试集准确率:90.03%
当前epoch:5[57600/60000(96%)] 损失:0.10 训练集准确率:91.80% 测试集准确率:90.74%
当前epoch:6[0/60000(0%)] 损失:0.15 训练集准确率:93.75% 测试集准确率:90.36%
当前epoch:6[6400/60000(11%)] 损失:0.31 训练集准确率:92.28% 测试集准确率:90.85%
当前epoch:6[12800/60000(21%)] 损失:0.23 训练集准确率:92.15% 测试集准确率:90.68%
当前epoch:6[19200/60000(32%)] 损失:0.15 训练集准确率:92.37% 测试集准确率:90.71%
当前epoch:6[25600/60000(43%)] 损失:0.31 训练集准确率:92.29% 测试集准确率:91.02%
当前epoch:6[32000/60000(53%)] 损失:0.21 训练集准确率:92.43% 测试集准确率:90.57%
当前epoch:6[38400/60000(64%)] 损失:0.25 训练集准确率:92.43% 测试集准确率:90.51%
当前epoch:6[44800/60000(75%)] 损失:0.21 训练集准确率:92.48% 测试集准确率:90.56%
当前epoch:6[51200/60000(85%)] 损失:0.07 训练集准确率:92.43% 测试集准确率:91.04%
当前epoch:6[57600/60000(96%)] 损失:0.14 训练集准确率:92.43% 测试集准确率:90.68%
当前epoch:7[0/60000(0%)] 损失:0.25 训练集准确率:89.06% 测试集准确率:91.18%
当前epoch:7[6400/60000(11%)] 损失:0.11 训练集准确率:92.51% 测试集准确率:91.09%
当前epoch:7[12800/60000(21%)] 损失:0.17 训练集准确率:92.98% 测试集准确率:91.21%
当前epoch:7[19200/60000(32%)] 损失:0.23 训练集准确率:93.06% 测试集准确率:90.80%
当前epoch:7[25600/60000(43%)] 损失:0.18 训练集准确率:92.95% 测试集准确率:91.39%
当前epoch:7[32000/60000(53%)] 损失:0.24 训练集准确率:93.01% 测试集准确率:91.06%
当前epoch:7[38400/60000(64%)] 损失:0.27 训练集准确率:92.94% 测试集准确率:91.18%
当前epoch:7[44800/60000(75%)] 损失:0.31 训练集准确率:92.77% 测试集准确率:90.88%
当前epoch:7[51200/60000(85%)] 损失:0.17 训练集准确率:92.73% 测试集准确率:91.42%
当前epoch:7[57600/60000(96%)] 损失:0.17 训练集准确率:92.75% 测试集准确率:90.75%
当前epoch:8[0/60000(0%)] 损失:0.15 训练集准确率:95.31% 测试集准确率:91.15%
当前epoch:8[6400/60000(11%)] 损失:0.18 训练集准确率:93.13% 测试集准确率:91.42%
当前epoch:8[12800/60000(21%)] 损失:0.12 训练集准确率:93.24% 测试集准确率:91.31%
当前epoch:8[19200/60000(32%)] 损失:0.27 训练集准确率:93.37% 测试集准确率:91.25%
当前epoch:8[25600/60000(43%)] 损失:0.17 训练集准确率:93.38% 测试集准确率:91.52%
当前epoch:8[32000/60000(53%)] 损失:0.19 训练集准确率:93.16% 测试集准确率:91.51%
当前epoch:8[38400/60000(64%)] 损失:0.26 训练集准确率:93.11% 测试集准确率:91.34%
当前epoch:8[44800/60000(75%)] 损失:0.44 训练集准确率:93.05% 测试集准确率:91.35%
当前epoch:8[51200/60000(85%)] 损失:0.31 训练集准确率:93.03% 测试集准确率:91.23%
当前epoch:8[57600/60000(96%)] 损失:0.22 训练集准确率:93.01% 测试集准确率:90.74%
当前epoch:9[0/60000(0%)] 损失:0.19 训练集准确率:93.75% 测试集准确率:91.15%
当前epoch:9[6400/60000(11%)] 损失:0.27 训练集准确率:93.64% 测试集准确率:90.78%
当前epoch:9[12800/60000(21%)] 损失:0.25 训练集准确率:93.73% 测试集准确率:91.31%
当前epoch:9[19200/60000(32%)] 损失:0.23 训练集准确率:93.42% 测试集准确率:89.56%
当前epoch:9[25600/60000(43%)] 损失:0.27 训练集准确率:93.24% 测试集准确率:90.82%
当前epoch:9[32000/60000(53%)] 损失:0.23 训练集准确率:93.33% 测试集准确率:91.29%
当前epoch:9[38400/60000(64%)] 损失:0.09 训练集准确率:93.31% 测试集准确率:91.24%
当前epoch:9[44800/60000(75%)] 损失:0.25 训练集准确率:93.31% 测试集准确率:90.78%
当前epoch:9[51200/60000(85%)] 损失:0.19 训练集准确率:93.33% 测试集准确率:91.34%
当前epoch:9[57600/60000(96%)] 损失:0.12 训练集准确率:93.35% 测试集准确率:91.30%
当前epoch:10[0/60000(0%)] 损失:0.17 训练集准确率:93.75% 测试集准确率:91.27%
当前epoch:10[6400/60000(11%)] 损失:0.13 训练集准确率:93.81% 测试集准确率:91.61%
当前epoch:10[12800/60000(21%)] 损失:0.22 训练集准确率:93.91% 测试集准确率:91.13%
当前epoch:10[19200/60000(32%)] 损失:0.14 训练集准确率:93.92% 测试集准确率:91.19%
当前epoch:10[25600/60000(43%)] 损失:0.22 训练集准确率:93.92% 测试集准确率:91.78%
当前epoch:10[32000/60000(53%)] 损失:0.15 训练集准确率:93.95% 测试集准确率:90.79%
当前epoch:10[38400/60000(64%)] 损失:0.09 训练集准确率:93.92% 测试集准确率:91.42%
当前epoch:10[44800/60000(75%)] 损失:0.12 训练集准确率:93.86% 测试集准确率:91.62%
当前epoch:10[51200/60000(85%)] 损失:0.14 训练集准确率:93.84% 测试集准确率:90.67%
当前epoch:10[57600/60000(96%)] 损失:0.13 训练集准确率:93.78% 测试集准确率:91.42%
from PIL import Image
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8, 8))
cols, rows = 1,1 #准备显示几幅图像
for i in range(1, cols * rows + 1): #[1,10)
sample_idx = torch.randint(len(train_dataset),
size=(1,)).item() #取一幅随机图像
img, label = train_dataset[sample_idx]
print(img.shape)
print(label)
figure.add_subplot(rows, cols, i) #3x3的figure加载第i幅图像
#plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
torch.Size([1, 28, 28])
9
model_path = "cnn_test.pt" # 这里是你保存的.pt文件的路径
model = torch.load(model_path)
model.cuda()
print(model)
image_path = "1.png" # 这里是你要分类的图像的路径
# 使用PIL库加载图像并将其转换为张量
image = Image.open(image_path)
print(image)
#image = image.convert('L')
transform = transforms.Compose([
transforms.Resize((28,28)),
#transforms.RandomResizedCrop(28),
#transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5)),
])
image_tensor = transform(image).unsqueeze(0).type(torch.cuda.FloatTensor)
#torch.transpose(image_tensor,1,2)
image_tensor.shape
CNN(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(out): Linear(in_features=1568, out_features=10, bias=True)
)
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=102x103 at 0x1F19505CE48>
torch.Size([1, 3, 28, 28])
#image_tensor.shape
get_gray = transforms.Compose([
transforms.Grayscale(),
])
gray_image = get_gray(image_tensor)
gray_image.shape
torch.Size([1, 1, 28, 28])
with torch.no_grad():
output = model(gray_image)
output[0].shape
torch.Size([1, 10])
with torch.no_grad():
test_output, _ = model(gray_image)
pred_y = (torch.max(test_output, 1)[1].data).cpu().numpy()
labels_map[int(pred_y)]
'Dress'
with torch.no_grad():
output = model(gray_image)
softmax = torch.nn.Softmax(dim=1)
probs = softmax(output[0])
probs
tensor([[1.3675e-04, 4.2272e-01, 3.3998e-11, 5.7714e-01, 2.2880e-16, 4.4630e-10,
1.5328e-11, 5.6080e-26, 3.3082e-10, 2.3702e-28]], device='cuda:0')
max_val, max_idx = torch.max(probs, dim=1)
# 输出结果
print("最大值为:", max_val.item())
print("最大值的索引为:", max_idx.item())
最大值为: 0.5771426558494568
最大值的索引为: 3
labels_map[max_idx.item()]
'Dress'