更新时间:2021-12-17 GMT+08:00
入参解析
编写的代码需能够接收服务端下发的下列参数:
- type:表示任务类型,取值"training"、"evaluation"。training表示训练任务, evaluation表示评估任务。
- train_url:训练模型输出路径。
- data_url:数据集路径。
- model_path:common model路径。
- subModelLearningRate:学习率。
- subModelIterateTimes:迭代次数。
def parse_args() -> argparse.Namespace: """ Parse args on edge side, DO NOT modify. args: type: training or testing, directed by server model_path: path to store common model data_url: path to fetch data, usually in local train_url: path to save model after training subModelLearningRate: learning rate subModelIterateTimes: number of batches per local train """ parser = argparse.ArgumentParser(description="Federated Learning", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--type", type=str, default="training", choices=["training", "evaluation"]) parser.add_argument("--model_path", type=str, help="the path to store common model") parser.add_argument("--data_url", type=str, help="the path to fetch data") parser.add_argument("--train_url", type=str, help="the path to save model after training") parser.add_argument("--subModelLearningRate", type=float, default=0.01) parser.add_argument("--subModelIterateTimes", type=int, default=1) args, _ = parser.parse_known_args() return args
父主题: 算法接口
