关于tensorflow2.x保存模型及加载模型的方法及对比

发布时间 2023-07-10 21:36:21作者: waterdoor

以下方法都是个人实际中测试和使用的方法,tf2版本在2.3~2.7之间

1、model.save() and model.load()

保存模型:这个方法可以直接将训练后的权重和训练的参数保存下来,一般我个人使用的.h5为后缀把模型整个保存下来。

步骤如下:

(1)创建模型,像添加积木一样对模型添加需要的卷积,池化等操作

 (2)配置神经网络的优化器,计算梯度的方法  

 (3)保存模型

 

加载模型:这样保存下来的模型可以直接将h5文件保存下来,不需要先加载模型

 

 

2、model.save_weight() and model.load_weight()

 (1)这里采用继承Model这个类去实现神经网络(比第一种方法更加常用且受规范)

 下面的方法就是当我们保存模型的权重参数,但是没有保存模型的结构

 加载模型

需要先把模型的结构导入过来,再load模型的参数进去才能进行推理

 

3、model.checkpoint

 这个用的比较少,看这样加载模型的方式,可能跟第二种类似

 

 

 

注意:这里的模型都指定了输入的图片的尺寸,如果想输入的图片尺寸不受限制,那么不要使用flatten拉直神经网络,可以使用全卷积来再进行softmax即可

tf.keras.layers.GlobalAveragePooling2D()