Ich versuche, JAX für die Implementierung der Point -Cloud -Verarbeitung zu verwenden. Ich stellte jedoch fest, dass das Training aufgrund meiner Implementierung der folgenden index_points_3d Operation, die die Auswahl der Funktionen basierend auf 3D-Indizes durchführt, extrem langsam wird.import jax
import jax.numpy as jnp
@jax.jit
def index_points_3d(features, indices):
"""
Args:
features: shape (B, N, C)
indices: shape (B, npoint, nsample)
Returns:
shape (B, npoint, nsample, C)
"""
features_expanded = features[..., None, :]
idx_expanded = indices[..., None]
return jnp.take_along_axis(features_expanded, idx_expanded, axis=1)
< /code>
Als ich den Profiler aufzeichnete, stellte ich fest, dass diese Operation extreme Wiederholungen von Loop_dynamic_update_slice_fusion, Loop_add_fusion, input_Recuce_fusion und Loop_Select_fusion in der Backpropagation -Stufe in der Folge. src = "https://i.static.net/pbgiphzf.gif"/>
Der Vorwärtspass ist kein Problem, da das Lernen schnell verlief, als ich den Gradienten des Ausgabemmerkmals gestoppt habe. /> Ich bin nicht zutiefst vertraut mit Jax 'niedrigen Operationen, daher bin ich mir nicht sicher, ob dies eine grundlegende Einschränkung von Jax /XLA ist oder ob es einen effizienteren Ansatz gibt. Jede Hilfe oder Anleitung zur Optimierung dieser Operation wäre sehr geschätzt!
JAX -Punkt -Cloud -Verarbeitung: Slow Index_points_3d Operation, der extreme XLA -Fusionsschleifen in der Backpropagatio ⇐ Python
-
- Similar Topics
- Replies
- Views
- Last post