AI开发平台ModelArtsAI开发平台ModelArts

更新时间:2021/03/18 GMT+08:00
分享

强化学习

概述

该预置算法中,为用户提供了常用的强化学习算法,目前支持五种常用算法(DQN、PPO、A2C、IMPALA、APEX)。用户订阅之后,选择算法只需要设置对应的参数,即可方便地创建训练作业,开始训练相应的强化学习环境(内置环境或是自定义环境)。训练后保存的模型可直接在ModelArts平台部署在线服务进行推理。

单击此处订阅算法。

训练

  • 算法基本信息
    • 任务类型:强化学习
    • 支持的框架引擎:Tensorflow-1.13.1-python3.6&Ray-0.7.4
    • 内置及支持环境:classic_control, atari, toy_text
    • 算法输入:
      • 如果使用内置环境,设置相应环境参数即可,无需输入数据,选择一个空文件夹或位置即可。
      • 如果使用自定义的环境,输入为用户实现自定义环境的代码所在的文件夹路径信息,目录结构示例如下。

        训练输入目录

        |- __init__.py
        |- custom_env_name
          |- custom_env.py
          |- custom_model.py

        __init__.py控制从custom_env和custom_model中导入模块,如下所示。

        from custom_env.custom_env import create_custom_env
        from custom_env.custom_model import CustomNetwork
        __all__=["create_custom_env","CustomNetwork"]

        custom_env.py中为gym.Env类型的环境和create_custom_env函数,如下所示。

        class CusEnv(gym.Env):
            ...
        def create_custom_env(env_config):
            custom_env=CusEnv()
            return custom_env

        custom_model.py用来自定义模型,如下所示。

        class CustomNetwork(Model):
            ...
    • 算法输出:

      训练保存的模型、参数文件、checkponit等文件。

  • 训练参数

    名称

    默认值

    类型

    是否必填

    是否可修改

    描述

    use_preset_env

    TRUE

    bool

    是否选用内置环境。TRUE表示使用内置环境,FALSE表示不使用。

    preset_env_id

    CartPole-v0

    String

    所用内置环境ID,只有use_preset_env参数为TRUE时才生效。ModelArts支持的内置环境ID请参见表1

    use_custom_model

    FALSE

    bool

    是否使用自定义模型。用户自定义环境,同时需要使用自定义模型时,需要设置为TRUE。

    stop_criterion

    timesteps_total

    string

    训练停止准则,目前支持四种:timesteps_total、episodes_total、time_total_s、episode_reward_mean。

    stop_value

    100000

    int

    停止准则对应的次数。如果设置不同的停止准则,该参数也应设置为对应的合适值。请参见表2

    checkpoint_freq

    20

    int

    checkpoint保存的频率。

    rl_toolkit

    rllib

    string

    选取使用的框架,目前仅支持rllib。

    rl_algorithm

    dqn

    string

    选取训练算法,支持的算法有:dqn、ppo、impala、a2c、apex。

    data_url

    None

    string

    数据存储的位置,数据是指自定义环境代码文件。

    train_url

    None

    string

    训练结果输出的地址。

    log_url

    None

    string

    训练日志保存地址。

    num_gpus

    0

    int

    GPU数目。

    num_cpus

    1

    int

    CPU数目。

    redis_address

    localhost:6379

    string

    redis数据库地址信息。

    model_name

    None

    string

    自定义模型名字。

    gamma

    0.99

    float

    折扣率。

    lr

    0.00003

    float

    学习率。

    sample_batch_size

    64

    int

    采样样本大小。

    train_batch_size

    64

    int

    训练样本大小,如果修改默认值,需要保证训练样本的值大于等于采样样本的值。

    表1 支持的内置环境

    类型

    说明

    支持的内置环境ID

    Algorithms

    不同的算法任务

    Copy-v0、RepeatCopy-v0、ReversedAddition-v0、ReversedAddition3-v0、DuplicatedInput-v0、Reverse-v0

    Atari-Games

    不同的Atari电子游戏环境

    AirRaid-v0、Centipede-v0、Pitfall-v0、Venture-v0、CrazyClimber-v0、ElevatorAction-v0、Assault-v0、Pooyan-v0、Gravitar-v0、BankHeist-v0、Tennis-v0、Alien-v0、Atlantis-v0、DoubleDunk-v0、Adventure-v0、Solaris-v0、Bowling-v0、SpaceInvaders-v0、Boxing-v0、Robotank-v0、IceHockey-v0、KungFuMaster-v0、Freeway-v0、Krull-v0、Pong-v0、MsPacman-v0、Defender-v0、Phoenix-v0、Enduro-v0、WizardOfWor-v0、DemonAttack-v0、Riverraid-v0、Asterix-v0、Berzerk-v0、BeamRider-v0、Amidar-v0、Jamesbond-v0、MontezumaRevenge-v0、UpNDown-v0、JourneyEscape-v0、Skiing-v0、StarGunner-v0、VideoPinball-v0、BattleZone-v0、Tutankham-v0、RoadRunner-v0、Carnival-v0、Zaxxon-v0、Hero-v0、ChopperCommand-v0、NameThisGame-v0、Asteroids-v0、Qbert-v0、Frostbite-v0、YarsRevenge-v0、PrivateEye-v0、Seaquest-v0、TimePilot-v0、Gopher-v0、FishingDerby-v0、Breakout-v0、Kangaroo-v0

    Classic control

    经典强化学习文献中不同的控制理论任务

    CartPole-v0、CartPole-v1、MountainCar-v0、MountainCarContinuous-v0、Pendulum-v0、Acrobot-v1、LunarLander-v2、LunarLanderContinuous-v2、BipedalWalker-v3、BipedalWalkerHardcore-v3、CarRacing-v0

    Toy text

    简单地基于文本的玩具环境

    Blackjack-v0、FrozenLake-v0、FrozenLake8x8-v0、NChain-v0、Roulette-v0、Taxi-v3、GuessingGame-v0、HotterColder-v0

    表2 stop_criterion及stop_value参数设置说明

    stop_criterion参数

    含义

    设置stop_value说明

    timesteps_total

    总的时间步

    默认1000次环境交互调用优化器更新一次参数,该参数设置总的时间步,建议设置1000的整数倍。

    episodes_total

    总的训练episode

    训练周期数和参数更新次数不对应,更新参数以时间步为基准,不同环境,在相同时间步对应episode也不相同。

    time_total_s

    总时长

    单位秒,每一次迭代对应一定的时长,该参数设置总的训练时长。

    episode_reward_mean

    每个episode获得平均奖励

    对于不同环境,获得奖励的数值、难易程度、取值范围等都不同,需要根据具体环境设置合适的stop_value。

  • 训练输出文件

    训练完成后的输出文件如下

    |- Algorithm_Env_Name
      |- checkpoint_n
        |- checkpoint-n.tune_mettadata
        |- checkpoint-n
        |- params.okl
      |- params.pkl
      |- params.json
      |- progress.csv
      |- result.json
    
    |- model
      |- variables
        |- variables.data-00000-of-00001
        |- variables.index
      |- customize_service.py
      |- config.json
      |- saved_model.pb
    |- experiment_state.json

