NaN vermeiden, weil die Optax -BFGS -Implementierung implementiert ist?Python

Python-Programme
Guest
 NaN vermeiden, weil die Optax -BFGS -Implementierung implementiert ist?

Post by Guest »

Ich bin ziemlich neu in Optax. Es gibt diese Zeile in scale_by_lbfgs , die die Gewichtsberechnung in optax/optax/_src/Transformation durchführt.

Code: Select all

def scale_by_lbfgs(
...
def update_fn(
...
weight = jnp.where(
vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates
)
Dies gibt nan , wenn vdot_diff_params_updates klein ist, aber ungleich Null. Gibt es eine gute Möglichkeit, dies zu vermeiden, z. B. Bearbeiten von Zuständen oder Gradientenkappen?

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post