Ich möchte Zeilen identifizieren, die Duplikate in der Matrix haben. Ich verwende in Bezug auf den Speicher ziemlich groß und schwer zu handhaben.
Code: Select all
# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()
hat
Code: Select all
B = A.T @ A # In PyTorch, this operation will also produce a sparse representation
zu vergleichen
Code: Select all
num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
num_elements[B.indices()[0]] == B.values()
])