Pytorch softmax:使用哪个维度?

9 浏览
0 Comments

Pytorch softmax:使用哪个维度?

函数torch.nn.functional.softmax有两个参数:inputdim。根据文档,softmax操作应用于input沿指定的dim的所有切片,并将其重新缩放,使得元素位于范围(0, 1)内且总和为1。

假设input为:

input = torch.randn((3, 4, 5, 6))

如果我想要以下结果,使得数组中的每个条目都为1:

sum = torch.sum(input, dim=3) # sum的大小为(3, 4, 5, 1)

我应该如何应用softmax?

softmax(input, dim=0) # 方法0
softmax(input, dim=1) # 方法1
softmax(input, dim=2) # 方法2
softmax(input, dim=3) # 方法3

我的直觉告诉我是最后一个,但我不确定。因为英语不是我的母语,所以我对“along”一词的使用感到困惑。

我对“along”一词的含义不太清楚,所以我将使用一个可以澄清事情的示例。假设我们有一个大小为(s1, s2, s3, s4)的张量,我希望发生以下情况。

0