[SOLVED] Why is this function slower in JAX vs numpy?

Issue

I have the following numpy function as seen below that I’m trying to optimize by using JAX but for whatever reason, it’s slower.

Could someone point out what I can do to improve the performance here? I suspect it has to do with the list comprehension taking place for Cg_new but breaking that apart doesn’t yield any further performance gains in JAX.

import numpy as np 

def testFunction_numpy(C, Mi, C_new, Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = np.zeros((1, len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new, Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop

Here’s the JAX equivalent:

import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C, Mi, C_new, Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = jnp.zeros((1, len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new, Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop

Solution

For general considerations on benchmark comparisons between JAX and NumPy, see https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy

As for your particular code: when JAX jit compilation encounters Python control flow, including list comprehensions, it effectively flattens the loop and stages the full sequence of operations. This can lead to slow jit compile times and suboptimal code. Fortunately, the list comprehension in your function is readily expressed in terms of native numpy broadcasting. Additionally, there are two other improvements you can make:

  • there is no need to forward declare Wg_new and Cg_new before computing them
  • when computing dot(inv(A), B), it is much more efficient and precise to use np.linalg.solve rather than explicitly computing the inverse.

Making these three improvements to both the numpy and JAX versions result in the following:

def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
    Wg_new = np.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
    Wg_new = jnp.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop

Both functions are a fair bit faster than they were previously due to the improved implementation. You’ll notice, however, that JAX is still slower than numpy here; this is somewhat to be expected because for a function of this level of simplicity, JAX and numpy are both generating effectively the same short series of BLAS and LAPACK calls executed on a CPU architecture. There’s simply not much room for improvement over numpy’s reference implementation, and with such small arrays JAX’s overhead is apparent.

Answered By – jakevdp

Answer Checked By – Cary Denson (BugsFixing Admin)

Leave a Reply

Your email address will not be published. Required fields are marked *