在PyTorch中有四種類型的乘法運算(位置乘法、點積、矩陣與向量乘法、矩陣乘法),非常容易搞混,我們一起來看看這四種乘法運算的區別。
位置乘法
先構建兩個張量a,b他們都是4行5列。
a = torch.arange(20).reshape([4,5])
b = torch.randn([4,5])

位置乘法,顧名思義就是將兩個張量對應位置的元素進行乘法運算,運算符是*。
可以是兩個張量相乘,也可以是標量和張量相乘。
標量與張量相乘,是用標量與張量的每個元素相乘,結果張量的形狀不變。
4 * a

兩個張量相乘,是對應位置的元素相乘,結果張量的形狀不變。
a * b

點積
點積是兩個向量(也就是一維張量)對應位置的元素相乘后求和,結果是一個標量,使用dot函數進行計算。
先構建兩個向量a、b,點積操作要求兩個向量的數據類型要一致,因此a中指定數據類型為float。
a = torch.arange(6, dtype=torch.float32)
b = torch.ones(6)

執行點積操作,結果是一個標量。
torch.dot(a,b)

矩陣與向量乘法
矩陣(二維張量)與向量(一維張量)的乘法是將矩陣的每一行與向量進行點積,要求矩陣的列維數與向量的維數相同,結果的維數與行數相同。
使用mv函數進行運算。
構建一個4行5列的矩陣和一個維數為5的向量。
a = torch.arange(20,dtype=torch.float32).reshape([4,5])
b = torch.ones(5)

使用mv函數相乘后,結果是維數為4的向量。
torch.mv(a,b)

矩陣乘法
矩陣(二維張量)乘法是用第一個矩陣的行向量與第二個矩陣的列向量進行點積,要求第一個矩陣的列數與第二個矩陣的行數相同。
使用mm函數進行運算。
構建兩個矩陣,一個4行5列,一個5行6列
a = torch.arange(20,dtype=torch.float32).reshape([4,5])
b = torch.randn([5,6])

使用mm函數相乘后,結果是4行6列的矩陣。
torch.mm(a,b)
