Warum ist die manuelle (-ish) Berechnung von Protokoll-Sum-Exp schneller als Numpy/Scipy-Funktionen?
Posted: 26 Aug 2025, 11:13
Ich arbeite mit einem sehr großen Array und war überrascht, dass die manuelle Berechnung des Protokolls der Summe der Exponentials tatsächlich schneller ist als die integrierten Funktionen für diese Aufgabe von Scipy (scipy () und sogar numpy ()
Ich habe angenommen, dass die Bibliotheksimplementierungen optimierter sind. Mache ich etwas Dummes oder vermisse etwas Offensichtliches? < /P>
Code: Select all
scipy.special.logsumexp
Code: Select all
numpy.logaddexp.reduce
Ich habe angenommen, dass die Bibliotheksimplementierungen optimierter sind. Mache ich etwas Dummes oder vermisse etwas Offensichtliches? < /P>
Code: Select all
# test_logsumexp.py
from functools import partial
import timeit
import numpy as np
from scipy.special import logsumexp
def logsumexp_manual(a, axis=-1):
max_vals = np.max(a, axis=axis)
exp_terms = np.exp(a - max_vals[..., np.newaxis])
sum_exp_terms = np.sum(exp_terms, axis=axis)
return max_vals + np.log(sum_exp_terms)
# Validity check
arr = np.random.rand(10_000, 30, 3)
scipy_result = logsumexp(arr, axis=-1)
numpy_result = np.logaddexp.reduce(arr, axis=-1)
manual_result = logsumexp_manual(arr)
np.testing.assert_allclose(scipy_result, manual_result)
np.testing.assert_allclose(scipy_result, numpy_result)
n_loops = 10
setup = ("import numpy as np; "
"from scipy.special import logsumexp, softmax; "
"arr = np.random.rand(10_000, 30, 3)")
total_scipy = timeit.timeit(
"logsumexp(arr, axis=-1)",
setup=setup,
number=n_loops,
)
total_numpy = timeit.timeit(
"np.logaddexp.reduce(arr, axis=-1)",
setup=setup,
number=n_loops,
)
logsumexp_partial = partial(logsumexp_manual, a=arr)
total_manual = timeit.timeit(
logsumexp_partial,
setup=setup,
number=n_loops,
)
print(f"Scipy logsumexp: {total_scipy / n_loops:.6f} seconds per loop")
print(f"Numpy logaddexp: {total_numpy / n_loops:.6f} seconds per loop")
print(f"Manual logsumexp: {total_manual / n_loops:.6f} seconds per loop")
smallest = min(total_scipy, total_numpy, total_manual)
< /code>
Konsolenausgabe < /h1>
Scipy logsumexp: 0.070824 seconds per loop
Numpy logaddexp: 0.044226 seconds per loop
Manual logsumexp: 0.031754 seconds per loop