Warum wechselt mein JAX -Skript zufällig zwischen 6s und 14S -Ausführungszeiten in identischen Umgebungen?
Posted: 17 Mar 2025, 00:34
Zuallererst möchte ich sagen, dass ich wirklich schlecht und neu in der Codierung bin. Normalerweise läuft mein Drehbuch in 14 Sekunden, aber nachdem er die Umgebung kopiert und einen Jupyter -Kernel installiert hat, läuft es manchmal stattdessen in 6 Sekunden. Das Problem ist, dass dieses Verhalten nicht konsistent ist - manchmal wird die kopierte Umgebung auf 14 Sekunden festgehalten.
Code: Select all
from typing import Callable
import diffrax
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
jax.config.update("jax_enable_x64", True)
from diffrax import diffeqsolve, ODETerm, SaveAt, Kvaerno5, Kvaerno3, Dopri8, Tsit5, PIDController
import interpax
from jaxopt import Bisection
from jax import jit
import equinox as eqx
from scipy.optimize import brentq
eps = 1e-18
N = 0
# **Global Parameters**
d = 1.0
t0 = 0.0 # Initial time
t_f = 6.0
dt0 = 1e-15 # Step size
tol = 1e-6 # Convergence tolerance for η
max_iters = 50 # Max iterations
damping = 0.3 # ✅ Small damping factor, for damping < 0.5 you stay closer to the original, for damping > 0.5 you go closer to the new
n = 1000
# **Define the Differential Equation**
@jit
def vector_field(t, u, args):
f, y, u_legit = u
α, β = args
d_f = y
d_y = (- α * f * ( y + 1 ) ** 2 + β * ( t + eps ) * y * ( y + 1 ) ** 2
- (N - 1) * ( 1 + y ) ** 2 * ( t + eps + f ) ** (-2) * ((t + eps) * y - f))
d_u_legit = f
return d_f, d_y, d_u_legit
# **General Solver Function**
@jit
def solve_ode(y_init, alpha, beta, solver=Tsit5()):
"""Solve the ODE for a given initial condition y_init, suppressing unwanted errors."""
try:
term = ODETerm(vector_field)
y0 = (y_init * eps, y_init, 0.0)
args = (alpha, beta)
saveat = SaveAt(ts=jnp.linspace(t0, t_f, n))
sol = diffeqsolve(
term, solver, t0, t_f, dt0, y0, args=args, saveat=saveat,
stepsize_controller=PIDController(rtol=1e-16, atol=1e-19),
max_steps=1_000_000 # Adjust if needed
)
return sol
except Exception as e:
if "maximum number of solver steps was reached" in str(e):
return None # Ignore this error silently
else:
raise # Let other errors pass through
def runs_successfully(y_init, alpha, beta):
"""Returns True if ODE solver runs without errors, False otherwise."""
try:
sol = solve_ode(y_init, alpha, beta)
if sol is None:
return False
return True
except Exception: # Catch *any* solver failure quietly
return False # Mark this initial condition as a failure
# Track previous successful y_low values
previous_y_lows = []
def bisection_search(alpha, beta, y_low=-0.35, y_high=-0.40, tol=1e-12, max_iter=50):
"""Find boundary value y_low where solver fails, with adaptive bounds."""
global previous_y_lows
cache = {} # Cache to avoid redundant function calls
def cached_runs(y):
"""Check if solver runs successfully, with caching."""
if y in cache:
return cache[y]
result = runs_successfully(y, alpha, beta)
cache[y] = result
return result
# Ensure we have a valid starting point
if not cached_runs(y_low):
raise ValueError(f"❌ Lower bound y_low={y_low:.12f} must be a running case!")
if cached_runs(y_high):
raise ValueError(f"❌ Upper bound y_high={y_high:.12f} must be a failing case!")
# **Adaptive Search Range**
if len(previous_y_lows) > 2:
recent_changes = [abs(previous_y_lows[i] - previous_y_lows[i - 1]) for i in range(1, len(previous_y_lows))]
max_change = max(recent_changes) if recent_changes else float('inf')
# **Shrink range dynamically based on recent root stability**
shrink_factor = 0.5 # Shrink range by a fraction
if max_change < 0.01: # If the root is stabilizing
range_shrink = shrink_factor * abs(y_high - y_low)
y_low = max(previous_y_lows[-1] - range_shrink, y_low)
y_high = min(previous_y_lows[-1] + range_shrink, y_high)
print(f"🔄 Adaptive Bisection: Narrowed range to [{y_low:.12f}, {y_high:.12f}]")
print(f"⚡ Starting Bisection with Initial Bounds: [{y_low:.12f}, {y_high:.12f}]")
for i in range(max_iter):
y_mid = (y_low + y_high) / 2
success = cached_runs(y_mid)
print(f"🔎 Iter {i+1}: y_low={y_low:.12f}, y_high={y_high:.12f}, y_mid={y_mid:.12f}, Success={success}")
if success:
y_low = y_mid
else:
y_high = y_mid
if abs(y_high - y_low) < tol:
break
previous_y_lows.append(y_low)
# Keep only last 5 values to track root changes efficiently
if len(previous_y_lows) > 5:
previous_y_lows.pop(0)
print(f"✅ Bisection Finished: Final y_low = {y_low:.12f}\n")
return y_low
# **Compute Anomalous Dimension (Direct Version)**
def compute_anomalous(eta_current):
"""Compute anomalous dimension given the current eta using ODE results directly."""
alpha = (d / 2 + 1 - eta_current / 2) / (1 - eta_current / (d + 2))
beta = (d / 2 - 1 + eta_current / 2) / (1 - eta_current / (d + 2))
# Find y_low
y_low = bisection_search(alpha, beta)
# Solve ODE
sol = solve_ode(y_low, alpha, beta)
if sol is None:
print(f"❌ Solver failed for y_low = {y_low:.9f}. Returning NaN for eta.")
return float('nan')
# Extract values directly from the ODE solution
x_vals = jnp.linspace(t0, t_f, n)
f_p = sol.ys[0] # First derivative f'
d2fdx2 = sol.ys[1] # Second derivative f''
potential = sol.ys[2] # Potential function V(x)
spline = interpax.CubicSpline(x_vals, f_p, bc_type='natural')
spline_derivative = interpax.CubicSpline(x_vals, d2fdx2, bc_type='natural')
root_x = brentq(spline, a=1e-4, b=5.0, xtol=1e-12)
U_k_pp = spline_derivative(root_x)
third_spline = jax.grad(lambda x: spline_derivative(x))
U_k_3p = third_spline(root_x)
spline_points_prime = [spline(x) for x in x_vals]
spline_points_dprime = [spline_derivative(x) for x in x_vals]
print(f"📌 Root found at x = {root_x:.12f}")
print(f"📌 Derivative at rho_0 = {root_x:.12f} is f'(x) = {U_k_pp:.12f}")
print(f" This should be zero {spline(root_x)}")
# Compute new eta (anomalous dimension)
eta_new = U_k_3p ** 2 / (1 + U_k_pp) ** 4
# Debugging: Check if eta_new is NaN or out of range
if jnp.isnan(eta_new) or eta_new < 0: # Original : if jnp.isnan(eta_new) or eta_new < 0 or eta_new > 1
print(f"⚠ Warning: Unphysical eta_new={eta_new:.9f}. Returning NaN.")
return float('nan')
# **Plot Results**
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
axs[0].plot(x_vals, f_p, color='blue', label="First Derivative (f')")
axs[0].plot(x_vals, spline_points_prime, color='orange', linestyle='--' ,label="First Splined Derivative (f')")
axs[0].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[0].set_ylabel("f'(x)")
axs[0].legend()
axs[0].set_title("First Derivative of f(x)")
axs[1].plot(x_vals, d2fdx2, color='green', label="Second Derivative (f'')")
axs[1].plot(x_vals, spline_points_dprime, color='orange', linestyle='--', label="Second Splined Derivative (f'')")
axs[1].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[1].set_ylabel("f''(x)")
axs[1].legend()
axs[1].set_title("Second Derivative of f(x)")
axs[2].plot(x_vals, potential, color='black', label="Potential (V(x))")
axs[2].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[2].set_xlabel("x")
axs[2].set_ylabel("V(x)")
axs[2].legend()
axs[2].set_title("Potential V(x)")
plt.tight_layout()
plt.show()
return float(eta_new)
# **Iterate to Find Self-Consistent η with Explicit Damping**
def find_self_consistent_eta(eta_init=1.06294761356, tol=1e-4, max_iters=1):
"""Iterate until η converges to a self-consistent value using damping."""
eta_current = eta_init
prev_values = []
for i in range(max_iters):
eta_calculated = compute_anomalous(eta_current)
# Check for NaN or invalid values
if jnp.isnan(eta_calculated):
print(f"❌ Iteration {i+1}: Computed NaN for eta. Stopping iteration.")
return None
# Apply damping explicitly
eta_next = (1 - damping) * eta_current + damping * eta_calculated
# Debugging prints
print(f"Iteration {i+1}:")
print(f" - η_current (used for ODE) = {eta_current:.12f}")
print(f" - η_calculated (new from ODE) = {eta_calculated:.12f}")
print(f" - η_next (damped for next iteration) = {eta_next:.12f}\n")
# Detect infinite loops (if η keeps bouncing between a few values)
if eta_next in prev_values:
print(f"⚠ Warning: η is cycling between values. Consider adjusting damping.")
return eta_next
prev_values.append(eta_next)
# Check for convergence
if abs(eta_next - eta_current) < tol:
print("✅ Converged!")
return eta_next
eta_current = eta_next
print("⚠ Did not converge within max iterations.")
return eta_current
import sys
import os
import time
# Redirect stderr to null (suppress error messages)
sys.stderr = open(os.devnull, 'w')
# Start timer
start_time = time.time()
# Run function (using `jax.jit`, NOT `filter_jit`)
final_eta = find_self_consistent_eta()
# End timer
end_time = time.time()
# Reset stderr back to normal
sys.stderr = sys.__stderr__
# Print results
print(f"\n🎯 Final self-consistent η: {final_eta:.12f}")
print(f"⏱ Execution time: {end_time - start_time:.4f} seconds")
< /code>
Ich verwende Jax und Diffrax, um Differentialgleichungen zu lösen und JAX = 0,4.31 und Diffrax = 0,6.1 und ich ständig an diese Versionen zu verwenden. Neuere Versionen von Jax oder Diffrax werden es überhaupt nicht zum Laufen bringen. Nach dem Neustart läuft es in 14 Sekunden erneut. /> [*] Überprüfen Sie, ob die XLA-Einstellungen angewendet werden. Umwelt
[*] Start Jupyter Notebook vom Anaconda Navigator
Es wird einen Fehler geben, der den Kernel startet. /> Wählen Sie den neuen Kernel < /li>
Ich führe das Skript aus < /li>
< /ol>
, und dies, wie ich sagte, läuft manchmal in 14 Sekunden und manchmal läuft in 6. Tests. < /p>
import numpy as np
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (10000, 10000))
%time jnp.dot(x, x).block_until_ready()