欢迎来到代码驿站!

Python代码

当前位置:首页 > 软件编程 > Python代码

浅谈Pytorch中的torch.gather函数的含义

时间:2020-10-14 10:03:47|栏目:Python代码|点击:

pytorch中的gather函数

pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。

立个flag开始学习pytorch,新开一个分类整理学习pytorch中的一些踩到的泥坑。

今天刚开始接触,读了一下documentation,写一个一开始每太搞懂的函数gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

观察它的输出结果:

 1 2 3
 4 5 6
[torch.FloatTensor of size 2x3]


 1 2
 6 4
[torch.FloatTensor of size 2x2]


 1 5 6
 1 2 3
[torch.FloatTensor of size 2x3]

这里是官方文档的解释

torch.gather(input, dim, index, out=None) → Tensor

 Gathers values along an axis specified by dim.

 For a 3-D tensor the output is specified by:

 out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

 Parameters: 

  input (Tensor) ?C The source tensor
  dim (int) ?C The axis along which to index
  index (LongTensor) ?C The indices of elements to gather
  out (Tensor, optional) ?C Destination tensor

 Example:

 >>> t = torch.Tensor([[1,2],[3,4]])
 >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
  1 1
  4 3
 [torch.FloatTensor of size 2x2]

可以看出,gather的作用是这样的,index实际上是索引,具体是行还是列的索引要看前面dim 的指定,比如对于我们的栗子,【1,2,3;4,5,6,】,指定dim=1,也就是横向,那么索引就是列号。index的大小就是输出的大小,所以比如index是【1,0;0,0】,那么看index第一行,1列指的是2, 0列指的是1,同理,第二行为4,4 。这样就输入为【2,1;4,4】,参考这样的解释看上面的输出结果,即可理解gather的含义。

gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果,这也是gather可能的一个作用。

上一篇:Python列表list操作符实例分析【标准类型操作符、切片、连接字符、列表解析、重复操作等】

栏    目:Python代码

下一篇:python实现网页自动签到功能

本文标题:浅谈Pytorch中的torch.gather函数的含义

本文地址:http://www.codeinn.net/misctech/10987.html

推荐教程

广告投放 | 联系我们 | 版权申明

重要申明:本站所有的文章、图片、评论等,均由网友发表或上传并维护或收集自网络,属个人行为,与本站立场无关。

如果侵犯了您的权利,请与我们联系,我们将在24小时内进行处理、任何非本站因素导致的法律后果,本站均不负任何责任。

联系QQ:914707363 | 邮箱:codeinn#126.com(#换成@)

Copyright © 2020 代码驿站 版权所有