Identifizieren und Entfernen von doppelten Spalten/Zeilen in einer spärlichen Binärmatrix in Pytorch
Posted: 05 Mar 2025, 13:57
Nehmen wir an, wir haben eine -Binärmatrix a mit Form N x m ,
Ich möchte Zeilen identifizieren, die Duplikate in der Matrix haben. Ich verwende in Bezug auf den Speicher ziemlich groß und schwer zu handhaben.
Intuitiv können wir das Punktprodukt dieser Matrix ausführen, das eine neue m x m matrix erstellt, die in den Begriffen I, j enthält, die Anzahl von 1s, die der Index i in derselben Position des Indexs des Indexs> bei Dimension 0 .
hat
Zu diesem Zeitpunkt habe ich versucht, diese Werte zu kombinieren und sie mit A.sum (0) ,
zu vergleichen
Aber das hat nicht funktioniert!>
Ich möchte Zeilen identifizieren, die Duplikate in der Matrix haben. Ich verwende in Bezug auf den Speicher ziemlich groß und schwer zu handhaben.
Code: Select all
# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()
hat
Code: Select all
B = A.T @ A # In PyTorch, this operation will also produce a sparse representation
zu vergleichen
Code: Select all
num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
num_elements[B.indices()[0]] == B.values()
])