如何在PyTorch中进行矩阵乘法运算?
使用numpy,我可以像这样进行简单的矩阵乘法:
a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)
然而,这在PyTorch中不起作用:
a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)
这段代码会报错如下:
RuntimeError: 期望1D张量,但得到了2D和2D张量
如何在PyTorch中执行矩阵乘法?