gather 函数的声明为:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
根据gather函数的声明,可以看到gather函数主要由三个参数,分别是:
input
dim
index
下面,我结合2个具体的例子,来给出其具体的计算方法:
例子1=====================================
import torch
torch.manual_seed(100)
x=torch.randn(2,3)
index = torch.LongTensor([[0,1,1]])
a = torch.gather(x, 0, index)
此时,x的值为:
tensor([[ 0.3607, -0.2859, -0.3938],
[ 0.2429, -1.3833, -2.3134]])
a的值为:
tensor([[ 0.3607, -1.3833, -2.3134]])
下面,我们来看看其具体的计算过程:
第一步,获得index的index
index的值为[[0,1,1]],其每个元素对应的index为:
(0, 0)、(0, 1)、(0, 2)
第二步,看dim的值
在调用gather函数的时候,需要指定dim的值,此处我们的dim值为0
第三步,根据dim用index的值来替换第一步中得到的对应维度的值
因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第一个维度的值,替换后的值为:
(0, 0)、(1, 1)、(1, 2)
第四步,根据第三步得到的新的索引值,在input中进行取值
input中的(0, 0)、(1, 1)、(1, 2)分别对应了 0.3607、 -1.3833、 -2.3134,而这正是我们得到的计算结果。
链接:https://www.jianshu.com/p/2ff9a6905abe
例子2==========================================
# gather 按规则收集
input = torch.arange(15).view(3, 5)
# 生成一个查找规则:(张量b的元素都是对应张量a的索引)
index = torch.zeros_like(input)
index[1][2] = 1
index[0][0] = 1
# 3.根据维度dim开始查找:
r1 = input.gather(0, index) # dim=0
r2 = input.gather(1, index) # dim=1
print(input)
print(index)
print(r1)
print(r2)
"""
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
tensor([[5, 1, 2, 3, 4],
[0, 1, 7, 3, 4],
[0, 1, 2, 3, 4]])
tensor([[ 1, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
"""
处理过程讲解:
变量生成和修改掠过,我们只讲gather的过程
第一步,获得index的每个元素的位置坐标index
index的值为
[
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]
]
其每个元素对应的index为:
(0, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)
(1, 0)、(1, 1)、(1, 2)、(1, 3)、(1, 4)
(2, 0)、(2, 1)、(2, 2)、(2, 3)、(2, 4)
第二步,根据dim的值用index矩阵的值来替换第一步中得到的对应维度的值
因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第1个维度的值,替换后的值为:
(1, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)
(0, 0)、(0, 1)、(1, 2)、(0, 3)、(0, 4)
(0, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)
第三步,根据第二步得到的新的索引值,在input中进行取值
(1, 0)=5、(0, 1)=1、(0, 2)=2、(0, 3)=3、(0, 4)=4
(0, 0)=0、(0, 1)=1、(1, 2)=7、(0, 3)=3、(0, 4)=4
(0, 0)=0、(0, 1)=1、(0, 2)=2、(0, 3)=3、(0, 4)=4
就得到output的值如下:
[
[5, 1, 2, 3, 4],
[0, 1, 7, 3, 4],
[0, 1, 2, 3, 4]
]