推理(目前仅支持在线服务)

  • 模型导入:“元模型来源”选择“从训练中选择”,选择训练作业及版本,系统会自动读取对应路径下的Python文件信息。
  • 推理输入:模型推理时,输入需要满足一定的格式要求,需要是JSON格式体。不同的算法,所需要的参数也不相同,使用时要结合具体情况。

    支持的五种算法不同的格式输入示例。

    • 对于不同的环境,其环境状态不同,对应的“observations”shape维度就不同,应用推理时需要明确,根据具体情况设置输入,下面示例中给出的是“CartPole-v0”环境对应的“observations”数据。
    • 在线部署推理仅为单步预测,如果希望实现完整的游戏推理,需要使用训练保存的模型,编写相应的推理程序。
    • DQN算法:
      {
      "seq_lens": [0],
      "is_training": "False",
      "observations": [[0.02160617659181513, 0.022702596137138, -0.026817784180154104, 0.04523499832683836]]
      }
    • PPO算法:
      {
      "is_training": "False",
      "observations": [[0.02160617659181513, 0.022702596137138, -0.026817784180154104, 0.04523499832683836]],
      "prev_action": [0],
      "prev_reward": [0.0],
      "seq_lens": [0]
      }
    • IMPALA算法:
      {
      "is_training": "False",
      "observations": [[0.02160617659181513, 0.022702596137138, -0.026817784180154104, 0.04523499832683836]],
      "prev_action": [0],
      "prev_reward": [0.0],
      "seq_lens": [0]
      }
    • A2C算法:
      {
      "is_training": "False",
      "observations": [[0.02160617659181513, 0.022702596137138, -0.026817784180154104, 0.04523499832683836]],
      "prev_action": [0],
      "prev_reward": [0.0],
      "seq_lens": [0]
      }
    • APEX算法:
      {
      "seq_lens": [0],
      "is_training": "False",
      "observations": [[0.02160617659181513, 0.022702596137138, -0.026817784180154104, 0.04523499832683836]]
      }
  • 推理输出

    推理输出内容示例如下所示。

    • DQN and APEX算法:
      {
      "actions": [1],
      "q_values": [0.1, 0.9]
      }
    • PPO and A2C算法:
      {
      "actions": [1],
      "action_prob": [0.50440347194],
      "action_logp": [-0.6843788027],
      "vf_preds": [0.005004595965]
      }
    • IMPALA算法:
      {
      "actions": [1],
      "action_prob": [0.50440347194],
      "action_logp": [-0.6843788027],
      "behaviour_logits": [0.005004595965, 0.0411932617]
      }

案例

分享:

    相关文档

    相关产品