温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

如何解析Pytorch基础中网络参数初始化问题

发布时间:2021-12-04 18:35:07 来源:亿速云 阅读:206 作者:柒染 栏目:大数据

如何解析Pytorch基础中网络参数初始化问题,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

参数访问和遍历:

对于模型参数,我们可以进行访问;

由于Sequential由Module继承而来,所以可以使用Module钟的parameter()或者named_parameters方法来访问所有的参数;

例如,对于使用Sequential搭建的网络,可以使用下列for循环直接进行遍历:

for name, param in net.named_parameters():
    print(name, param.size())

当然,也可以使用索引来按层访问,因为本身网络也是按层搭建的:

for name, param in net[0].named_parameters():
    print(name, param.size(), type(param))

当我们获取某一层的参数信息后,可以使用data()和grad()函数来进行值和梯度的访问:

weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad) # 反向传播前梯度为None
Y.backward()
print(weight_0.grad)

参数初始化问题:

当我们参用for循环获取每层参数,可以采用如下形式对w和偏置b进行初值设定:

for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)

for name, param in net.named_parameters():
    if 'bias' in name:
        init.constant_(param, val=0)
        print(name, param.data)

当然,我们也可以进行初始化函数的自定义设置:

def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10, 10)
        tensor *= (tensor.abs() >= 5).float()

for name, param in net.named_parameters():
    if 'weight' in name:
        init_weight_(param)
        print(name, param.data)

这里注意一下torch.no_grad()的问题;

该形式表示该参数并不随着backward进行更改,常常用来进行局部网络参数固定的情况;

如该连接所示:关于no_grad()

共享参数:

可以自定义Module类,在forward中多次调用同一个层实现;

如上章节的代码所示:

class FancyMLP(nn.Module):
    def __init__(self, **kwargs):
        super(FancyMLP, self).__init__(**kwargs)
        self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数)
        self.linear = nn.Linear(20, 20)
    def forward(self, x):
        x = self.linear(x)
        # 使用创建的常数参数,以及nn.functional中的relu函数和mm函数
        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
        # 复用全连接层。等价于两个全连接层共享参数
        x = self.linear(x)
        # 控制流,这里我们需要调用item函数来返回标量进行比较
        while x.norm().item() > 1:
            x /= 2
        if x.norm().item() < 0.8:
            x *= 10
        return x.sum()

所以可以看到,相当于同时在同一个网络中调用两次相同的Linear实例,所以变相实现了参数共享;

suo'yi注意一下,如果传入Sequential模块的多层都是同一个Module实例的话,则他们共享参数;

看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注亿速云行业资讯频道,感谢您对亿速云的支持。

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI