更新时间:2021-03-18 GMT+08:00
分享

vec_axpy

功能说明

矢量每个element与标量求积后累加:

函数原型

vec_axpy(mask, dst, src, scalar, repeat_times, dst_rep_stride, src_rep_stride)

参数说明

请参见参数说明

dst/src/scalar支持的数据类型为:Tensor(float16,float32) ,且src和scalar操作数的类型需要保持一致 。

该接口支持的精度组合如下:

表1 支持的精度组合

类型

src.dtype

scalar.dtype

dst.dtype

并行度PAR/repeat

fp16

float16

float16

float16

128

fp32

float32

float32

float32

64

fmix

float16

float16

float32

64

返回值

注意事项

  • 请参见注意事项
  • 注意存在混合精度fmix的支持。
  • fmix模式下,src每次迭代仅选取前4个block参与计算。

调用示例

from te import tik
tik_instance = tik.Tik()
src_ub = tik_instance.Tensor("float16", (128,), name="src_ub", scope=tik.scope_ubuf)
scalar = 2
dst_ub = tik_instance.Tensor("float32", (128,), name="dst_ub", scope=tik.scope_ubuf)
tik_instance.vec_axpy(64, dst_ub, src_ub, scalar, 1, 8, 4)
分享:

    相关文档

    相关产品

close