简单神经网络(py)

发布时间 2023-11-19 09:03:57作者: 小菜碟子
  1 import numpy
  2 #激活函数库
  3 import scipy.special
  4 
  5 import matplotlib.pyplot
  6 
  7 #neutral network class definition
  8 class neutralNetwork:
  9     def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
 10         #定义各个节点
 11         self.inodes=inputnodes
 12         self.hnodes=hiddennodes
 13         self.onodes=outputnodes
 14 
 15         #初始化权重矩阵(利用正态分布)
 16         self.win=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
 17         self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
 18 
 19         #定义激活函数
 20         self.activation_function=lambda x: scipy.special.expit(x)
 21 
 22         #初始化学习率
 23         self.lr=learningrate
 24         pass
 25 
 26     #训练网络并更新权重
 27     def train(self,inputs_list,targets_list):
 28         inputs=numpy.array(inputs_list,ndmin=2).T
 29         targets=numpy.array(targets_list,ndmin=2).T
 30 
 31         hidden_inputs=numpy.dot(self.win,inputs)
 32         hidden_outputs=self.activation_function(hidden_inputs)
 33 
 34         final_inputs=numpy.dot(self.who,hidden_outputs)
 35         final_outputs=self.activation_function(final_inputs)
 36 
 37         output_errors=targets-final_outputs
 38         hidden_errors=numpy.dot(self.who.T,output_errors)
 39 
 40         self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
 41         self.win+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
 42 
 43         pass
 44 
 45     #查询每次输出结果
 46     def query(self,inputs_list):
 47         inputs=numpy.array(inputs_list,ndmin=2).T
 48 
 49         hidden_inputs=numpy.dot(self.win,inputs)
 50         hidden_outputs=self.activation_function(hidden_inputs)
 51 
 52         final_inputs=numpy.dot(self.who,hidden_outputs)
 53         final_outputs=self.activation_function(final_inputs)
 54 
 55         return final_outputs
 56         pass
 57 
 58 #inputnode是像素的大小28*28
 59 input_nodes=784
 60 #选择比inputnode小的,强迫网络总结输入主要特点
 61 hidden_nodes=100
 62 #手写一共十个数字,所以设置outputnode为10
 63 output_nodes=10
 64 
 65 learning_rate=0.3
 66 
 67 #训练2世代(太大会过度拟合)
 68 epoches=2
 69 
 70 n=neutralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
 71 
 72 #加载mnist训练集
 73 training_data_file=open("mnist_train.csv",'r')
 74 training_data_list=training_data_file.readlines()
 75 training_data_file.close()
 76 
 77 #用训练集训练网络
 78 for e in range(epoches):
 79     for record in training_data_list:
 80         all_values=record.split(',')
 81 
 82         #转化成input矩阵格式(非0:会造成网络崩溃;除以最大像素是255得到0.01-0.99;激活函数不能达到1)
 83         inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
 84 
 85         #设置目标输出:不能为0和1,否则会存在饱和网络(为了无限接近不可能的值0和1)
 86         targets=numpy.zeros(output_nodes)+0.01
 87         targets[int(all_values[0])]=0.99
 88         n.train(inputs,targets)
 89         pass
 90     pass
 91 
 92 #测试网络
 93 test_data_file=open("mnist_test.csv",'r')
 94 test_data_list=test_data_file.readlines()
 95 test_data_file.close()
 96 
 97 scorecard=[]
 98 
 99 for record in test_data_list:
100     all_values=record.split(',')
101     correct_label=int(all_values[0])
102     print(correct_label,"correct label")
103     image_array=numpy.asfarray(all_values[1:]).reshape((28,28))
104     matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None')
105 
106     inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
107 
108     outputs=n.query(inputs)
109 
110     label=numpy.argmax(outputs)
111     print(label,"network's answer:")
112 
113     if(label==correct_label):
114         scorecard.append(1)
115     else:
116         scorecard.append(0)
117     pass
118 
119 scorecard_array=numpy.asfarray(scorecard)
120 print("performance=",scorecard_array.sum()/scorecard_array.size)