更新时间:2021-12-17 GMT+08:00
分享

入参解析

编写的代码需能够接收服务端下发的下列参数:

  • type:表示任务类型,取值"training"、"evaluation"。training表示训练任务, evaluation表示评估任务。
  • train_url:训练模型输出路径。
  • data_url:数据集路径。
  • model_path:common model路径。
  • subModelLearningRate:学习率。
  • subModelIterateTimes:迭代次数。
def parse_args() -> argparse.Namespace:
    """
    Parse args on edge side, DO NOT modify.
    args:
    type: training or testing, directed by server
    model_path: path to store common model
    data_url: path to fetch data, usually in local
    train_url: path to save model after training
    subModelLearningRate: learning rate
    subModelIterateTimes: number of batches per local train
    """

    parser = argparse.ArgumentParser(description="Federated Learning",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--type", type=str, default="training", choices=["training", "evaluation"])
    parser.add_argument("--model_path", type=str, help="the path to store common model")
    parser.add_argument("--data_url", type=str, help="the path to fetch data")
    parser.add_argument("--train_url", type=str, help="the path to save model after training")
    parser.add_argument("--subModelLearningRate", type=float, default=0.01)
    parser.add_argument("--subModelIterateTimes", type=int, default=1)
    args, _ = parser.parse_known_args()

    return args
分享:

    相关文档

    相关产品

close