首页
读书
网课
《python》目录


正文

gather 函数的声明为:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor


根据gather函数的声明,可以看到gather函数主要由三个参数,分别是:

input

dim

index

下面,我结合2个具体的例子,来给出其具体的计算方法:

例子1=====================================

import torch

torch.manual_seed(100)

x=torch.randn(2,3)

index = torch.LongTensor([[0,1,1]])

a = torch.gather(x, 0, index)

此时,x的值为:

tensor([[ 0.3607, -0.2859, -0.3938],

        [ 0.2429, -1.3833, -2.3134]])

a的值为:

tensor([[ 0.3607, -1.3833, -2.3134]])

下面,我们来看看其具体的计算过程:

第一步,获得index的index

index的值为[[0,1,1]],其每个元素对应的index为:

(0, 0)、(0, 1)、(0, 2)

第二步,看dim的值

在调用gather函数的时候,需要指定dim的值,此处我们的dim值为0

第三步,根据dim用index的值来替换第一步中得到的对应维度的值

因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第一个维度的值,替换后的值为:

(0, 0)、(1, 1)、(1, 2)

第四步,根据第三步得到的新的索引值,在input中进行取值

input中的(0, 0)、(1, 1)、(1, 2)分别对应了 0.3607、 -1.3833、 -2.3134,而这正是我们得到的计算结果。


链接:https://www.jianshu.com/p/2ff9a6905abe


例子2==========================================

# gather 按规则收集

input = torch.arange(15).view(3, 5)

# 生成一个查找规则:(张量b的元素都是对应张量a的索引)

index = torch.zeros_like(input)

index[1][2] = 1

index[0][0] = 1

# 3.根据维度dim开始查找:

r1 = input.gather(0, index)  # dim=0

r2 = input.gather(1, index)  # dim=1

print(input)

print(index)

print(r1)

print(r2)

"""

tensor([[ 0,  1,  2,  3,  4],

        [ 5,  6,  7,  8,  9],

        [10, 11, 12, 13, 14]])

tensor([[1, 0, 0, 0, 0],

        [0, 0, 1, 0, 0],

        [0, 0, 0, 0, 0]])

tensor([[5, 1, 2, 3, 4],

        [0, 1, 7, 3, 4],

        [0, 1, 2, 3, 4]])

tensor([[ 1,  0,  0,  0,  0],

        [ 5,  5,  6,  5,  5],

        [10, 10, 10, 10, 10]])

"""

处理过程讲解:

变量生成和修改掠过,我们只讲gather的过程

第一步,获得index的每个元素的位置坐标index

index的值为

[

        [1, 0, 0, 0, 0],

        [0, 0, 1, 0, 0],

        [0, 0, 0, 0, 0]

]

其每个元素对应的index为:

(0, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)

(1, 0)、(1, 1)、(1, 2)、(1, 3)、(1, 4)

(2, 0)、(2, 1)、(2, 2)、(2, 3)、(2, 4)


第二步,根据dim的值用index矩阵的值来替换第一步中得到的对应维度的值

因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第1个维度的值,替换后的值为:

(1, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)

(0, 0)、(0, 1)、(1, 2)、(0, 3)、(0, 4)

(0, 0)、(0, 1)、(0, 2)、(0, 3)、(0, 4)


第三步,根据第二步得到的新的索引值,在input中进行取值

(1, 0)=5、(0, 1)=1、(0, 2)=2、(0, 3)=3、(0, 4)=4

(0, 0)=0、(0, 1)=1、(1, 2)=7、(0, 3)=3、(0, 4)=4

(0, 0)=0、(0, 1)=1、(0, 2)=2、(0, 3)=3、(0, 4)=4

就得到output的值如下:

[

        [5, 1, 2, 3, 4],

        [0, 1, 7, 3, 4],

        [0, 1, 2, 3, 4]

]



上一篇: 没有了
下一篇: 没有了
圣贤书院