Page 1 of 1

Warum ist die manuelle (-ish) Berechnung von Protokoll-Sum-Exp schneller als Numpy/Scipy-Funktionen?

Posted: 26 Aug 2025, 11:13
by Anonymous
Ich arbeite mit einem sehr großen Array und war überrascht, dass die manuelle Berechnung des Protokolls der Summe der Exponentials tatsächlich schneller ist als die integrierten Funktionen für diese Aufgabe von Scipy (scipy (

Code: Select all

scipy.special.logsumexp
) und sogar numpy (

Code: Select all

numpy.logaddexp.reduce
)
Ich habe angenommen, dass die Bibliotheksimplementierungen optimierter sind. Mache ich etwas Dummes oder vermisse etwas Offensichtliches? < /P>

Code: Select all

# test_logsumexp.py

from functools import partial
import timeit
import numpy as np
from scipy.special import logsumexp

def logsumexp_manual(a, axis=-1):
max_vals = np.max(a, axis=axis)
exp_terms = np.exp(a - max_vals[..., np.newaxis])
sum_exp_terms = np.sum(exp_terms, axis=axis)
return max_vals + np.log(sum_exp_terms)

# Validity check
arr = np.random.rand(10_000, 30, 3)
scipy_result = logsumexp(arr, axis=-1)
numpy_result = np.logaddexp.reduce(arr, axis=-1)
manual_result = logsumexp_manual(arr)
np.testing.assert_allclose(scipy_result, manual_result)
np.testing.assert_allclose(scipy_result, numpy_result)

n_loops = 10
setup = ("import numpy as np; "
"from scipy.special import logsumexp, softmax; "
"arr = np.random.rand(10_000, 30, 3)")

total_scipy = timeit.timeit(
"logsumexp(arr, axis=-1)",
setup=setup,
number=n_loops,
)

total_numpy = timeit.timeit(
"np.logaddexp.reduce(arr, axis=-1)",
setup=setup,
number=n_loops,
)

logsumexp_partial = partial(logsumexp_manual, a=arr)
total_manual = timeit.timeit(
logsumexp_partial,
setup=setup,
number=n_loops,
)

print(f"Scipy  logsumexp:  {total_scipy  / n_loops:.6f} seconds per loop")
print(f"Numpy  logaddexp:  {total_numpy  / n_loops:.6f} seconds per loop")
print(f"Manual logsumexp:  {total_manual / n_loops:.6f} seconds per loop")
smallest = min(total_scipy, total_numpy, total_manual)
< /code>
 Konsolenausgabe < /h1>
Scipy  logsumexp:   0.070824 seconds per loop
Numpy  logaddexp:   0.044226 seconds per loop
Manual logsumexp:   0.031754 seconds per loop