更新时间:2023-11-16 GMT+08:00
分享

使用pytorch进行线性回归

在FunctionGraph页面将torch添加为公共依赖

图1 torch添加为公共依赖

在代码中导入torch并使用

# -*- coding:utf-8 -*-
import json
# 导入torch依赖
import torch as t
import numpy as np
def handler (event, context):
    print("start training!")
    train()
    print("finished!")
    return {
        "statusCode": 200,
        "isBase64Encoded": False,
        "body": json.dumps(event),
        "headers": {
            "Content-Type": "application/json"
        }
    }
 
 
def get_fake_data(batch_size=8):
    x = t.rand(batch_size, 1) * 20;
    y = x * 2 + (1 + t.randn(batch_size, 1)) * 3  
    return x, y
 
def train():
    t.manual_seed(1000)  
 
    x, y = get_fake_data()
  
    w = t.rand(1, 1) 
    b = t.zeros(1, 1)
    lr = 0.001  
 
 
    for ii in range(2000):
        x, y = get_fake_data() 
        y_pred = x.mm(w) + b.expand_as(y)
        loss = 0.5 * (y_pred - y) ** 2  
        loss = loss.sum()
 
        dloss = 1 
        dy_pred = dloss * (y_pred - y)
 
        dw = x.t().mm(dy_pred)
        db = dy_pred.sum()
        w.sub_(lr * dw)
        b.sub_(lr * db)
 
    
        if ii % 10 == 0:
            x = t.arange(0, 20).view(-1, 1)
 
            y = x.float().mm(w)+ b.expand_as(x)
            
            x2, y2 = get_fake_data(batch_size=20) 
 
            print("w=",w.item(), "b=",b.item())
分享:

    相关文档

    相关产品