NLP应用 | 保存checkpoint模型

发布时间 2023-07-07 21:30:12作者: 张Zong在修行

需求描述:

当我们训练模型的时候,我们要训练很多训练步数,我们想要保存训练到一定阶段的checkpoint模型参数,并把这些checkpoint模型保存到一个指定的文件夹下。在文件夹下我们最多保存keep_checkpoint_max个checkpoint模型的文件。保存到output文件夹下。每save_checkpoint_steps步去保存一次。

如果保存的checkpoint模型已经达到最大数量,那么就把最早保存的文件删除,然后在保存现在的checkpoint模型的文件。

文件名是后面是保存的第几次。

代码梳理

首先我们定义一个checkpoint模型保存的函数:

def save_checkpoint(step, epoch, model, optimizer, params):
    if dist.get_rank() == 0:
        state = {
            "step": step,
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        utils.save(state, params.output, params.keep_checkpoint_max)

我们定义了一个保存模型的函数,需要传入的参数为训练步数(step)、数据集训练次数(epoch)、模型(model)、优化器(optimizer)、参数集(params)。

我们定义了需要保存的信息字典:

state = {
    "step": step,
    "epoch": epoch,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict()
}

字典里保存了训练步数、epoch数、模型参数和优化器参数。

然后传给我们自己的一个工具类的保存函数utils.save()

接下来我们看一下工具包保存checkpoint模型的实现。

import os
import glob
import torch


def oldest_checkpoint(path):
    names = glob.glob(os.path.join(path, "*.pt"))

    if not names:
        return None

    oldest_counter = 10000000
    checkpoint_name = names[0]

    for name in names:
        counter = name.rstrip(".pt").split("-")[-1]

        if not counter.isdigit():
            continue
        else:
            counter = int(counter)

        if counter < oldest_counter:
            checkpoint_name = name
            oldest_counter = counter

    return checkpoint_name


def latest_checkpoint(path):
    names = glob.glob(os.path.join(path, "*.pt"))

    if not names:
        return None

    latest_counter = 0
    checkpoint_name = names[0]

    for name in names:
        counter = name.rstrip(".pt").split("-")[-1]

        if not counter.isdigit():
            continue
        else:
            counter = int(counter)

        if counter > latest_counter:
            checkpoint_name = name
            latest_counter = counter

    return checkpoint_name


def save(state, path, max_to_keep=None):
    checkpoints = glob.glob(os.path.join(path, "*.pt"))

    if max_to_keep and len(checkpoints) >= max_to_keep:
        checkpoint = oldest_checkpoint(path)
        os.remove(checkpoint)

    if not checkpoints:
        counter = 1
    else:
        checkpoint = latest_checkpoint(path)
        counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

    checkpoint = os.path.join(path, "model-%d.pt" % counter)
    print("Saving checkpoint: %s" % checkpoint)
    torch.save(state, checkpoint)

我们首先来看一下save()函数的实现。

def save(state, path, max_to_keep=None):
    checkpoints = glob.glob(os.path.join(path, "*.pt"))

    if max_to_keep and len(checkpoints) >= max_to_keep:
        checkpoint = oldest_checkpoint(path)
        os.remove(checkpoint)

    if not checkpoints:
        counter = 1
    else:
        checkpoint = latest_checkpoint(path)
        counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

    checkpoint = os.path.join(path, "model-%d.pt" % counter)
    print("Saving checkpoint: %s" % checkpoint)
    torch.save(state, checkpoint)

刚刚我们从save_checkpoint函数中传入到save函数三个参数,我们一个个看一下。

  • state:需要保存的信息,类型是字典类型的数据
  • path:我们在命令行输入的output路径,用来保存模型的路径
  • max_to_keep:keep_checkpoint_max,这个参数的作用就是在文件夹下我们最多保存keep_checkpoint_max个checkpoint模型的文件

save函数的流程:

第一,我们先查看一下这个文件夹下有多少.pt结尾的文件,以列表的方式保存到checkpoints变量中。

checkpoints = glob.glob(os.path.join(path, "*.pt"))

第二,如果传入了max_to_keep参数,并且文件夹中目前的checkpoint模型的文件大于或者等于最大达到保存的文件数时,我们寻找文件夹下最先保存的checkpoint模型的文件,然后删除这个文件。如果没有超过,这段不执行。

if max_to_keep and len(checkpoints) >= max_to_keep:
    checkpoint = oldest_checkpoint(path)
    os.remove(checkpoint)

第三,如果最开始文件夹下没有任何checkpoint模型的文件,文件计数(counter)加一。如果有文件的话,找到最新保存的checkpoint模型文件的文件名,提取文件名中的数字,然后加一,作为当前保存的文件的数字尾缀。

if not checkpoints:
    counter = 1
else:
    checkpoint = latest_checkpoint(path)
    counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

第四,拼接路径和文件的名字,传给checkpoint变量。

checkpoint = os.path.join(path, "model-%d.pt" % counter)
print("Saving checkpoint: %s" % checkpoint)

第五,使用pytorch的torch.save()函数进行模型和训练相关参数的保存。

torch.save(state, checkpoint)

上面save函数调用了两个函数:

  • oldest_checkpoint:返回最早保存的checkpoint模型文件的文件名
  • latest_checkpoint:返回最新保存的checkpoint模型文件的文件名

自己仔细看代码实现逻辑十分好懂,自己看一下吧。

到这里我们已经知道模型如何保存的实现了,上面需求描述的也大都实现了,但是缺一个训练多少步进行调用这个函数,在训练的过程中,如下代码所示:

if step % params.save_checkpoint_steps == 0:
    save_checkpoint(step, epoch, model, optimizer, params)

代码的意思是,当训练步数对多少步数保存一次的参数(save_checkpoint_steps)进行取余,如果为零,表示save_checkpoint_steps步训练到了,需要保存了,然后执行我们实现的save_checkpoint函数对模型的checkpoint进行保存。

代码中用到的参数来源:

  • 一部分是执行命令的时候用户传入的
  • 一部分是代码设置的默认参数,这些参数也可以在命令行进行指定

总结:

  • 我们这样是为了需要像文中的需求进行具体的代码解决方法,这些代码实现是正确的,只需要用户在自己的项目中把这些代码设计到合适的位置。我仅在文中进行了保存checkpoint文件的思路梳理。