基于Tensorflow的Faster-Rcnn的断点续训

发布时间 2023-06-12 17:17:30作者: 瑾明

一、前言

  最近在学习目标检测,到github上找了一个开源的Faster-RCNN项目(Tensorflow),项目地址是:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3

  根据网上的各种教程,模型训练还算顺利,不过这个项目缺少断点续训的功能。也就是中途误操作导致训练中止,就只能从头开始训练,模型的训练还是需要比较长的时间,没有断点续训不是很方便。因此在原项目的基础上新增了断点续训功能。

二、断点续训

  找到项目根目录下的train.py文件,在 last_snapshot_iter = 0 这行代码后新增以下代码块:

        ckpt = tf.train.get_checkpoint_state("./default/voc_2007_trainval/default")
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess,ckpt.model_checkpoint_path) #恢复当前会话sess,将ckpt中的值赋给w和b
            last_checkpoint = ckpt.model_checkpoint_path #最近模型路径
            ins_start = last_checkpoint.index("iter_")+5
            ins_end = last_checkpoint.index(".ckpt")
            last_iter = last_checkpoint[ins_start:ins_end] #最近模型的迭代次数
            last_snapshot_iter = int(last_iter)    

 

  加完代码之后,训练中止时,执行python train.py,即可自动检测是否断点续训。如果想重新开始训练模型,将 default/voc_2007_trainval/default 目录下的内容删除即可。