温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

Pytorch中的.backward()方法怎么用

发布时间:2021-07-12 15:36:11 来源:亿速云 阅读:318 作者:chen 栏目:大数据

这篇文章主要讲解了“Pytorch中的.backward()方法怎么用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Pytorch中的.backward()方法怎么用”吧!

PyTorch的主要功能和特点之一就是backword函数,我知道一些基本的导数:

Let, F = a*b
Where,
a = 10
b = 10∂F/∂a = b => ∂F/∂a = 20
∂F/∂b = a => ∂F/∂b = 10

让我们在PyTorch中实现:

Pytorch中的.backward()方法怎么用

如果a和b是向量,那么下面的代码似乎给出了一个错误:

Pytorch中的.backward()方法怎么用

RuntimeError: grad can be implicitly created only for scalar outputs

在文档中写道:当我们调用张量的反向函数时,如果张量是非标量(即它的数据有不止一个元素)并且要求梯度,那么这个函数还需要指定特定梯度。

这里F是非标量张量所以我们需要把梯度参数传递给和张量F维数相同的反向传播函数

Pytorch中的.backward()方法怎么用

在上面的代码示例中,将梯度参数传递给backword函数并给出了所需的梯度值a和b。但是,为什么我们必须将梯度参数传递给backword函数?

要理解这一点,我们需要了解.backward()函数是如何工作的。再次提到这些文档:

torch.autograd是一个计算向量-雅可比积的引擎。即给定任意向量v,计算其乘积J@v.T注:@表示矩阵乘法

一般来说,雅可比矩阵是一个全偏导数的矩阵。如果我们考虑函数y它有n维的输入向量x它有m维的输出。然后计算包含以J表示的所有偏导数的雅可比矩阵:

Pytorch中的.backward()方法怎么用

v为backword函数提供的外梯度。另外,需要注意的另一件重要的事情是,默认情况下F.backward()与F.backward(gradient=torch.tensor([1.])相同,所以默认情况下,当输出张量是标量时,我们不需要传递梯度参数,就像我们在第一个例子中所做的那样。

当输出张量为标量时,则v_vector的大小为1,即torch.tensor([1.]),可以用值1代替。这样就得到了完整的雅可比矩阵,也就是J@v。T = J

但是,当输出张量是非标量时,我们需要传递外部梯度向量v,得到的梯度计算雅可比向量积,即J@v.T

在这里,对于F = a*b在a = [10.0, 10.0] b =[20.0, 20.0]和v =[1]。1。我们得到∂F/∂a :

Pytorch中的.backward()方法怎么用

到目前为止,我们有:

Pytorch中的.backward()方法怎么用

我们引入一个新的变量G,它依赖于F

Pytorch中的.backward()方法怎么用

到目前为止都很好,但是让我们检查一下F的grad值也就是F.grad

Pytorch中的.backward()方法怎么用

我们得到None,并显示了一个警告

The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor......

在前向传播过程中,自动动态生成计算图。对于上面的代码示例,动态图如下:

Pytorch中的.backward()方法怎么用

从上面的计算图中,我们发现张量A和B是叶节点。我们可以用is_leaf来验证:

Pytorch中的.backward()方法怎么用

Torch backward()仅在默认情况下累积叶子节点张量的梯度。因此,F grad没有值,因为F张量不是叶子节点张量。为了积累非叶子节点的梯度,我们可以使用retain_grad方法如下:

Pytorch中的.backward()方法怎么用

在一般的情况下,我们的损失值张量是一个标量值,我们的权值参数是计算图的叶子节点,所以我们不会得出上面讨论的误差条件。但是了解这些特殊的情况,这有助于了解更多关于pytorch的功能,万一那天用上了呢,对吧。

感谢各位的阅读,以上就是“Pytorch中的.backward()方法怎么用”的内容了,经过本文的学习后,相信大家对Pytorch中的.backward()方法怎么用这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是亿速云,小编将为大家推送更多相关知识点的文章,欢迎关注!

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI