更新时间:2024-12-30 GMT+08:00
分享

昇腾迁移融合算子API替换样例

部分torch原生的API在下发和执行时会包括多个小算子,下发和执行耗时较长,可以通过替换成NPU API来使能融合算子,提升训练性能。

优化器替换

替换优化器一般都能有较大的性能受益,可以优先考虑将torch原生的优化器替换为昇腾提供的亲和优化器。下文以AdamW优化器为例,其他优化器的替换方式一致。

  • torch_npu.optim.NpuFusedAdamW

torch原生代码示例如下:

import torch
optimizer = torch.optim.AdamW(
  model.parameters(),
  learning_rate,
  momentum=momentum,
  weight_decay=weight_decay
)

torch_npu代码示例如下:

import torch_npu
from torch_npu.contrib import transfer_to_npu

optimizer = torch_npu.optim.NpuFusedAdamW(
  model.parameters(),
  learning_rate,
  momentum=momentum,
  weight_decay=weight_decay
)

亲和API替换

  • optimizer.clip_grad_norm_fused_

    在替换为npu亲和梯度裁剪api之前,请确保代码中已使用npu亲和优化器。

    torch原生代码示例如下:
    import torch
    optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
    torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    optimizer = torch_npu.optim.NpuFusedAdamW(model.parameters(), lr = lr)
    optimizer.clip_grad_norm_fused_(max_norm=10, norm_type=2)
  • torch_npu.npu_confusion_transpose

    示例一

    torch原生代码示例如下:
    import torch
    
    data = torch.rand(64, 3, 64, 128).cuda()
    batch, channel, height, width = data.shape
    result = torch.permute(data, (0, 2, 1, 3)).reshape(height, batch, channel*width)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    data = torch.rand(64, 3, 64, 128).cuda()
    batch, channel, height, width = data.shape
    result = torch_npu.npu_confusion_transpose(data, (0, 2, 1, 3), (height, batch, channel*width), transpose_first=True)

    示例二

    torch原生代码示例如下:
    import torch
    
    data = torch.rand(64, 3, 64, 128).cuda()
    batch, channel, height, width = data.shape
    result = data.view(batch, height*channel*width).transpose(1, 0)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    data = torch.rand(64, 3, 64, 128).cuda()
    batch, channel, height, width = data.shape
    result = torch_npu.npu_confusion_transpose(data, (1, 0), (batch, height*channel*width), transpose_first=False)
  • torch_npu.npu_scaled_masked_softmax

    需要注意的,atten_mask和atten_scores张量最后一维的取值范围为32-8192,且必须为32的整数倍。

    torch原生代码示例如下:
    import torch
    x = torch.randn([64, 8, 128, 256]).cuda()
    mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
    scale = 0.8
    
    output = torch.softmax((x * scale).masked_fill(mask, -1*torch.inf), dim=-1)
    # shape is (64, 8, 128, 256)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    x = torch.randn([64, 8, 128, 256]).cuda()
    mask = torch.randn([1, 1, 128, 256]).cuda() >= 1
    scale = 0.8
    
    output = torch_npu.npu_scaled_masked_softmax(x, mask, scale)
    # shape is (64, 8, 128, 256)
  • torch_npu.fast_gelu

    示例一

    替换torch.nn.functional.fast_gelu方法,实现上有些差异,激活函数输出结果会不同。

    torch原生代码示例如下:
    import torch
    input_data = torch.rand(64, 32).cuda()
    result = torch.nn.functional.gelu(input_data)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    input_data = torch.rand(64, 32).cuda()
    result = torch_npu.fast_gelu(input_data)

    示例二

    继承torch.nn.GELU,基于torch_npu.fast_gelu重写forward方法。

    torch原生代码示例如下:
    import torch
    input_data = torch.rand(64, 32).cuda()
    gelu_module = torch.nn.GELU().cuda()
    result3 = gelu_module(input_data)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    # 继承torch.nn.GELU,基于torch_npu.fast_gelu重写forward方法
    class FastGelu(torch.nn.GELU):
        def forward(self, input_data):
            return torch_npu.fast_gelu(input_data)
    
    input_data = torch.rand(64, 32).cuda()
    fast_gelu_module = FastGelu().cuda()
    result = fast_gelu_module(input_data)
  • torch_npu.npu_rms_norm

    输入数据dtype仅支持float16、bfloat16、float。

    torch原生代码示例如下:
    import torch
    
    class TorchRMSNorm(torch.nn.Module):
        def __init__(self, dim: int, eps = 1e-6):
            super().__init__()
            self.eps = eps
            self.weight = nn.Parameter(torch.ones(dim)).cuda()
    
        def _norm(self, x):
            return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
        def forward(self, x):
            output = self._norm(x.float()).type_as(x)
            return output * self.weight
    
    input_data = torch.randn(128, 256).cuda()
    torch_rms_norm = TorchRMSNorm((128, 256))
    result = torch_rms_norm(data)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    class NpuRMSNorm(torch.nn.Module):
        def __init__(self, dim: int, eps = 1e-6):
            super().__init__()
            self.eps = eps
            self.weight = nn.Parameter(torch.ones(dim)).cuda()
    
        def forward(self, x):
            return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
    
    input_data = torch.randn(128, 256).cuda()
    npu_rms_norm = NpuRMSNorm((128, 256))
    result = npu_rms_norm(data)
  • torch_npu.npu_swiglu

    输入数据dtype仅支持float16、bfloat16、float。

    torch原生代码示例如下:
    import torch
    class TorchSwiGlu(torch.nn.Module):
        def __init__(self, dim = -1):
            super().__init__()
            self.dim = dim
    
        def _swiglu(self, x):
            x = torch.chunk(x, 2, -1)
            return torch.nn.functional.silu(x[0]) * x[1]
    
        def forward(self, x):
            output = self._swiglu(x)
            return output
    
    input_data = torch.randn(128, 256).cuda()
    torch_swiglu = TorchSwiGlu()
    result = torch_swiglu(data)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    class NpuSwiGlu(torch.nn.Module):
        def __init__(self, dim = -1):
            super().__init__()
            self.dim = dim
    
        def forward(self, x):
            dim = -1
            return torch_npu.npu_swiglu(x, dim=dim)
    
    input_data = torch.randn(128, 256).cuda()
    npu_swiglu = NpuSwiGlu()
    result = npu_swiglu(data)
  • torch_npu.npu_rotary_mul
    torch原生代码示例如下:
    import torch
    
    x = torch.rand([2, 8192, 5, 128]).cuda()
    r1 = torch.rand([1, 8192, 1, 128]).cuda()
    r2 = torch.rand([1, 8192, 1, 128]).cuda()
    
    def torch_func(x, r1, r2):
       x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
       # x1, x2 = torch.chunk(x, 2, -1)
       x_new = torch.cat((-x2, x1), dim=-1)
       output = r1 * x + r2 * x_new
       return output
    
    result = torch_func(x, r1, r2)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    x = torch.rand([2, 8192, 5, 128]).cuda()
    r1 = torch.rand([1, 8192, 1, 128]).cuda()
    r2 = torch.rand([1, 8192, 1, 128]).cuda()
    
    result = torch_npu.npu_rotary_mul(x, r1, r2)
  • torch_npu.npu_fusion_attention
    torch原生代码示例如下:
    import torch
    
    class TorchFlashAttention():
        def supported_op_exec(self, query, key, value, atten_mask=None):
            scale = 0.099
            qk = torch.matmul(query, key.transpose(2, 3)).mul(scale)
    
            if atten_mask is not None:
                qk.masked_fill_(atten_mask.npu(), torch.tensor(-float('inf')).npu())
            softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16)
            output = torch.matmul(softmax_res, value)
            output = output.transpose(1, 2)
            output = output.reshape(output.shape[0], output.shape[1], -1)
            return output
    
        def custom_op_exec(self, query, key, value, atten_mask=None):
            scale = 0.099
            return torch_npu.npu_fusion_attention(
                query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)
    
        def trans_BNSD2BSH(self, tensor: torch.Tensor):
            tensor = torch.transpose(tensor, 1, 2)
            tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
            return tensor
    
        def test_torch_flash_attention(self, device="npu"):
            query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0
    
            q_npu = self.trans_BNSD2BSH(query).npu()
            k_npu = self.trans_BNSD2BSH(key).npu()
            v_npu = self.trans_BNSD2BSH(value).npu()
    
            result = self.supported_op_exec(query.npu(), key.npu(), value.npu(), atten_mask=atten_mask)
            # result shape (1, 128, 4096)

    torch_npu代码示例如下:

    import torch
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    
    
    class NPUFlashAttention():
    
        def npu_exec(self, query, key, value, atten_mask=None):
            scale = 0.099
            return torch_npu.npu_fusion_attention(
                query, key, value, head_num=32, input_layout="BSH", scale=scale, atten_mask=atten_mask)
    
        def trans_BNSD2BSH(self, tensor: torch.Tensor):
            tensor = torch.transpose(tensor, 1, 2)
            tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
            return tensor
    
        def test_npu_flash_attention(self, device="npu"):
            query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            value = torch.randn(1, 32, 128, 128, dtype=torch.float16)
            atten_mask = torch.randn(1, 1, 128, 128, dtype=torch.float16).npu() >= 0
    
            q_npu = self.trans_BNSD2BSH(query).npu()
            k_npu = self.trans_BNSD2BSH(key).npu()
            v_npu = self.trans_BNSD2BSH(value).npu()
    
            result, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.npu_exec(q_npu, k_npu, v_npu, atten_mask)
            # result shape (1, 128, 4096)

相关文档