[SOLVED] vmap ops.index_update in Jax

Issue

I have the following code below and it’s using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code:

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0, len(y)):
      y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

And here is my attempt at using vmap:

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0, len(y))
  return jax.vmap(paraUpdate, y)(ranger)

But I receive the following error:

TypeError: vmap in_axes must be an int, None, or (nested) container
with those types as leaves, but got
Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.

I’m a little confused since the range is of type int so I’m not too sure what’s going on.

In the end, I’m trying to get this little piece optimized as best as possible to get the lowest time.

Solution

jax.vmap can express functionality in which a single operation is independently applied across multiple axes of an input. Your function is a bit different: you have a single operation iteratively applied to a single input.

Fortunately JAX provides lax.scan which can handle this situation. The implementation would look something like this:

from jax import lax

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind

@jax.jit
def filter_jax2(y):
  ranger = jnp.arange(len(y))
  return lax.scan(paraUpdate, y, ranger)[0]

print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True

%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop

%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 ┬Ás per loop

If you change your algorithm so that you’e applying the operation to every column in the array rather than the first N columns, it can be expressed with vmap like this:

@jax.jit
def filter_jax3(y):
  f = lambda col: jscp.convolve(impulse_20, col)[:-19]
  return jax.vmap(f, in_axes=1, out_axes=1)(y)

Answered By – jakevdp

Answer Checked By – Clifford M. (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *