Identifizieren und Entfernen von doppelten Spalten/Zeilen in einer spärlichen Binärmatrix in PytorchPython

Python-Programme
Anonymous
 Identifizieren und Entfernen von doppelten Spalten/Zeilen in einer spärlichen Binärmatrix in Pytorch

Post by Anonymous »

Nehmen wir an, wir haben eine -Binärmatrix a mit Form N x m ,
Ich möchte Zeilen identifizieren, die Duplikate in der Matrix haben. Ich verwende in Bezug auf den Speicher ziemlich groß und schwer zu handhaben.

Code: Select all

# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()
Intuitiv können wir das Punktprodukt dieser Matrix ausführen, das eine neue m x m matrix erstellt, die in den Begriffen I, j enthält, die Anzahl von 1s, die der Index i in derselben Position des Indexs des Indexs> bei Dimension 0 .
hat

Code: Select all

B = A.T @ A # In PyTorch, this operation will also produce a sparse representation
Zu diesem Zeitpunkt habe ich versucht, diese Werte zu kombinieren und sie mit A.sum (0) ,
zu vergleichen

Code: Select all

num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
num_elements[B.indices()[0]] == B.values()
])
Aber das hat nicht funktioniert!>

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post