关联网站:

einops官网
torch.einsum( equation , * operands ) → Tensor

对输入元素operands沿指定的维度、使用爱因斯坦求和符号的乘积求和。
参数:
-
equation ( string ) – 爱因斯坦求和的下标。
-
operands(List [ Tensor ])——计算爱因斯坦求和的张量。
Einsum允许计算许多常见的多维线性代数数组运算,方法是根据由equation给出的爱因斯坦求和约定,以速记(short-hand)格式表示它们。这种格式的细节在下面描述,但通常想法是operands 用一些下标标记输入的每个维度,并定义哪些下标是输出的一部分,operands然后通过对下标不属于输出维度的元素的乘积求和来计算输出。例如,矩阵乘法可以使用einsum计算为torch.einsum(“ij,jk->ik”, A, B)。这里,j 是求和下标,i 和 k 是输出下标(有关原因的更多详细信息,请参见下面的部分)。
equation 参数说明:
equation字符串以与维度相同的顺序指定输入的每个维度的下标( [a-z,A-Z] operands中的字母) ,用逗号 (‘,’) 分隔每个操作数的下标,例如’ij,jk’指定两个二维操作数的下标。标有相同下标的维度必须是可广播的,即它们的大小必须匹配或为1。例外情况是,如果对相同的输入操作数重复下标,在这种情况下,此操作数的标有此下标的维度必须在大小上匹配,并且操作数将被其沿这些维度的对角线替换。equation中只出现一次的下标将是输出的一部分,按字母顺序递增排序。输出是通过按元素乘以输入来计算的operands,它们的维度根据下标对齐,然后对下标不属于输出的维度求和。
或者,可以通过在等式末尾添加箭头 (->) 后跟输出下标来显式定义输出下标。例如,以下等式计算矩阵乘法的转置:‘ij,jk->ki’。对于某些输入操作数,输出下标必须至少出现一次,而对于输出则最多出现一次。
可以使用省略号 (...) 代替下标来广播省略号所涵盖的维度。每个输入操作数最多可以包含一个省略号,它将覆盖下标未覆盖的维度,例如,对于具有 5 维的输入操作数,等式“ab…c”中的省略号覆盖第三和第四维。省略号不需要覆盖operands中相同数量的维度,但省略号的“形状”(它们覆盖的维度的大小)必须一起传播。如果未使用箭头 (->) 表示法显式定义输出,则省略号将首先出现在输出(最左侧的维度)中,位于输入操作数仅出现一次的下标标签之前。例如下面的等式实现批量矩阵乘法’…ij,…jk’。
最后几点注意事项:equation可能在不同元素(下标、省略号、箭头和逗号)之间包含空格,但类似“…”的内容无效。空字符串 ’ ’ 对标量operands有效。
注:
- torch.einsum处理省略号 (‘…’) 的方式与 NumPy 不同,因为它允许对省略号覆盖的维度求和,也就是说,省略号不需要是输出的一部分。
- 此函数不会优化给定的表达式,因此用于相同计算的不同公式可能会运行得更快或消耗更少的内存。像 opt_einsum ( https://optimized-einsum.readthedocs.io/en/stable/
)这样的项目可以为你优化公式。
- 从 PyTorch 1.10 开始,还支持子列表格式(请参见下面的示例)。在这种格式中,每个操作数的下标由子列表指定,子列表是 [0, 52) 范围内的整数列表。这些子列表跟在它们的操作数之后,一个额外的子列表可以出现在输入的末尾以指定输出的下标。例如torch。einsum
(op1, sublist1, op2, sublist2, …, [subslist_out])。可以在子列表中提供Python
的Ellipsis对象,以启用广播,如上面的方程式部分所述。torch.einsum()
例:
# trace(迹) >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.4157) # diagonal(对角线) >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([ 0.0266, 2.4750, -1.0881, -1.3075]) # outer product(外积) >>> x = torch.randn(5) tensor([-0.3550, -0.6059, -1.3375, -1.5649, 0.2675]) >>> y = torch.randn(4) tensor([-0.2202, -1.5290, -2.0062, 0.9600]) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.0782, 0.5428, 0.7122, -0.3408], [ 0.1334, 0.9264, 1.2156, -0.5817], [ 0.2945, 2.0451, 2.6834, -1.2840], [ 0.3445, 2.3927, 3.1396, -1.5023], [-0.0589, -0.4089, -0.5366, 0.2568]]) # batch matrix multiplication(批量矩阵乘法) >>> As = torch.randn(3,2,5) tensor([[[-0.0306, 0.8251, 0.0157, -0.4563, 0.5550], [-1.4550, 0.0762, 0.9258, 0.1198, -1.1737]], [[-0.4460, -0.7224, 0.7260, 0.7552, 0.0326], [-0.3904, -1.2392, 0.4848, -0.4756, 0.2301]], [[ 1.5307, 0.7668, -1.9426, 1.7473, -0.6258], [ 0.6758, 1.8240, -0.2053, 0.0973, -0.6118]]]) >>> Bs = torch.randn(3,5,4) tensor([[[-0.7054, -0.2155, -1.5458, -0.8236], [-1.4957, -2.2604, 0.6897, -1.0360], [ 1.2924, 0.2798, 1.0544, 0.3656], [-0.3993, -1.2463, -0.6601, 0.2706], [ 1.0727, 0.5418, -0.2516, -0.1133]], [[ 0.4215, 1.5712, -0.2351, 1.3741], [ 1.6418, 0.9806, -1.0259, -1.1297], [ 0.7326, 0.4989, 0.4404, 0.2975], [-0.6866, 0.5696, -0.8942, 0.6815], [ 1.7486, 0.5344, 0.0538, 0.5258]], [[ 1.6280, -1.3989, -0.2900, 0.0936], [-0.9436, -0.1766, 0.6780, 0.3152], [ 0.9645, -0.1199, -1.1644, -1.0290], [-0.2791, -0.8086, 0.2161, 0.7901], [ 1.3222, -1.4023, -2.4181, -1.2875]]]) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-0.4147, -0.9847, 0.7946, -1.0103], [ 0.8020, -0.3849, 3.4942, 1.6233]], [[-1.3035, -0.5993, 0.4922, 0.9511], [-1.1150, -1.7346, 2.0142, 0.8047]], [[-1.4202, -2.5790, 4.2288, 4.5702], [-1.6549, -0.4636, 2.7802, 1.7141]]]) # with sublist format and ellipsis(带有子列表格式和省略号) >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-0.4147, -0.9847, 0.7946, -1.0103], [ 0.8020, -0.3849, 3.4942, 1.6233]], [[-1.3035, -0.5993, 0.4922, 0.9511], [-1.1150, -1.7346, 2.0142, 0.8047]], [[-1.4202, -2.5790, 4.2288, 4.5702], [-1.6549, -0.4636, 2.7802, 1.7141]]]) # batch permute(批量交换) >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear(等价于torch.nn.functional.bilinear) >>> A = torch.randn(3,5,4) >>> l = torch.randn(2,5) >>> r = torch.randn(2,4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])