文档首页/ AI开发平台ModelArts/ 常见问题/ Standard训练作业/ 编写训练代码/ 如何在训练中加载部分训练好的参数?
更新时间:2024-10-28 GMT+08:00

如何在训练中加载部分训练好的参数?

在训练作业时,需要从预训练的模型中加载部分参数,初始化当前模型。请您通过如下方式加载:

  1. 通过如下代码,您可以查看所有的参数。
    from moxing.tensorflow.utils.hyper_param_flags import mox_flags
    print(mox_flags.get_help())
  2. 通过如下方式控制载入模型时需要恢复的参数名。其中,“checkpoint_include_patterns”为需要恢复的参数,“checkpoint_exclude_patterns”为不需要恢复的参数。
    checkpoint_include_patterns: Variables names patterns to include when restoring checkpoint. Such as: conv2d/weights.
    checkpoint_exclude_patterns: Variables names patterns to include when restoring checkpoint. Such as: conv2d/weights.
  3. 通过以下方式控制需要训练的参数列表。其中,“trainable_include_patterns”为需要训练的参数列表,“trainable_exclude_patterns”为不需要训练的参数列表。
    --trainable_exclude_patterns: Variables names patterns to exclude for trainable variables. Such as: conv1,conv2.
    --trainable_include_patterns: Variables names patterns to include for trainable variables. Such as: logits.