python argparse变量到class变量的转换代码

发布时间 2023-07-22 17:56:48作者: 颀周

  github上的项目总喜欢使用argparse + bash来运行,这对于快速运行一个项目来说可能有好处,但在debug的时候是很难受的。因为我们需要在.sh文件中修改传入参数,并且不能使用jupyter。

  以下是把parser转换成class类的一个代码示例:

#%%
import argparse

parser = argparse.ArgumentParser()
 
parser.add_argument("--get_pred",
                    action='store_true',
                    help="Whether to get prediction results.")
parser.add_argument("--get_ig_pred",
                    action='store_true',
                    help="Whether to get integrated gradient at the predicted label.")
parser.add_argument("--get_ig_gold",
                    action='store_true',
                    help="Whether to get integrated gradient at the gold label.")
parser.add_argument("--get_base",
                    action='store_true',
                    help="Whether to get base values. ")
parser.add_argument("--batch_size",
                    default=16,
                    type=int,
                    help="Total batch size for cut.")
parser.add_argument("--num_batch",
                    default=10,
                    type=int,
                    help="Num batch of an example.")

#%% 转换
def print_store_actions(store_actions, print_attrs = ['type', 'help'], need_default = True):
    if len(print_attrs) > 0:
        s = '# '
        for i in store_actions.__dir__():
            if i in print_attrs:
                s0 = str(getattr(store_actions, i))
                s0 = s0.replace('\n', ' ')
                s += s0 + ', '
        print(s[:-2])
    if need_default:
        if getattr(store_actions, 'type') == str:
            s = '# default = "' + str(getattr(store_actions, 'default')) + '"'
        else:
            s = '# default = ' + str(getattr(store_actions, 'default'))
        print(s)

def parser_2_class(parser, print_attrs = ['type', 'help'], need_default = True):
    for i in parser._actions:
        if i.option_strings[0] == '-h':
            continue
        v = '"' + i.default + '"' if i.type == str else i.default
        if len(print_attrs) == 0:
            print(i.option_strings[0][2:], '=', v, end='  ')
            print_store_actions(i, print_attrs, need_default)
        else:
            print_store_actions(i, print_attrs, need_default)
            print(i.option_strings[0][2:], '=', v)

parser_2_class(parser, ['type', 'help'], True)

  然后使用输出构建一个只包含成员变量的类,就能实现和parser获得的变量空间一样的效果,从而可以方便地debug,并且无需修改项目的其它代码。如下:

class args:
    # None, Whether to get prediction results.
    # default = False
    get_pred = False
    # None, Whether to get integrated gradient at the predicted label.
    # default = False
    get_ig_pred = False
    # None, Whether to get integrated gradient at the gold label.
    # default = False
    get_ig_gold = False
    # None, Whether to get base values. 
    # default = False
    get_base = False
    # <class 'int'>, Total batch size for cut.
    # default = 16
    batch_size = 16
    # <class 'int'>, Num batch of an example.
    # default = 10
    num_batch = 10