如何在PyTorch中进行矩阵乘法运算?

10 浏览
0 Comments

如何在PyTorch中进行矩阵乘法运算?

使用numpy,我可以像这样进行简单的矩阵乘法:

a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)

然而,这在PyTorch中不起作用:

a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)

这段代码会报错如下:

RuntimeError: 期望1D张量,但得到了2D和2D张量

如何在PyTorch中执行矩阵乘法?

0