Aufrufen mehrerer Funktionen auf mehreren Zeilen eines TensorsPython

Python-Programme
Anonymous
 Aufrufen mehrerer Funktionen auf mehreren Zeilen eines Tensors

Post by Anonymous »

import torch

x = torch.ones(3, 3)

factors = [lambda x: 2*x, lambda x: 3*x, lambda x: 4*x]
indices = torch.tensor([0, 1, 2])

def multiply_row_by_factor(row, idx):
return factors[idx](row)

result = torch.vmap(multiply_row_by_factor, in_dims=(0, 0))(x, indices)

# Original Tensor
# tensor([[1., 1., 1.],
# [1., 1., 1.],
# [1., 1., 1.]])

# Desired Result
# tensor([[2., 2., 2.],
# [3., 3., 3.],
# [4., 4., 4.]])
< /code>
Wie der Titel sagt, suche ich nach einer Möglichkeit, mehrere Funktionen auf mehreren Zeilen eines Tensors aufzurufen. Ich zeige ein minimal reproduzierbares Beispiel für die Einfachheit. Mir ist bewusst, dass VMAP nur mit einer Funktion aufgerufen werden soll. Ich benutze es hier nur als Beispiel, um zu kommunizieren, was ich versuche zu tun. Dieser spezielle Ansatz funktioniert nicht, da IDX ein Charge ist. Die Funktionen hier sind Lambdas, aber in Wirklichkeit bestehen meine Funktionen aus komplexen Transformationen, die ich lieber nicht zersetzen würde, damit dies funktioniert. < /P>
Gibt es eine Möglichkeit, so etwas zu erreichen ? Etwas Reinigerer als Pytorch -Streams?

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post