Batch-Matrixmultiplikation in NumpyPython

Python-Programme
Guest
 Batch-Matrixmultiplikation in Numpy

Post by Guest »

Ich habe zwei Numpy-Arrays a und b der Form [5, 5, 5] bzw. [5, 5]. Für a und b ist der erste Eintrag in der Form die Stapelgröße. Wenn ich die Matrixmultiplikationsoption ausführe, erhalte ich ein Array der Form [5, 5, 5]. Ein MWE ist wie folgt.

Code: Select all

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)
Angenommen, ich würde eine Schleife über die Stapelgröße ausführen, d. h. a[0] @ b[0].T, dann würde das zu einem Array der Form [5 , 1]. Wenn ich schließlich alle Ergebnisse entlang der Achse 1 verkette, erhalte ich ein resultierendes Array mit der Form [5, 5]. Der folgende Code beschreibt diese Zeilen besser.

Code: Select all

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)
Kann ich die oben genannte Funktionalität erhalten, ohne eine Schleife zu verwenden? PyTorch bietet beispielsweise eine Funktion namens Torch.bmm, um dies zu berechnen. Danke.

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post