Warum ist Array -Manipulation in Jax viel langsamer?Python

Python-Programme
Anonymous
 Warum ist Array -Manipulation in Jax viel langsamer?

Post by Anonymous »

Ich arbeite an der Umwandlung einer transformationsortigen numerischen Pipeline von Numpy in Jax, um die JIT-Beschleunigung zu nutzen. Ich habe jedoch festgestellt, dass einige grundlegende Operationen wie Broadcast_to und Moveaxis in JAX - selbst ohne JIT - mit Numpy vergleichbar sind, und selbst für große Chargengrößen wie 3.000.000, in denen ich erwarten würde, dass Jax viel schneller ist. < /P>

Code: Select all

### Benchmark: moveaxis + broadcast_to ###
NumPy: moveaxis + broadcast_to → 0.000116 s
JAX: moveaxis + broadcast_to → 0.204249 s
JAX JIT: moveaxis + broadcast_to → 0.054713 s

### Benchmark: broadcast_to only ###
NumPy: broadcast_to → 0.000059 s
JAX: broadcast_to → 0.062167 s
JAX JIT: broadcast_to → 0.057625 s
< /code>
mache ich etwas falsch? Gibt es bessere Möglichkeiten, diese Art von Manipulationen auszuführen?import timeit

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

# Base transformation matrix
M_np = np.array([[1, 0, 0, 0.5],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])

M_jax = jnp.array(M_np)

# Batch size
n = 1_000_000

print("### Benchmark: moveaxis + broadcast_to ###")

# NumPy
t_numpy = timeit.timeit(
lambda: np.moveaxis(np.broadcast_to(M_np[:, :, None], (4, 4, n)), 2, 0),
number=10
)
print(f"NumPy: moveaxis + broadcast_to → {t_numpy:.6f} s")

# JAX
t_jax = timeit.timeit(
lambda: jnp.moveaxis(jnp.broadcast_to(M_jax[:, :, None], (4, 4, n)), 2, 0).block_until_ready(),
number=10
)
print(f"JAX: moveaxis + broadcast_to → {t_jax:.6f} s")

# JAX JIT
@jit
def broadcast_and_move_jax(M):
return jnp.moveaxis(jnp.broadcast_to(M[:, :, None], (4, 4, n)), 2, 0)

# Warm-up
broadcast_and_move_jax(M_jax).block_until_ready()

t_jit = timeit.timeit(
lambda: broadcast_and_move_jax(M_jax).block_until_ready(),
number=10
)
print(f"JAX JIT: moveaxis + broadcast_to → {t_jit:.6f} s")

print("\n### Benchmark: broadcast_to only ###")

# NumPy
t_numpy_b = timeit.timeit(
lambda: np.broadcast_to(M_np[:, :, None], (4, 4, n)),
number=10
)
print(f"NumPy: broadcast_to → {t_numpy_b:.6f} s")

# JAX
t_jax_b = timeit.timeit(
lambda: jnp.broadcast_to(M_jax[:, :, None], (4, 4, n)).block_until_ready(),
number=10
)
print(f"JAX: broadcast_to → {t_jax_b:.6f} s")

# JAX JIT
@jit
def broadcast_only_jax(M):
return jnp.broadcast_to(M[:, :, None], (4, 4, n))

broadcast_only_jax(M_jax).block_until_ready()

t_jit_b = timeit.timeit(
lambda: broadcast_only_jax(M_jax).block_until_ready(),
number=10
)
print(f"JAX JIT: broadcast_to → {t_jit_b:.6f} s")

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post