更新时间:2023-05-05 GMT+08:00
分享

梯度提升树分类

概述

“梯度提升树分类”节点用于生成二分类模型,是一种基于决策树的迭代分类算法。该算法采用迭代的思想不断地构建决策树模型,每棵树都是通过梯度优化损失函数而构建,从而达到从基准值到目标值的逼近。算法思想可简单理解成:后一次模型都是针对前一次模型预测出错的情况进行修正,模型随着迭代不断地改进,从而获得比较好的预测效果。

梯度提升树分类的损失函数为对数似然损失函数,如下所示:

式中,N 表示样本数量,xi 表示样本i 的特征,yi 表示样本i 的标签,F(xi) 表示样本i 预测的标签。

输入

参数

子参数

参数说明

inputs

dataframe

inputs为字典类型,dataframe为pyspark中的DataFrame类型对象

输出

spark pipeline类型的模型

参数说明

参数

子参数

参数说明

input_features_str

-

输入的列名以逗号分隔组成的字符串,例如:

"column_a"

"column_a,column_b"

label_col

-

目标列

classifier_label_index_col

-

目标列经过标签编码后的新的列名,默认为"label_index"

classifier_feature_vector_col

-

算子输入的特征向量列的列名,默认为"model_features"

prediction_index_col

-

算子输出的预测label对应的标签列,默认为"prediction_index"

prediction_col

-

算子输出的预测label的列名,默认为"prediction"

max_depth

-

树的最大深度,默认为5

max_bins

-

最大分箱数,默认为32

min_instances_per_node

-

树节点分割时要求子节点包含的最小实例数,默认为1

min_info_gain

-

最小信息增益,默认为0

max_iter

-

最大迭代次数,默认为20

step_size

-

步长,默认为0.1

subsampling_rate

-

训练每棵树时对训练集的抽样率,默认为1.0

样例

inputs = {
    "dataframe": None  # @input {"label":"dataframe","type":"DataFrame"}
}
params = {
    "inputs": inputs,
    "b_output_action": True,
    "b_use_default_encoder": True,
    "input_features_str": "",  # @param {"label": "input_features_str", "type": "string", "required": "false", "helpTip": ""}
    "outer_pipeline_stages": None,
    "label_col": "",  # @param {"label": "label_col", "type": "string", "required": "true", "helpTip": "target label column"}
    "classifier_label_index_col": "label_index",  # @param {"label": "classifier_label_index_col", "type": "string", "required": "true", "helpTip": ""}
    "classifier_feature_vector_col": "model_features",  # @param {"label": "classifier_feature_vector_col", "type": "string", "required": "true", "helpTip": ""}
    "prediction_index_col": "prediction_index",  # @param {"label": "prediction_index_col", "type": "string", "required": "true", "helpTip": ""}
    "prediction_col": "prediction",  # @param {"label": "prediction_col", "type": "string", "required": "true", "helpTip": ""}
    "max_depth": 5,  # @param {"label": "max_depth", "type": "integer", "required": "true", "range": "(0,2147483647]", "helpTip": ""}
    "max_bins": 32,  # @param {"label": "max_bins", "type": "integer", "required": "true", "range": "(0,2147483647]", "helpTip": ""}
    "min_instances_per_node": 1,  # @param {"label": "min_instances_per_node", "type": "integer", "required": "true", "range":"(0,2147483647]", "helpTip": ""}
    "min_info_gain": 0.0,  # @param {"label": "min_info_gain", "type": "number", "required": "true", "range": "[0,none)", "helpTip": ""}
    "loss_type": "logistic",
    "max_iter": 20,  # @param {"label": "max_iter", "type": "integer", "required": "true", "range": "(0,2147483647]", "helpTip": ""}
    "step_size": 0.1,  # @param {"label": "step_size", "type": "number", "required": "true", "range": "(0,none)", "helpTip": ""}
    "subsampling_rate": 1.0  # @param {"label": "subsampling_rate", "type": "number", "required": "true", "range": "(0,1.0]", "helpTip": ""}
}
gbt_classifier____id___ = MLSGBTClassifier(**params)
gbt_classifier____id___.run()
# @output {"label":"pipeline_model","name":"gbt_classifier____id___.get_outputs()['output_port_1']","type":"PipelineModel"}

分享:

    相关文档

    相关产品