import sr
import sympy
def test_2():
#sin(x)*exp(x)
model = sr.SR(niterations = 1,
unary_operators = {"sin": (sympy.sin, np.sin),
"exp": (sympy.exp, np.exp)},
binary_operators = {"*": (operator.mul, operator.mul)},
discrete_param_values = ["(-5, 5)"],
foundBreak = True,
maxfev = 200)
n = 100
xmin = -5
xmax = 10
x = (xmax - xmin) * np.random.rand(n) + xmin
y = np.sin(x) * np.exp(x)
model.predict([x], y, ["x"])
assert(len(model.bestExpressions) == 1)
assert(model.bestExpressions[0][0] == sympy.sympify("sin(x)*exp(x)"))
< /code>
Meine Implementierung der Fit -Funktion versucht mehrere zufällige curve_fit mit runden Operation (für diskrete Werte): < /p>
def fit(func, value_vars, y, p0, loss_func, eps, maxfev, discrete_values = []):
if (len(discrete_values) == 0):
try:
value_params, _ = curve_fit(func, value_vars, y, p0 = p0, maxfev = maxfev)
except RuntimeError as e:
print(e)
return p0
return value_params
try:
value_params, _ = curve_fit(func, value_vars, y, p0 = p0, maxfev = maxfev)
except RuntimeError as e:
print(e)
return p0
best_x = random_discrete_values(len(p0), discrete_values)
try:
best_loss = loss_func(func(value_vars, *best_x), y)
x = best_x
for i in range(0, maxfev):
try:
value_params, _ = curve_fit(func, value_vars, y, p0 = random_discrete_values(len(p0), discrete_values), maxfev = 10 * maxfev)
x = round_discrete_values(value_params, discrete_values)
loss = loss_func(func(value_vars,

if (loss < best_loss and any(x)):
best_loss = loss
best_x = x
if (best_loss < eps):
break
except RuntimeError:
pass
except ValueError:
pass
except OverflowError:
pass
except ValueError:
pass
except OverflowError:
pass
return best_x
< /code>
Hier ist ein MCVE mit Optuna, es scheitert auch: < /p>
import math
import numpy as np
import random
import sympy
def range_discrete_values(discrete_values):
type_, min_, max_ = int, math.inf, -math.inf
for value in discrete_values:
if (type(value) is float):
type_ = float
min_ = min(min_, value)
max_ = max(max_, value)
elif (type(value) == str):
s = value
a, b = [float(x) for x in s[1:-1].split(",")]
assert(a