Page 1 of 1

Warum wechselt mein JAX -Skript zufällig zwischen 6s und 14S -Ausführungszeiten in identischen Umgebungen?

Posted: 17 Mar 2025, 00:34
by Anonymous
Zuallererst möchte ich sagen, dass ich wirklich schlecht und neu in der Codierung bin. Normalerweise läuft mein Drehbuch in 14 Sekunden, aber nachdem er die Umgebung kopiert und einen Jupyter -Kernel installiert hat, läuft es manchmal stattdessen in 6 Sekunden. Das Problem ist, dass dieses Verhalten nicht konsistent ist - manchmal wird die kopierte Umgebung auf 14 Sekunden festgehalten.

Code: Select all

from typing import Callable

import diffrax
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping

jax.config.update("jax_enable_x64", True)

from diffrax import diffeqsolve, ODETerm, SaveAt, Kvaerno5, Kvaerno3, Dopri8, Tsit5, PIDController
import interpax
from jaxopt import Bisection
from jax import jit
import equinox as eqx
from scipy.optimize import brentq

eps = 1e-18
N = 0

# **Global Parameters**
d = 1.0
t0 = 0.0   # Initial time
t_f = 6.0
dt0 = 1e-15  # Step size
tol = 1e-6  # Convergence tolerance for η
max_iters = 50  # Max iterations
damping = 0.3  # ✅ Small damping factor, for damping < 0.5 you stay closer to the original, for damping > 0.5 you go closer to the new
n = 1000

# **Define the Differential Equation**
@jit
def vector_field(t, u, args):
f, y, u_legit = u
α, β = args
d_f = y
d_y = (- α * f * ( y + 1 ) ** 2 + β * ( t + eps ) * y * ( y + 1 ) ** 2
- (N - 1) * ( 1 + y ) ** 2 * ( t + eps + f ) ** (-2) * ((t + eps) * y - f))
d_u_legit = f
return d_f, d_y, d_u_legit

# **General Solver Function**

@jit
def solve_ode(y_init, alpha, beta, solver=Tsit5()):
"""Solve the ODE for a given initial condition y_init, suppressing unwanted errors."""
try:
term = ODETerm(vector_field)
y0 = (y_init * eps, y_init, 0.0)
args = (alpha, beta)
saveat = SaveAt(ts=jnp.linspace(t0, t_f, n))

sol = diffeqsolve(
term, solver, t0, t_f, dt0, y0, args=args, saveat=saveat,
stepsize_controller=PIDController(rtol=1e-16, atol=1e-19),
max_steps=1_000_000  # Adjust if needed
)
return sol
except Exception as e:
if "maximum number of solver steps was reached" in str(e):
return None  # Ignore this error silently
else:
raise  # Let other errors pass through

def runs_successfully(y_init, alpha, beta):
"""Returns True if ODE solver runs without errors, False otherwise."""
try:
sol = solve_ode(y_init, alpha, beta)
if sol is None:
return False
return True
except Exception:  # Catch *any* solver failure quietly
return False  # Mark this initial condition as a failure

# Track previous successful y_low values
previous_y_lows = []

def bisection_search(alpha, beta, y_low=-0.35, y_high=-0.40, tol=1e-12, max_iter=50):
"""Find boundary value y_low where solver fails, with adaptive bounds."""

global previous_y_lows
cache = {}  # Cache to avoid redundant function calls

def cached_runs(y):
"""Check if solver runs successfully, with caching."""
if y in cache:
return cache[y]
result = runs_successfully(y, alpha, beta)
cache[y] = result
return result

# Ensure we have a valid starting point
if not cached_runs(y_low):
raise ValueError(f"❌ Lower bound y_low={y_low:.12f} must be a running case!")
if cached_runs(y_high):
raise ValueError(f"❌ Upper bound y_high={y_high:.12f} must be a failing case!")

# **Adaptive Search Range**
if len(previous_y_lows) > 2:
recent_changes = [abs(previous_y_lows[i] - previous_y_lows[i - 1]) for i in range(1, len(previous_y_lows))]
max_change = max(recent_changes) if recent_changes else float('inf')

# **Shrink range dynamically based on recent root stability**
shrink_factor = 0.5  # Shrink range by a fraction
if max_change <  0.01:  # If the root is stabilizing
range_shrink = shrink_factor * abs(y_high - y_low)
y_low = max(previous_y_lows[-1] - range_shrink, y_low)
y_high = min(previous_y_lows[-1] + range_shrink, y_high)
print(f"🔄 Adaptive Bisection: Narrowed range to [{y_low:.12f}, {y_high:.12f}]")

print(f"⚡ Starting Bisection with Initial Bounds: [{y_low:.12f}, {y_high:.12f}]")

for i in range(max_iter):
y_mid = (y_low + y_high) / 2
success = cached_runs(y_mid)

print(f"🔎 Iter {i+1}: y_low={y_low:.12f}, y_high={y_high:.12f}, y_mid={y_mid:.12f}, Success={success}")

if success:
y_low = y_mid
else:
y_high = y_mid

if abs(y_high - y_low) < tol:
break

previous_y_lows.append(y_low)

# Keep only last 5 values to track root changes efficiently
if len(previous_y_lows) > 5:
previous_y_lows.pop(0)

print(f"✅ Bisection Finished: Final y_low = {y_low:.12f}\n")

return y_low

# **Compute Anomalous Dimension (Direct Version)**
def compute_anomalous(eta_current):
"""Compute anomalous dimension given the current eta using ODE results directly."""
alpha = (d / 2 + 1 - eta_current / 2) / (1 - eta_current / (d + 2))
beta = (d / 2 - 1 + eta_current / 2) / (1 - eta_current / (d + 2))

