[SOLVED] Have float64 or float32 attribute in numba jitclass

Issue

How to have a numba jitclass with an argument which can be either a float64 or a float32 ? With functions, the following code works:

import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass


@njit()
def f(a):
    print(a.dtype)
    return a[0]


a = np.zeros(3)
f(a)
f(a.astype(np.float32))

while trying to use both float32 and float64 with class attributes fails:

@jitclass([('arr', numba.types.float64[:])])
class MyClass():
    def __init__(self):
        pass

    def f(self, a):
        self.arr = a


myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))

Is there a workaround ?

Solution

When you call MyClass(), Numba need to instantiate a class and because Numba only work with well-defined strongly types (this is what makes it fast and so useful), the field of the class need to be typed before the instantiation of an object. Thus, you cannot define the type of MyClass fields when the method f is called because this call is made by the CPython interpreter which is dynamic. Note that a class usually have more than one method (otherwise such a class would not be very useful) and this is why partial compilation is not really possible either.

One simple solution to address this problem is simply to use two types:

class MyClass():
    def __init__(self):
        pass

    def f(self, a):
        self.arr = a

MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)

myclass = MyClass_float32() # Instantiate the class lazily and an object
# `self.arr` is already instantiated here and it has `float32[:]` type.
myclass.f(np.zeros(3, dtype=np.float32))

myclass = MyClass_float64()
myclass.f(np.zeros(3, dtype=np.float64))

Answered By – Jérôme Richard

Answer Checked By – Senaida (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *