Warum wechselt mein JAX -Skript zufällig zwischen 6s und 14S -Ausführungszeiten in identischen Umgebungen?Python

Python-Programme
Anonymous
 Warum wechselt mein JAX -Skript zufällig zwischen 6s und 14S -Ausführungszeiten in identischen Umgebungen?

Post by Anonymous »

Zuallererst möchte ich sagen, dass ich wirklich schlecht und neu in der Codierung bin. Normalerweise läuft mein Drehbuch in 14 Sekunden, aber nachdem er die Umgebung kopiert und einen Jupyter -Kernel installiert hat, läuft es manchmal stattdessen in 6 Sekunden. Das Problem ist, dass dieses Verhalten nicht konsistent ist - manchmal wird die kopierte Umgebung auf 14 Sekunden festgehalten.

Code: Select all

from typing import Callable

import diffrax
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float  # https://github.com/google/jaxtyping

jax.config.update("jax_enable_x64", True)

from diffrax import diffeqsolve, ODETerm, SaveAt, Kvaerno5, Kvaerno3, Dopri8, Tsit5, PIDController
import interpax
from jaxopt import Bisection
from jax import jit
import equinox as eqx
from scipy.optimize import brentq

eps = 1e-18
N = 0

# **Global Parameters**
d = 1.0
t0 = 0.0   # Initial time
t_f = 6.0
dt0 = 1e-15  # Step size
tol = 1e-6  # Convergence tolerance for η
max_iters = 50  # Max iterations
damping = 0.3  # ✅ Small damping factor, for damping < 0.5 you stay closer to the original, for damping > 0.5 you go closer to the new
n = 1000

# **Define the Differential Equation**
@jit
def vector_field(t, u, args):
f, y, u_legit = u
α, β = args
d_f = y
d_y = (- α * f * ( y + 1 ) ** 2 + β * ( t + eps ) * y * ( y + 1 ) ** 2
- (N - 1) * ( 1 + y ) ** 2 * ( t + eps + f ) ** (-2) * ((t + eps) * y - f))
d_u_legit = f
return d_f, d_y, d_u_legit

# **General Solver Function**

@jit
def solve_ode(y_init, alpha, beta, solver=Tsit5()):
"""Solve the ODE for a given initial condition y_init, suppressing unwanted errors."""
try:
term = ODETerm(vector_field)
y0 = (y_init * eps, y_init, 0.0)
args = (alpha, beta)
saveat = SaveAt(ts=jnp.linspace(t0, t_f, n))

sol = diffeqsolve(
term, solver, t0, t_f, dt0, y0, args=args, saveat=saveat,
stepsize_controller=PIDController(rtol=1e-16, atol=1e-19),
max_steps=1_000_000  # Adjust if needed
)
return sol
except Exception as e:
if "maximum number of solver steps was reached" in str(e):
return None  # Ignore this error silently
else:
raise  # Let other errors pass through

def runs_successfully(y_init, alpha, beta):
"""Returns True if ODE solver runs without errors, False otherwise."""
try:
sol = solve_ode(y_init, alpha, beta)
if sol is None:
return False
return True
except Exception:  # Catch *any* solver failure quietly
return False  # Mark this initial condition as a failure

# Track previous successful y_low values
previous_y_lows = []

def bisection_search(alpha, beta, y_low=-0.35, y_high=-0.40, tol=1e-12, max_iter=50):
"""Find boundary value y_low where solver fails, with adaptive bounds."""

global previous_y_lows
cache = {}  # Cache to avoid redundant function calls

def cached_runs(y):
"""Check if solver runs successfully, with caching."""
if y in cache:
return cache[y]
result = runs_successfully(y, alpha, beta)
cache[y] = result
return result

# Ensure we have a valid starting point
if not cached_runs(y_low):
raise ValueError(f"❌ Lower bound y_low={y_low:.12f} must be a running case!")
if cached_runs(y_high):
raise ValueError(f"❌ Upper bound y_high={y_high:.12f} must be a failing case!")

# **Adaptive Search Range**
if len(previous_y_lows) > 2:
recent_changes = [abs(previous_y_lows[i] - previous_y_lows[i - 1]) for i in range(1, len(previous_y_lows))]
max_change = max(recent_changes) if recent_changes else float('inf')

# **Shrink range dynamically based on recent root stability**
shrink_factor = 0.5  # Shrink range by a fraction
if max_change <  0.01:  # If the root is stabilizing
range_shrink = shrink_factor * abs(y_high - y_low)
y_low = max(previous_y_lows[-1] - range_shrink, y_low)
y_high = min(previous_y_lows[-1] + range_shrink, y_high)
print(f"🔄 Adaptive Bisection: Narrowed range to [{y_low:.12f}, {y_high:.12f}]")

print(f"⚡ Starting Bisection with Initial Bounds: [{y_low:.12f}, {y_high:.12f}]")

for i in range(max_iter):
y_mid = (y_low + y_high) / 2
success = cached_runs(y_mid)

print(f"🔎 Iter {i+1}: y_low={y_low:.12f}, y_high={y_high:.12f}, y_mid={y_mid:.12f}, Success={success}")

if success:
y_low = y_mid
else:
y_high = y_mid

if abs(y_high - y_low) < tol:
break

previous_y_lows.append(y_low)

# Keep only last 5 values to track root changes efficiently
if len(previous_y_lows) > 5:
previous_y_lows.pop(0)

print(f"✅ Bisection Finished: Final y_low = {y_low:.12f}\n")

return y_low

# **Compute Anomalous Dimension (Direct Version)**
def compute_anomalous(eta_current):
"""Compute anomalous dimension given the current eta using ODE results directly."""
alpha = (d / 2 + 1 - eta_current / 2) / (1 - eta_current / (d + 2))
beta = (d / 2 - 1 + eta_current / 2) / (1 - eta_current / (d + 2))

