JAX scannen über ein Array mit zusätzlichem ArgumentPython

Python-Programme
Anonymous
 JAX scannen über ein Array mit zusätzlichem Argument

Post by Anonymous »

Ich verwende jax.lax.scan , um eine Systemdynamik vorwärts zu verbreiten. Parameter für die Dynamik erfassen ((

Code: Select all

dt
und k ) verwende ich Lambda , def oder carry (und tragen Sie sie in der Schleife scan ). Die Ausführungszeit ist fast gleich, wahrscheinlich weil die Dynamik einfach ist. Was ist die jax < /code> freundlichste und idiomatischste Art der Implementierung zum Backen in Parametern?

Code: Select all

import jax
import jax.numpy as jnp
import time

N = 100

def dynamics(x, u, dt, k):
vel = jnp.array([jnp.cos(x[2])*u[0], jnp.sin(x[2])*u[0], u[1]])
x_next = x + vel*dt - k*vel
return x_next

def one_step(state, u, dt, k):
x, kk = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1), x_next

# uses lamda for scan
def inner_lambda(x0, u, dt, k):
one_step_ = lambda state, input: one_step(state, u, dt, k)
state0 = (x0, 0)
return jax.lax.scan(one_step_, state0, None, length = N)

# defines a new inner function for scan
def inner_def(x0, u, dt, k):
def one_step_(state, input):
return one_step(state,u, dt, k)
return jax.lax.scan(one_step_, (x0, 0), None, length = N)

def one_step_carry(state, input):
x, kk, u, dt, k = state
x_next = dynamics(x, u[:,kk], dt, k)
return (x_next,kk+1, u, dt, k), x_next

# carry everything for scan
def carry_all(x0, u, dt, k):
return jax.lax.scan(one_step_carry, (x0, 0, u, dt, k), None, length = N)

x0 = jnp.array([2., 3., 4])
key = jax.random.PRNGKey(0)
u = jax.random.uniform(key, shape=(2, N))
dt = 0.01
k = 1e-6

# --- Benchmark helper ---
def benchmark(fn, name):
fn_jit = jax.jit(fn)

# First call (includes compilation)
carry, x_list = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()

t0 = time.time()
for i in range(1000):
# Second call (cached execution)
carry, ys = fn_jit(x0, u, dt, k)
carry[0].block_until_ready()
t1 = time.time()

print(f"{name:10s} | cached run: {t1 - t0:.6f}s")

# --- Run benchmarks ---
benchmark(inner_lambda, "lambda")
benchmark(inner_def, "def")
benchmark(carry_all, "carry_all")

# --- The result ---
lambda     | cached run: 0.677526s
def        | cached run: 0.673526s
carry_all  | cached run: 0.686203s

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post