Zuallererst möchte ich sagen, dass ich wirklich schlecht und neu in der Codierung bin. Normalerweise läuft mein Drehbuch in 14 Sekunden, aber nachdem er die Umgebung kopiert und einen Jupyter -Kernel installiert hat, läuft es manchmal stattdessen in 6 Sekunden. Das Problem ist, dass dieses Verhalten nicht konsistent ist - manchmal wird die kopierte Umgebung auf 14 Sekunden festgehalten.
from typing import Callable
import diffrax
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float # https://github.com/google/jaxtyping
jax.config.update("jax_enable_x64", True)
from diffrax import diffeqsolve, ODETerm, SaveAt, Kvaerno5, Kvaerno3, Dopri8, Tsit5, PIDController
import interpax
from jaxopt import Bisection
from jax import jit
import equinox as eqx
from scipy.optimize import brentq
eps = 1e-18
N = 0
# **Global Parameters**
d = 1.0
t0 = 0.0 # Initial time
t_f = 6.0
dt0 = 1e-15 # Step size
tol = 1e-6 # Convergence tolerance for η
max_iters = 50 # Max iterations
damping = 0.3 # ✅ Small damping factor, for damping < 0.5 you stay closer to the original, for damping > 0.5 you go closer to the new
n = 1000
# **Define the Differential Equation**
@jit
def vector_field(t, u, args):
f, y, u_legit = u
α, β = args
d_f = y
d_y = (- α * f * ( y + 1 ) ** 2 + β * ( t + eps ) * y * ( y + 1 ) ** 2
- (N - 1) * ( 1 + y ) ** 2 * ( t + eps + f ) ** (-2) * ((t + eps) * y - f))
d_u_legit = f
return d_f, d_y, d_u_legit
# **General Solver Function**
@jit
def solve_ode(y_init, alpha, beta, solver=Tsit5()):
"""Solve the ODE for a given initial condition y_init, suppressing unwanted errors."""
try:
term = ODETerm(vector_field)
y0 = (y_init * eps, y_init, 0.0)
args = (alpha, beta)
saveat = SaveAt(ts=jnp.linspace(t0, t_f, n))
sol = diffeqsolve(
term, solver, t0, t_f, dt0, y0, args=args, saveat=saveat,
stepsize_controller=PIDController(rtol=1e-16, atol=1e-19),
max_steps=1_000_000 # Adjust if needed
)
return sol
except Exception as e:
if "maximum number of solver steps was reached" in str(e):
return None # Ignore this error silently
else:
raise # Let other errors pass through
def runs_successfully(y_init, alpha, beta):
"""Returns True if ODE solver runs without errors, False otherwise."""
try:
sol = solve_ode(y_init, alpha, beta)
if sol is None:
return False
return True
except Exception: # Catch *any* solver failure quietly
return False # Mark this initial condition as a failure
# Track previous successful y_low values
previous_y_lows = []
def bisection_search(alpha, beta, y_low=-0.35, y_high=-0.40, tol=1e-12, max_iter=50):
"""Find boundary value y_low where solver fails, with adaptive bounds."""
global previous_y_lows
cache = {} # Cache to avoid redundant function calls
def cached_runs(y):
"""Check if solver runs successfully, with caching."""
if y in cache:
return cache[y]
result = runs_successfully(y, alpha, beta)
cache[y] = result
return result
# Ensure we have a valid starting point
if not cached_runs(y_low):
raise ValueError(f"❌ Lower bound y_low={y_low:.12f} must be a running case!")
if cached_runs(y_high):
raise ValueError(f"❌ Upper bound y_high={y_high:.12f} must be a failing case!")
# **Adaptive Search Range**
if len(previous_y_lows) > 2:
recent_changes = [abs(previous_y_lows[i] - previous_y_lows[i - 1]) for i in range(1, len(previous_y_lows))]
max_change = max(recent_changes) if recent_changes else float('inf')
# **Shrink range dynamically based on recent root stability**
shrink_factor = 0.5 # Shrink range by a fraction
if max_change < 0.01: # If the root is stabilizing
range_shrink = shrink_factor * abs(y_high - y_low)
y_low = max(previous_y_lows[-1] - range_shrink, y_low)
y_high = min(previous_y_lows[-1] + range_shrink, y_high)
print(f"🔄 Adaptive Bisection: Narrowed range to [{y_low:.12f}, {y_high:.12f}]")
print(f"⚡ Starting Bisection with Initial Bounds: [{y_low:.12f}, {y_high:.12f}]")
for i in range(max_iter):
y_mid = (y_low + y_high) / 2
success = cached_runs(y_mid)
print(f"🔎 Iter {i+1}: y_low={y_low:.12f}, y_high={y_high:.12f}, y_mid={y_mid:.12f}, Success={success}")
if success:
y_low = y_mid
else:
y_high = y_mid
if abs(y_high - y_low) < tol:
break
previous_y_lows.append(y_low)
# Keep only last 5 values to track root changes efficiently
if len(previous_y_lows) > 5:
previous_y_lows.pop(0)
print(f"✅ Bisection Finished: Final y_low = {y_low:.12f}\n")
return y_low
# **Compute Anomalous Dimension (Direct Version)**
def compute_anomalous(eta_current):
"""Compute anomalous dimension given the current eta using ODE results directly."""
alpha = (d / 2 + 1 - eta_current / 2) / (1 - eta_current / (d + 2))
beta = (d / 2 - 1 + eta_current / 2) / (1 - eta_current / (d + 2))
# Find y_low
y_low = bisection_search(alpha, beta)
# Solve ODE
sol = solve_ode(y_low, alpha, beta)
if sol is None:
print(f"❌ Solver failed for y_low = {y_low:.9f}. Returning NaN for eta.")
return float('nan')
# Extract values directly from the ODE solution
x_vals = jnp.linspace(t0, t_f, n)
f_p = sol.ys[0] # First derivative f'
d2fdx2 = sol.ys[1] # Second derivative f''
potential = sol.ys[2] # Potential function V(x)
spline = interpax.CubicSpline(x_vals, f_p, bc_type='natural')
spline_derivative = interpax.CubicSpline(x_vals, d2fdx2, bc_type='natural')
root_x = brentq(spline, a=1e-4, b=5.0, xtol=1e-12)
U_k_pp = spline_derivative(root_x)
third_spline = jax.grad(lambda x: spline_derivative(x))
U_k_3p = third_spline(root_x)
spline_points_prime = [spline(x) for x in x_vals]
spline_points_dprime = [spline_derivative(x) for x in x_vals]
print(f"📌 Root found at x = {root_x:.12f}")
print(f"📌 Derivative at rho_0 = {root_x:.12f} is f'(x) = {U_k_pp:.12f}")
print(f" This should be zero {spline(root_x)}")
# Compute new eta (anomalous dimension)
eta_new = U_k_3p ** 2 / (1 + U_k_pp) ** 4
# Debugging: Check if eta_new is NaN or out of range
if jnp.isnan(eta_new) or eta_new < 0: # Original : if jnp.isnan(eta_new) or eta_new < 0 or eta_new > 1
print(f"⚠ Warning: Unphysical eta_new={eta_new:.9f}. Returning NaN.")
return float('nan')
# **Plot Results**
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
axs[0].plot(x_vals, f_p, color='blue', label="First Derivative (f')")
axs[0].plot(x_vals, spline_points_prime, color='orange', linestyle='--' ,label="First Splined Derivative (f')")
axs[0].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[0].set_ylabel("f'(x)")
axs[0].legend()
axs[0].set_title("First Derivative of f(x)")
axs[1].plot(x_vals, d2fdx2, color='green', label="Second Derivative (f'')")
axs[1].plot(x_vals, spline_points_dprime, color='orange', linestyle='--', label="Second Splined Derivative (f'')")
axs[1].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[1].set_ylabel("f''(x)")
axs[1].legend()
axs[1].set_title("Second Derivative of f(x)")
axs[2].plot(x_vals, potential, color='black', label="Potential (V(x))")
axs[2].axvline(root_x, linestyle='dashed', color='red', label="Potential Minimum")
axs[2].set_xlabel("x")
axs[2].set_ylabel("V(x)")
axs[2].legend()
axs[2].set_title("Potential V(x)")
plt.tight_layout()
plt.show()
return float(eta_new)
# **Iterate to Find Self-Consistent η with Explicit Damping**
def find_self_consistent_eta(eta_init=1.06294761356, tol=1e-4, max_iters=1):
"""Iterate until η converges to a self-consistent value using damping."""
eta_current = eta_init
prev_values = []
for i in range(max_iters):
eta_calculated = compute_anomalous(eta_current)
# Check for NaN or invalid values
if jnp.isnan(eta_calculated):
print(f"❌ Iteration {i+1}: Computed NaN for eta. Stopping iteration.")
return None
# Apply damping explicitly
eta_next = (1 - damping) * eta_current + damping * eta_calculated
# Debugging prints
print(f"Iteration {i+1}:")
print(f" - η_current (used for ODE) = {eta_current:.12f}")
print(f" - η_calculated (new from ODE) = {eta_calculated:.12f}")
print(f" - η_next (damped for next iteration) = {eta_next:.12f}\n")
# Detect infinite loops (if η keeps bouncing between a few values)
if eta_next in prev_values:
print(f"⚠ Warning: η is cycling between values. Consider adjusting damping.")
return eta_next
prev_values.append(eta_next)
# Check for convergence
if abs(eta_next - eta_current) < tol:
print("✅ Converged!")
return eta_next
eta_current = eta_next
print("⚠ Did not converge within max iterations.")
return eta_current
import sys
import os
import time
# Redirect stderr to null (suppress error messages)
sys.stderr = open(os.devnull, 'w')
# Start timer
start_time = time.time()
# Run function (using `jax.jit`, NOT `filter_jit`)
final_eta = find_self_consistent_eta()
# End timer
end_time = time.time()
# Reset stderr back to normal
sys.stderr = sys.__stderr__
# Print results
print(f"\n🎯 Final self-consistent η: {final_eta:.12f}")
print(f"⏱ Execution time: {end_time - start_time:.4f} seconds")
< /code>
Ich verwende Jax und Diffrax, um Differentialgleichungen zu lösen und JAX = 0,4.31 und Diffrax = 0,6.1 und ich ständig an diese Versionen zu verwenden. Neuere Versionen von Jax oder Diffrax werden es überhaupt nicht zum Laufen bringen. Nach dem Neustart läuft es in 14 Sekunden erneut. /> [*] Überprüfen Sie, ob die XLA-Einstellungen angewendet werden. Umwelt
[*] Start Jupyter Notebook vom Anaconda Navigator
Es wird einen Fehler geben, der den Kernel startet. /> Wählen Sie den neuen Kernel < /li>
Ich führe das Skript aus < /li>
< /ol>
, und dies, wie ich sagte, läuft manchmal in 14 Sekunden und manchmal läuft in 6. Tests. < /p>
import numpy as np
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (10000, 10000))
%time jnp.dot(x, x).block_until_ready()
Zuallererst möchte ich sagen, dass ich wirklich schlecht und neu in der Codierung bin. Normalerweise läuft mein Drehbuch in 14 Sekunden, aber nachdem er die Umgebung kopiert und einen Jupyter -Kernel installiert hat, läuft es manchmal stattdessen in 6 Sekunden. Das [url=viewtopic.php?t=15738]Problem[/url] ist, dass dieses Verhalten nicht konsistent ist - manchmal wird die kopierte Umgebung auf 14 Sekunden festgehalten.[code]from typing import Callable
import diffrax import jax import jax.lax as lax import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float # https://github.com/google/jaxtyping
jax.config.update("jax_enable_x64", True)
from diffrax import diffeqsolve, ODETerm, SaveAt, Kvaerno5, Kvaerno3, Dopri8, Tsit5, PIDController import interpax from jaxopt import Bisection from jax import jit import equinox as eqx from scipy.optimize import brentq
eps = 1e-18 N = 0
# **Global Parameters** d = 1.0 t0 = 0.0 # Initial time t_f = 6.0 dt0 = 1e-15 # Step size tol = 1e-6 # Convergence tolerance for η max_iters = 50 # Max iterations damping = 0.3 # ✅ Small damping factor, for damping < 0.5 you stay closer to the original, for damping > 0.5 you go closer to the new n = 1000
# **Define the Differential Equation** @jit def vector_field(t, u, args): f, y, u_legit = u α, β = args d_f = y d_y = (- α * f * ( y + 1 ) ** 2 + β * ( t + eps ) * y * ( y + 1 ) ** 2 - (N - 1) * ( 1 + y ) ** 2 * ( t + eps + f ) ** (-2) * ((t + eps) * y - f)) d_u_legit = f return d_f, d_y, d_u_legit
# **General Solver Function**
@jit def solve_ode(y_init, alpha, beta, solver=Tsit5()): """Solve the ODE for a given initial condition y_init, suppressing unwanted errors.""" try: term = ODETerm(vector_field) y0 = (y_init * eps, y_init, 0.0) args = (alpha, beta) saveat = SaveAt(ts=jnp.linspace(t0, t_f, n))
sol = diffeqsolve( term, solver, t0, t_f, dt0, y0, args=args, saveat=saveat, stepsize_controller=PIDController(rtol=1e-16, atol=1e-19), max_steps=1_000_000 # Adjust if needed ) return sol except Exception as e: if "maximum number of solver steps was reached" in str(e): return None # Ignore this error silently else: raise # Let other errors pass through
def runs_successfully(y_init, alpha, beta): """Returns True if ODE solver runs without errors, False otherwise.""" try: sol = solve_ode(y_init, alpha, beta) if sol is None: return False return True except Exception: # Catch *any* solver failure quietly return False # Mark this initial condition as a failure
def bisection_search(alpha, beta, y_low=-0.35, y_high=-0.40, tol=1e-12, max_iter=50): """Find boundary value y_low where solver fails, with adaptive bounds."""
global previous_y_lows cache = {} # Cache to avoid redundant function calls
def cached_runs(y): """Check if solver runs successfully, with caching.""" if y in cache: return cache[y] result = runs_successfully(y, alpha, beta) cache[y] = result return result
# Ensure we have a valid starting point if not cached_runs(y_low): raise ValueError(f"❌ Lower bound y_low={y_low:.12f} must be a running case!") if cached_runs(y_high): raise ValueError(f"❌ Upper bound y_high={y_high:.12f} must be a failing case!")
# **Adaptive Search Range** if len(previous_y_lows) > 2: recent_changes = [abs(previous_y_lows[i] - previous_y_lows[i - 1]) for i in range(1, len(previous_y_lows))] max_change = max(recent_changes) if recent_changes else float('inf')
# **Shrink range dynamically based on recent root stability** shrink_factor = 0.5 # Shrink range by a fraction if max_change < 0.01: # If the root is stabilizing range_shrink = shrink_factor * abs(y_high - y_low) y_low = max(previous_y_lows[-1] - range_shrink, y_low) y_high = min(previous_y_lows[-1] + range_shrink, y_high) print(f"🔄 Adaptive Bisection: Narrowed range to [{y_low:.12f}, {y_high:.12f}]")
print(f"⚡ Starting Bisection with Initial Bounds: [{y_low:.12f}, {y_high:.12f}]")
for i in range(max_iter): y_mid = (y_low + y_high) / 2 success = cached_runs(y_mid)
print(f"🔎 Iter {i+1}: y_low={y_low:.12f}, y_high={y_high:.12f}, y_mid={y_mid:.12f}, Success={success}")
if success: y_low = y_mid else: y_high = y_mid
if abs(y_high - y_low) < tol: break
previous_y_lows.append(y_low)
# Keep only last 5 values to track root changes efficiently if len(previous_y_lows) > 5: previous_y_lows.pop(0)
print(f"✅ Bisection Finished: Final y_low = {y_low:.12f}\n")
# Solve ODE sol = solve_ode(y_low, alpha, beta) if sol is None: print(f"❌ Solver failed for y_low = {y_low:.9f}. Returning NaN for eta.") return float('nan')
# Extract values directly from the ODE solution x_vals = jnp.linspace(t0, t_f, n) f_p = sol.ys[0] # First derivative f' d2fdx2 = sol.ys[1] # Second derivative f'' potential = sol.ys[2] # Potential function V(x)
spline_points_prime = [spline(x) for x in x_vals] spline_points_dprime = [spline_derivative(x) for x in x_vals]
print(f"📌 Root found at x = {root_x:.12f}") print(f"📌 Derivative at rho_0 = {root_x:.12f} is f'(x) = {U_k_pp:.12f}") print(f" This should be zero {spline(root_x)}")
# Compute new eta (anomalous dimension) eta_new = U_k_3p ** 2 / (1 + U_k_pp) ** 4
# Debugging: Check if eta_new is NaN or out of range if jnp.isnan(eta_new) or eta_new < 0: # Original : if jnp.isnan(eta_new) or eta_new < 0 or eta_new > 1 print(f"⚠ Warning: Unphysical eta_new={eta_new:.9f}. Returning NaN.") return float('nan')
# **Iterate to Find Self-Consistent η with Explicit Damping** def find_self_consistent_eta(eta_init=1.06294761356, tol=1e-4, max_iters=1): """Iterate until η converges to a self-consistent value using damping.""" eta_current = eta_init prev_values = []
for i in range(max_iters): eta_calculated = compute_anomalous(eta_current)
# Check for NaN or invalid values if jnp.isnan(eta_calculated): print(f"❌ Iteration {i+1}: Computed NaN for eta. Stopping iteration.") return None
# Debugging prints print(f"Iteration {i+1}:") print(f" - η_current (used for ODE) = {eta_current:.12f}") print(f" - η_calculated (new from ODE) = {eta_calculated:.12f}") print(f" - η_next (damped for next iteration) = {eta_next:.12f}\n")
# Detect infinite loops (if η keeps bouncing between a few values) if eta_next in prev_values: print(f"⚠ Warning: η is cycling between values. Consider adjusting damping.") return eta_next prev_values.append(eta_next)
# Check for convergence if abs(eta_next - eta_current) < tol: print("✅ Converged!") return eta_next
eta_current = eta_next
print("⚠ Did not converge within max iterations.") return eta_current
# Run function (using `jax.jit`, NOT `filter_jit`) final_eta = find_self_consistent_eta()
# End timer end_time = time.time()
# Reset stderr back to normal sys.stderr = sys.__stderr__
# Print results print(f"\n🎯 Final self-consistent η: {final_eta:.12f}") print(f"⏱ Execution time: {end_time - start_time:.4f} seconds") < /code> Ich verwende Jax und Diffrax, um Differentialgleichungen zu lösen und JAX = 0,4.31 und Diffrax = 0,6.1 und ich ständig an diese Versionen zu verwenden. Neuere Versionen von Jax oder Diffrax werden es überhaupt nicht zum Laufen bringen. Nach dem Neustart läuft es in 14 Sekunden erneut. /> [*] Überprüfen Sie, ob die XLA-Einstellungen angewendet werden. Umwelt [*] Start Jupyter Notebook vom Anaconda Navigator Es wird einen Fehler geben, der den Kernel startet. /> Wählen Sie den neuen Kernel < /li> Ich führe das Skript aus < /li> < /ol> , und dies, wie ich sagte, läuft manchmal in 14 Sekunden und manchmal läuft in 6. Tests. < /p> import numpy as np import jax.numpy as jnp from jax import random x = random.uniform(random.key(0), (10000, 10000)) %time jnp.dot(x, x).block_until_ready() [/code]
Ich bin ein Jax -Anfänger und jemand mit Jax hat mir gesagt, dass wenn wir wiederholt Anrufe zu einem Scan / for Loop (z. B. wenn diese selbst für Loop selbst eingewickelt werden), könnte es besser...
Ich habe eine ausführbare Datei mit PyInstaller erstellt und festgestellt, dass das Paket auch nach einigen Größenreduzierungstricks (Erstellen einer benutzerdefinierten Umgebung, Verwendung von...
Wie die Überschrift sagt: Kann das Symbol für Android -App -Launcher dynamisch geändert werden, wenn Benutzer zwischen hell/dunkel wechselt? (monochromatisch) dazwischen und dann ausschalten. Ich...
Ich arbeite an der Umwandlung einer transformationsortigen numerischen Pipeline von Numpy in Jax, um die JIT-Beschleunigung zu nutzen. Ich habe jedoch festgestellt, dass einige grundlegende...
Ich arbeite an der Umwandlung einer transformationsortigen numerischen Pipeline von Numpy in Jax, um die JIT-Beschleunigung zu nutzen. Ich habe jedoch festgestellt, dass einige grundlegende...