Pytorch softmax:使用哪个维度?
- 论坛
- Pytorch softmax:使用哪个维度?
9 浏览
Pytorch softmax:使用哪个维度?
函数torch.nn.functional.softmax
有两个参数:input
和dim
。根据文档,softmax操作应用于input
沿指定的dim
的所有切片,并将其重新缩放,使得元素位于范围(0, 1)
内且总和为1。
假设input为:
input = torch.randn((3, 4, 5, 6))
如果我想要以下结果,使得数组中的每个条目都为1:
sum = torch.sum(input, dim=3) # sum的大小为(3, 4, 5, 1)
我应该如何应用softmax?
softmax(input, dim=0) # 方法0 softmax(input, dim=1) # 方法1 softmax(input, dim=2) # 方法2 softmax(input, dim=3) # 方法3
我的直觉告诉我是最后一个,但我不确定。因为英语不是我的母语,所以我对“along”一词的使用感到困惑。
我对“along”一词的含义不太清楚,所以我将使用一个可以澄清事情的示例。假设我们有一个大小为(s1, s2, s3, s4)的张量,我希望发生以下情况。