by Anonymous » 25 Aug 2025, 03:06
Ich verwende jax.lax.scan , um eine Systemdynamik vorwärts zu verbreiten. Parameter für die Dynamik erfassen ((
und k ) verwende ich Lambda , def oder carry (und tragen Sie sie in der Schleife scan ). Die Ausführungszeit ist fast gleich, wahrscheinlich weil die Dynamik einfach ist. Was ist die jax < /code> freundlichste und idiomatischste Art der Implementierung zum Backen in Parametern?
Code: Select all
import jax
import jax.numpy as jnp
import time
N = 100
def dynamics(x, u, dt, k):
vel = jnp.array([jnp.cos(x[2])*u[0], jnp.sin(x[2])*u[0], u[1]])
x_next = x + vel*dt - k*vel
return x_next
def one_step(state, u, dt, k):
x, kk = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1), x_next
# uses lamda for scan
def inner_lambda(x0, u, dt, k):
one_step_ = lambda state, input: one_step(state, u, dt, k)
state0 = (x0, 0)
return jax.lax.scan(one_step_, state0, None, length = N)
# defines a new inner function for scan
def inner_def(x0, u, dt, k):
def one_step_(state, input):
return one_step(state,u, dt, k)
return jax.lax.scan(one_step_, (x0, 0), None, length = N)
def one_step_carry(state, input):
x, kk, u, dt, k = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1, u, dt, k), x_next
# carry everything for scan
def carry_all(x0, u, dt, k):
return jax.lax.scan(one_step_carry, (x0, 0, u, dt, k), None, length = N)
x0 = jnp.array([2., 3., 4])
key = jax.random.PRNGKey(0)
u = jax.random.uniform(key, shape=(2, N))
dt = 0.01
k = 1e-6
# --- Benchmark helper ---
def benchmark(fn, name):
fn_jit = jax.jit(fn)
# First call (includes compilation)
carry, x_list = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()
t0 = time.time()
for i in range(1000):
# Second call (cached execution)
carry, ys = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()
t1 = time.time()
print(f"{name:10s} | cached run: {t1 - t0:.6f}s")
# --- Run benchmarks ---
benchmark(inner_lambda, "lambda")
benchmark(inner_def, "def")
benchmark(carry_all, "carry_all")
# --- The result ---
lambda | cached run: 0.677526s
def | cached run: 0.673526s
carry_all | cached run: 0.686203s
Ich verwende jax.lax.scan , um eine Systemdynamik vorwärts zu verbreiten. Parameter für die Dynamik erfassen (([code]dt[/code] und k ) verwende ich Lambda , def oder carry (und tragen Sie sie in der Schleife scan ). Die Ausführungszeit ist fast gleich, wahrscheinlich weil die Dynamik einfach ist. Was ist die jax < /code> freundlichste und idiomatischste Art der Implementierung zum Backen in Parametern?[code]import jax
import jax.numpy as jnp
import time
N = 100
def dynamics(x, u, dt, k):
vel = jnp.array([jnp.cos(x[2])*u[0], jnp.sin(x[2])*u[0], u[1]])
x_next = x + vel*dt - k*vel
return x_next
def one_step(state, u, dt, k):
x, kk = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1), x_next
# uses lamda for scan
def inner_lambda(x0, u, dt, k):
one_step_ = lambda state, input: one_step(state, u, dt, k)
state0 = (x0, 0)
return jax.lax.scan(one_step_, state0, None, length = N)
# defines a new inner function for scan
def inner_def(x0, u, dt, k):
def one_step_(state, input):
return one_step(state,u, dt, k)
return jax.lax.scan(one_step_, (x0, 0), None, length = N)
def one_step_carry(state, input):
x, kk, u, dt, k = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1, u, dt, k), x_next
# carry everything for scan
def carry_all(x0, u, dt, k):
return jax.lax.scan(one_step_carry, (x0, 0, u, dt, k), None, length = N)
x0 = jnp.array([2., 3., 4])
key = jax.random.PRNGKey(0)
u = jax.random.uniform(key, shape=(2, N))
dt = 0.01
k = 1e-6
# --- Benchmark helper ---
def benchmark(fn, name):
fn_jit = jax.jit(fn)
# First call (includes compilation)
carry, x_list = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()
t0 = time.time()
for i in range(1000):
# Second call (cached execution)
carry, ys = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()
t1 = time.time()
print(f"{name:10s} | cached run: {t1 - t0:.6f}s")
# --- Run benchmarks ---
benchmark(inner_lambda, "lambda")
benchmark(inner_def, "def")
benchmark(carry_all, "carry_all")
# --- The result ---
lambda | cached run: 0.677526s
def | cached run: 0.673526s
carry_all | cached run: 0.686203s
[/code]