# Find y_low
y_low = bisection_search(alpha, beta)

# Solve ODE
sol = solve_ode(y_low, alpha, beta)
if sol is None:
print(f"❌ Solver failed for y_low = {y_low:.9f}. Returning NaN for eta.")
return float('nan')

# Extract values directly from the ODE solution
x_vals = jnp.linspace(t0, t_f, n)
f_p = sol.ys[0]  # First derivative f'
d2fdx2 = sol.ys[1]  # Second derivative f''
potential = sol.ys[2]  # Potential function V(x)

spline = interpax.CubicSpline(x_vals, f_p, bc_type='natural')
spline_derivative = interpax.CubicSpline(x_vals, d2fdx2, bc_type='natural')

root_x = brentq(spline, a=1e-4, b=5.0, xtol=1e-12)

U_k_pp = spline_derivative(root_x)

third_spline = jax.grad(lambda x: spline_derivative(x))
U_k_3p = third_spline(root_x)

spline_points_prime = [spline(x) for x in x_vals]
spline_points_dprime = [spline_derivative(x) for x in x_vals]

print(f"📌 Root found at x = {root_x:.12f}")
print(f"📌 Derivative at rho_0 = {root_x:.12f} is f'(x) = {U_k_pp:.12f}")
print(f" This should be zero {spline(root_x)}")

# Compute new eta (anomalous dimension)
eta_new = U_k_3p ** 2 / (1 + U_k_pp) ** 4

# Debugging: Check if eta_new is NaN or out of range
if jnp.isnan(eta_new) or eta_new < 0: # Original : if jnp.isnan(eta_new) or eta_new < 0 or eta_new > 1
print(f"⚠ Warning: Unphysical eta_new={eta_new:.9f}.  Returning NaN.")
return float('nan')

# **Plot Results**
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True)

axs[0].plot(x_vals, f_p, color='blue', label="First Derivative (f')")
axs[0].plot(x_vals, spline_points_prime, color='orange', linestyle='--' ,label="First Splined Derivative (f')")
axs[0].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[0].set_ylabel("f'(x)")
axs[0].legend()
axs[0].set_title("First Derivative of f(x)")

axs[1].plot(x_vals, d2fdx2, color='green', label="Second Derivative (f'')")
axs[1].plot(x_vals, spline_points_dprime, color='orange', linestyle='--', label="Second Splined Derivative (f'')")
axs[1].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[1].set_ylabel("f''(x)")
axs[1].legend()
axs[1].set_title("Second Derivative of f(x)")

axs[2].plot(x_vals, potential, color='black', label="Potential (V(x))")
axs[2].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[2].set_xlabel("x")
axs[2].set_ylabel("V(x)")
axs[2].legend()
axs[2].set_title("Potential V(x)")

plt.tight_layout()
plt.show()

return float(eta_new)

# **Iterate to Find Self-Consistent η with Explicit Damping**
def find_self_consistent_eta(eta_init=1.06294761356, tol=1e-4, max_iters=1):
"""Iterate until η converges to a self-consistent value using damping."""
eta_current = eta_init
prev_values = []

for i in range(max_iters):
eta_calculated = compute_anomalous(eta_current)

# Check for NaN or invalid values
if jnp.isnan(eta_calculated):
print(f"❌ Iteration {i+1}: Computed NaN for eta. Stopping iteration.")
return None

# Apply damping explicitly
eta_next = (1 - damping) * eta_current + damping * eta_calculated

# Debugging prints
print(f"Iteration {i+1}:")
print(f"  - η_current (used for ODE) = {eta_current:.12f}")
print(f"  - η_calculated (new from ODE) = {eta_calculated:.12f}")
print(f"  - η_next (damped for next iteration) = {eta_next:.12f}\n")

# Detect infinite loops (if η keeps bouncing between a few values)
if eta_next in prev_values:
print(f"⚠ Warning: η is cycling between values. Consider adjusting damping.")
return eta_next
prev_values.append(eta_next)

# Check for convergence
if abs(eta_next - eta_current) < tol:
print("✅ Converged!")
return eta_next

eta_current = eta_next

print("⚠ Did not converge within max iterations.")
return eta_current

import sys
import os
import time

# Redirect stderr to null (suppress error messages)
sys.stderr = open(os.devnull, 'w')

# Start timer
start_time = time.time()

# Run function (using `jax.jit`, NOT `filter_jit`)
final_eta = find_self_consistent_eta()

# End timer
end_time = time.time()

# Reset stderr back to normal
sys.stderr = sys.__stderr__

# Print results
print(f"\n🎯 Final self-consistent η: {final_eta:.12f}")
print(f"⏱ Execution time: {end_time - start_time:.4f} seconds")
< /code>
Ich verwende Jax und Diffrax, um Differentialgleichungen zu lösen und JAX = 0,4.31 und Diffrax = 0,6.1 und ich ständig an diese Versionen zu verwenden. Neuere Versionen von Jax oder Diffrax werden es überhaupt nicht zum Laufen bringen. Nach dem Neustart läuft es in 14 Sekunden erneut. /> [*] Überprüfen Sie, ob die XLA-Einstellungen angewendet werden. Umwelt
[*] Start Jupyter Notebook vom Anaconda Navigator
 Es wird einen Fehler geben, der den Kernel startet. />  Wählen Sie den neuen Kernel < /li>
 Ich führe das Skript aus < /li>
< /ol>
, und dies, wie ich sagte, läuft manchmal in 14 Sekunden und manchmal läuft in 6. Tests. < /p>
import numpy as np
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (10000, 10000))
%time jnp.dot(x, x).block_until_ready()

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post