So skalieren Sie die Lernraten pro Parameter mit Penzai und OptaxPython

Python-Programme
Anonymous
 So skalieren Sie die Lernraten pro Parameter mit Penzai und Optax

Post by Anonymous »

Ich möchte ein einfaches Vorwärts -Neuralnetz ausbilden, das ich in Penzai eingebaut habe, aber ich möchte für jede Parametergruppe unterschiedliche Lernraten verwenden. Ich speichere den Lernrate -Skalierungsfaktor in den Metadaten des einzelnen Parameters , z. So wie dieses: < /p>

Code: Select all

Parameter(
label='mlp/Affine_0/Linear.weights',
value=,
metadata={'learning_rate': 0.0012755102040816326},
)
Ich benutze Penzais statefultrainer für das Training und deklariere den Optimierer wie diesen:

Code: Select all

optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
)
wobei ich scale_by_metadata_value wie folgt definiere:

Code: Select all

def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue))
)
return {"learning_rates": learning_rates}

def update_fn(updates, state, params):
del params
updates = jax.tree.map(
# This is where the TypeError is thrown:
lambda lr, g: lr * g, state["learning_rates"], updates
)
return updates, state

return optax.GradientTransformation(init_fn, update_fn)
< /code>
Wenn ich jedoch einen Trainingsschritt ausführe, erhalte ich < /p>
TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'
Ich bin besonders verwirrt, weil alles funktioniert, wenn ich die Zeile scale_by_metadata_value ("Learning_Rate") entferne, obwohl optax.scale_by_learning_rate (0.01) im Wesentlichen das gleiche tut wie mit scale_by_MetadaTa_Value. Der beste Weg, um zu implementieren, scale_by_metadata_value ?

Code: Select all

import penzai.toolshed.basic_training
import penzai
import penzai.pz as pz
import jax
import jax.numpy as jnp
import optax
import numpy as np

model = pz.nn.Linear(
weights=pz.Parameter(
value=pz.nx.wrap(np.ones((8, 4)), "features", "features_out"),
label="linear",
metadata={"learning_rate": 0.5},
),
in_axis_names=("features",),
out_axis_names=("features_out",),
)

def softmax_cross_entropy_loss(
model, rng, state, current_input, current_target: pz.nx.NamedArray
):
del rng, state
logits: pz.nx.NamedArray = model(current_input)
loss = jnp.sum(
optax.losses.softmax_cross_entropy(
logits.unwrap("features_out"),
current_target.unwrap("features_out"),
)
)
return (loss, None, {"softmax_cross_entropy_loss": loss})

def scale_by_metadata_value(metadata_field_name: str):
def init_fn(params):
learning_rates = jax.tree.map(
lambda param: param.metadata[metadata_field_name],
params,
is_leaf=(lambda node: isinstance(node, pz.ParameterValue)),
)
return {"learning_rates": learning_rates}

def update_fn(updates, state, params):
del params
updates = jax.tree.map(lambda lr, g: lr * g, state["learning_rates"], updates)
return updates, state

return optax.GradientTransformation(init_fn, update_fn)

trainer = penzai.toolshed.basic_training.StatefulTrainer.build(
root_rng=jax.random.key(2025),
model=model,
optimizer_def=optax.chain(
optax.scale_by_adam(),
scale_by_metadata_value("learning_rate"),
optax.scale_by_learning_rate(0.01),
),
loss_fn=softmax_cross_entropy_loss,
jit=False,
)

trainer.step(
current_input=pz.nx.wrap(np.zeros(8), "features"),
current_target=pz.nx.wrap(np.ones(4), "features_out"),
)
# TypeError: unsupported operand type(s) for *: 'jaxlib.xla_extension.ArrayImpl' and 'ParameterValue'

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post