# Find y_low
y_low = bisection_search(alpha, beta)

# Solve ODE
sol = solve_ode(y_low, alpha, beta)
if sol is None:
print(f"❌ Solver failed for y_low = {y_low:.9f}. Returning NaN for eta.")
return float('nan')

# Extract values directly from the ODE solution
x_vals = jnp.linspace(t0, t_f, n)
f_p = sol.ys[0]  # First derivative f'
d2fdx2 = sol.ys[1]  # Second derivative f''
potential = sol.ys[2]  # Potential function V(x)

spline = interpax.CubicSpline(x_vals, f_p, bc_type='natural')
spline_derivative = interpax.CubicSpline(x_vals, d2fdx2, bc_type='natural')

root_x = brentq(spline, a=1e-4, b=5.0, xtol=1e-12)

U_k_pp = spline_derivative(root_x)

third_spline = jax.grad(lambda x: spline_derivative(x))
U_k_3p = third_spline(root_x)

spline_points_prime = [spline(x) for x in x_vals]
spline_points_dprime = [spline_derivative(x) for x in x_vals]

print(f"📌 Root found at x = {root_x:.12f}")
print(f"📌 Derivative at rho_0 = {root_x:.12f} is f'(x) = {U_k_pp:.12f}")
print(f" This should be zero {spline(root_x)}")

# Compute new eta (anomalous dimension)
eta_new = U_k_3p ** 2 / (1 + U_k_pp) ** 4

# Debugging: Check if eta_new is NaN or out of range
if jnp.isnan(eta_new) or eta_new < 0: # Original : if jnp.isnan(eta_new) or eta_new < 0 or eta_new > 1
print(f"⚠ Warning: Unphysical eta_new={eta_new:.9f}.  Returning NaN.")
return float('nan')

# **Plot Results**
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True)

axs[0].plot(x_vals, f_p, color='blue', label="First Derivative (f')")
axs[0].plot(x_vals, spline_points_prime, color='orange', linestyle='--' ,label="First Splined Derivative (f')")
axs[0].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[0].set_ylabel("f'(x)")
axs[0].legend()
axs[0].set_title("First Derivative of f(x)")

axs[1].plot(x_vals, d2fdx2, color='green', label="Second Derivative (f'')")
axs[1].plot(x_vals, spline_points_dprime, color='orange', linestyle='--', label="Second Splined Derivative (f'')")
axs[1].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[1].set_ylabel("f''(x)")
axs[1].legend()
axs[1].set_title("Second Derivative of f(x)")

axs[2].plot(x_vals, potential, color='black', label="Potential (V(x))")
axs[2].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[2].set_xlabel("x")
axs[2].set_ylabel("V(x)")
axs[2].legend()
axs[2].set_title("Potential V(x)")

plt.tight_layout()
plt.show()

return float(eta_new)

# **Iterate to Find Self-Consistent η with Explicit Damping**
def find_self_consistent_eta(eta_init=1.06294761356, tol=1e-4, max_iters=1):
"""Iterate until η converges to a self-consistent value using damping."""
eta_current = eta_init
prev_values = []

for i in range(max_iters):
eta_calculated = compute_anomalous(eta_current)

# Check for NaN or invalid values
if jnp.isnan(eta_calculated):
print(f"❌ Iteration {i+1}: Computed NaN for eta. Stopping iteration.")
return None

# Apply damping explicitly
eta_next = (1 - damping) * eta_current + damping * eta_calculated

# Debugging prints
print(f"Iteration {i+1}:")
print(f"  - η_current (used for ODE) = {eta_current:.12f}")
print(f"  - η_calculated (new from ODE) = {eta_calculated:.12f}")
print(f"  - η_next (damped for next iteration) = {eta_next:.12f}\n")

# Detect infinite loops (if η keeps bouncing between a few values)
if eta_next in prev_values:
print(f"⚠ Warning: η is cycling between values. Consider adjusting damping.")
return eta_next
prev_values.append(eta_next)

# Check for convergence
if abs(eta_next - eta_current) < tol:
print("✅ Converged!")
return eta_next

eta_current = eta_next

print("⚠ Did not converge within max iterations.")
return eta_current

import sys
import os
import time

# Redirect stderr to null (suppress error messages)
sys.stderr = open(os.devnull, 'w')

# Start timer
start_time = time.time()

# Run function (using `jax.jit`, NOT `filter_jit`)
final_eta = find_self_consistent_eta()

# End timer
end_time = time.time()

# Reset stderr back to normal
sys.stderr = sys.__stderr__

# Print results
print(f"\n🎯 Final self-consistent η: {final_eta:.12f}")
print(f"⏱ Execution time: {end_time - start_time:.4f} seconds")
< /code>
Ich verwende Jax und Diffrax, um Differentialgleichungen zu lösen und JAX = 0,4.31 und Diffrax = 0,6.1 und ich ständig an diese Versionen zu verwenden. Neuere Versionen von Jax oder Diffrax werden es überhaupt nicht zum Laufen bringen. Nach dem Neustart läuft es in 14 Sekunden erneut. /> [*] Überprüfen Sie, ob die XLA-Einstellungen angewendet werden. Umwelt
[*] Start Jupyter Notebook vom Anaconda Navigator
 Es wird einen Fehler geben, der den Kernel startet. />  Wählen Sie den neuen Kernel < /li>
 Ich führe das Skript aus < /li>
< /ol>
, und dies, wie ich sagte, läuft manchmal in 14 Sekunden und manchmal läuft in 6. Tests. < /p>
import numpy as np
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (10000, 10000))
%time jnp.dot(x, x).block_until_ready()