文档首页 > > AI Gallery用户指南> 预置AI算法(官方发布)>

时序预测

时序预测

分享
更新时间:2020/12/14 GMT+08:00

概述

TS_Forecast是基于经典的机器学习和深度学习模型建立的可预测一段时间内的序列的框架。用户提供时间序列数据集,选择相应的机器学习或者深度学习模型,可获得预测精度和预测结果,预测结果以csv提供,并提供tensorboard可视化。深度学习方法同时支持CPU和GPU规格进行训练。

训练

  • 算法基本信息
    • 适用场景:时间序列预测
    • 支持的框架引擎:PyTorch-1.3.0-python3.6
    • 算法输入:

      序列类型的数据集,建议以train_per:valid_per=0.8:0.2或0.9:0.1的比例进行切分,其中80%(90%)的数据训练,20%(10%)的数据验证。内置的机器学习方法只支持单维度的时间序列,深度学习方法可支持多维度的时间序列。

      需要预测的长度。
      训练输入目录
        |- 时间序列数据.csv
      
      时间序列数据.csv的格式
      | 时间列  | 预测数据1  | 预测数据2   | 预测数据n  |
      | ------ | --------- | --------- | --------- |
      | ------ | --------- | --------- | --------- |
    • 算法输出:

      输出从验证数据后开始的一段时间内预测结果,预测结果以csv的格式存储。例如:如果是“0.8:0.2”,则从(0.8+0.2)的数据后开始预测,如果是“0.6:0.2”,则从(0.6+0.2)的数据后开始预测。

  • 训练参数

    名称

    默认值

    类型

    是否必填

    是否可修改

    描述

    mode

    lstm

    string

    模型类型。支持rnn、lstm、gru、lstnet、attention、arima、holt_winters。

    不同模型类型的特定参数请参见:

    forecast_period

    1

    int

    预测长度。

    date_frequency

    D

    string

    日期频率,可设置为年/季度/月/周/日/时/分/秒/毫秒/微秒。

    render

    true

    bool

    是否可视化预测结果曲线。

    train_per

    0.8

    float

    训练数据占比。

    valid_per

    0.2

    float

    验证数据占比。

    nan_deal

    zero

    string

    缺失数据填补方式,仅在数据有缺失时有意义。

    epochs

    10

    int

    迭代次数,仅在深度学习模型时有意义。

    step

    true

    bool

    是否单步预测,仅在深度学习模型时有意义。

    interval

    0.8

    float

    置信区间间隔,holt_winters不支持该参数。

    gpu

    1

    int

    GPU数目,仅在深度学习模型时有意义。

    表1 arima特定参数

    名称

    默认值

    类型

    是否必填

    描述

    order

    (0, 0, 0)

    tuple

    括号中的值分别填写自回归AR项、差分系数、移动平均MA项。数值之间请使用英文逗号隔开。

    seasonal_order

    (1, 1, 0, 12)

    tuple

    括号中的值分别填写自回归AR项、差分系数、移动平均MA项、季节性。数值之间请使用英文逗号隔开。

    auto

    false

    bool

    是否自动确定order和seasonal_order参数,不需要手工指定。

    表2 holt-winters特定参数

    名称

    默认值

    类型

    是否必填

    描述

    trend

    add

    string

    趋势项,支持add, mul、additive、multiplicative、None。

    seasonal

    add

    string

    季节项,支持add、mul、additive、multiplicative、None。

    seasonal_periods

    12

    int

    季节性。

    表3 rnn、lstm、gru、lstnet、attention默认参数

    名称

    默认值

    类型

    是否必填

    描述

    window

    168

    int

    滑动窗口大小。

    hidden_rnn

    100

    int

    隐藏层单元数目。

    n_layers

    1

    int

    网络层数。

    optim

    adam

    string

    优化器,支持adam、sgd。

    lr

    0.0001

    float

    学习率。

    batch_size

    128

    int

    批次大小。

    dropout

    0.1

    float

    dropout数值。

    loss

    mae

    string

    损失函数类型,支持mae、rmse。

    lstnet详细说明请参见论文地址

    表4 lstnet特定参数

    名称

    默认值

    类型

    是否必填

    描述

    hidden_cnn

    100

    int

    cnn单元数目。

    hidden_skip

    10

    int

    recurrent-skip项的输出单元数目。

    skip

    24

    int

    recurrent-skip项。

    cnn_kernel

    6

    int

    cnn-kernel的大小。

    highway_window

    24

    int

    highway网络的窗口大小。

    attention详细说明请参见论文地址

    表5 attention特定参数

    名称

    默认值

    类型

    是否必填

    描述

    k_dim

    64

    int

    false

    key维度。

    v_dim

    64

    int

    false

    value维度。

    n_head

    8

    int

    false

    head数目。

  • 训练输出文件

    训练完成后的输出文件如下

    训练输出目录
      |- model.pt
      |- forecast.csv
      |- events.out.tfevents.xxx

GPU/CPU推理

暂不支持

分享:

    相关文档

    相关产品