[SOLVED] How could I speed up my written python code: spheres contact detection (collision) using spatial searching

Table of Contents


I am working on a spatial search case for spheres in which I want to find connected spheres. For this aim, I searched around each sphere for spheres that centers are in a (maximum sphere diameter) distance from the searching sphere’s center. At first, I tried to use scipy related methods to do so, but scipy method takes longer times comparing to equivalent numpy method. For scipy, I have determined the number of K-nearest spheres firstly and then find them by cKDTree.query, which lead to more time consumption. However, it is slower than numpy method even by omitting the first step with a constant value (it is not good to omit the first step in this case). It is contrary to my expectations about scipy spatial searching speed. So, I tried to use some list-loops instead some numpy lines for speeding up using numba prange. Numba run the code a little faster, but I believe that this code can be optimized for better performances, perhaps by vectorization, using other alternative numpy modules or using numba in another way. I have used iteration on all spheres due to prevent probable memory leaks and …, where number of spheres are high.

import numpy as np
import numba as nb
from scipy.spatial import cKDTree, distance

# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')     # shape: (n-spheres, )     must be loaded by np.load('a.npy') or np.loadtxt('radii_large.csv')
poss = np.load('b.npy')      # shape: (n-spheres, 3)    must be loaded by np.load('b.npy') or np.loadtxt('pos_large.csv', delimiter=',')

rnd = np.random.RandomState(70)
data_volume = 200000

radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()

x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------

# @nb.jit('float64[:,::1](float64[:,::1], float64[::1])', forceobj=True, parallel=True)
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = np.array([], dtype=np.float64)
    ends_ind = np.empty([1, 2], dtype=np.int64)
    """ using list looping """
    # particle_corsp_overlaps = []
    # ends_ind = []

    # for particle_idx in nb.prange(len(poss)):  # by list looping
    for particle_idx in range(len(poss)):
        unshared_idx = np.delete(np.arange(len(poss)), particle_idx)                                                    # <--- relatively high time consumer
        poss_without = poss[unshared_idx]

        """ # SCIPY method ---------------------------------------------------------------------------------------------
        nears_i_ind = cKDTree(poss_without).query_ball_point(poss[particle_idx], r=dia_max)         # <--- high time consumer
        if len(nears_i_ind) > 0:
            dist_i, dist_i_ind = cKDTree(poss_without[nears_i_ind]).query(poss[particle_idx], k=len(nears_i_ind))       # <--- high time consumer
            if not isinstance(dist_i, float):
                dist_i[dist_i_ind] = dist_i.copy()
        """  # NUMPY method --------------------------------------------------------------------------------------------
        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dia_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dia_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dia_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dia_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dia_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dia_max

        nears_i_ind = np.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        if len(nears_i_ind) > 0:
            dist_i = distance.cdist(poss_without[nears_i_ind], poss[particle_idx][None, :]).squeeze()                   # <--- relatively high time consumer
        # """  # -------------------------------------------------------------------------------------------------------
            contact_check = dist_i - (radii[unshared_idx][nears_i_ind] + radii[particle_idx])
            connected = contact_check[contact_check <= 0]

            particle_corsp_overlaps = np.concatenate((particle_corsp_overlaps, connected))
            """ using list looping """
            # if len(connected) > 0:
            #    for value_ in connected:
            #        particle_corsp_overlaps.append(value_)

            contacts_ind = np.where([contact_check <= 0])[1]
            contacts_sec_ind = np.array(nears_i_ind)[contacts_ind]
            sphere_olps_ind = np.where((poss[:, None] == poss_without[contacts_sec_ind][None, :]).all(axis=2))[0]       # <--- high time consumer

            ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
            if particle_idx > 0:
                ends_ind = np.concatenate((ends_ind, ends_ind_mod_temp))
                ends_ind[0, 0], ends_ind[0, 1] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]
            """ using list looping """
            # for contacted_idx in sphere_olps_ind:
            #    ends_ind.append([particle_idx, contacted_idx])

    # ends_ind_org = np.array(ends_ind)  # using lists
    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer
    gap = np.array(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

In one of my tests on 23000 spheres, scipy, numpy, and numba-aided methods finished the loop in about 400, 200, and 180 seconds correspondingly using Colab TPU; for 500.000 spheres it take 3.5 hours. These execution times are not satisfying at all for my project, where number of spheres may be up to 1.000.000 in a medium data volume. I will call this code many times in my main code and seeking for ways that could perform this code in milliseconds (as much as fastest that it could). Is it possible??
I would be appreciated if anyone would speed up the code as it is needed.


  • This code must be executable with python 3.7+, on CPU and GPU.
  • This code must be applicable for data size, at least, 300.000 spheres.
  • All numpy, scipy, and … equivalent modules instead of my written modules, which make my code faster significantly, will be upvoted.

I would be appreciated for any recommendations or explanations about:

  1. Which method could be faster in this subject?
  2. Why scipy is not faster than other methods in this case and where it could be helpful relating to this subject?
  3. Choosing between iterator methods and matrix form methods is a confusing matter for me. Iterating methods use less memory and could be used and tuned up by numba and … but, I think, are not useful and comparable with matrix methods (which depends on memory limits) like numpy and … for huge sphere numbers. For this case, perhaps I could omit the iteration by numpy, but I guess strongly that it cannot be handled due to huge matrix size operations and memory leaks.

Prepared sample test data:

Poss data: 23000, 500000
Radii data: 23000, 500000
Line by line speed test logs: for two test cases scipy method and numpy time consumption.


Based on previous answers, I designed a efficient algorithm with a much lower memory footprint and much faster than the previous ones (especially on the large dataset). That being said this algorithm is far move complex and push the limit of Python and Numba.

The key issue of previous algorithms is that they set a dia_max threshold which is much bigger than actually required. Indeed, dia_max is set to the maximum possible redius so to be sure not to miss any overlapping. The thing is the big dataset contains balls of very different size and some of them are huge. This means that previous algorithms was fetching for a very large radius around many small balls. The result was thousands of neighbours to check per ball while only few can truly overlap.

One solution to efficiently address this problem is to split the balls in different groups based on their size. The idea is to first sort balls based on radii, then split the sorted balls in two groups, then independently query neighbours between each possible pair of groups, then merge data so to apply the previous algorithm (with some additional optimizations). More specifically, the query is applied between small balls with big ones, small balls with other small ones, big balls with other big ones, and big balls with small ones.

Another key point to speed this up is to request the different neighbour queries in parallel using joblib. This solution is far from being perfect since the BallTree object needs to be duplicated which is inefficient but this is mandatory because of the way parallelism is currently done in CPython (ie. GIL, pickling, etc.). Using a package that support parallel request can bypass this inherent limitation of CPython but existing package doing that does not seems to provide an interface sufficiently useful to address this problem or are not optimized enough to be actually useful.

Finally, the Numba code can be strongly optimized by removing almost all very expensive (implicit) array allocations. Using a in-place sorting algorithm optimized for small array also improve significantly the execution time (mainly because the default implementation of Numba perform several expensive allocations and is not optimized for small arrays). In addition, the final np.unique operation can be completely rewritten with a basic loop as the main loop iterate over balls with increasing IDs (hence already sorted).

Here is the resulting code:

import numpy as np
import numba as nb
from sklearn.neighbors import BallTree
from joblib import Parallel, delayed

def flatten_neighbours(arr):
    sizes = np.fromiter(map(len, arr), count=len(arr), dtype=np.int64)
    values = np.concatenate(arr, dtype=np.int64)
    return sizes, values

def find_neighbours(searched_pts, ref_pts, max_dist):
    balltree = BallTree(ref_pts, leaf_size=16, metric='euclidean')
    res = balltree.query_radius(searched_pts, r=max_dist)
    return flatten_neighbours(res)

def vstack_neighbours(top_infos, bottom_infos):
    top_sizes, top_values = top_infos
    bottom_sizes, bottom_values = bottom_infos
    return np.concatenate([top_sizes, bottom_sizes]), np.concatenate([top_values, bottom_values])

@nb.njit('(Tuple([int64[::1],int64[::1]]), Tuple([int64[::1],int64[::1]]), int64)')
def hstack_neighbours(left_infos, right_infos, offset):
    left_sizes, left_values = left_infos
    right_sizes, right_values = right_infos
    n = left_sizes.size
    out_sizes = np.empty(n, dtype=np.int64)
    out_values = np.empty(left_values.size + right_values.size, dtype=np.int64)
    left_cur, right_cur, out_cur = 0, 0, 0
    right_values += offset
    for i in range(n):
        left, right = left_sizes[i], right_sizes[i]
        full = left + right
        out_values[out_cur:out_cur+left] = left_values[left_cur:left_cur+left]
        out_values[out_cur+left:out_cur+full] = right_values[right_cur:right_cur+right]
        out_sizes[i] = full
        left_cur += left
        right_cur += right
        out_cur += full
    return out_sizes, out_values

@nb.njit('(int64[::1], int64[::1], int64[::1], int64[::1])')
def reorder_neighbours(in_sizes, in_values, index, reverse_index):
    n = reverse_index.size
    out_sizes = np.empty_like(in_sizes)
    out_values = np.empty_like(in_values)
    in_offsets = np.empty_like(in_sizes)
    s, cur = 0, 0

    for i in range(n):
        in_offsets[i] = s
        s += in_sizes[i]

    for i in range(n):
        in_ind = reverse_index[i]
        size = in_sizes[in_ind]
        in_offset = in_offsets[in_ind]
        out_sizes[i] = size
        for j in range(size):
            out_values[cur+j] = index[in_values[in_offset+j]]
        cur += size

    return out_sizes, out_values

def small_inplace_sort(arr):
    if len(arr) < 80:
        # Basic insertion sort
        i = 1
        while i < len(arr):
            x = arr[i]
            j = i - 1
            while j >= 0 and arr[j] > x:
                arr[j+1] = arr[j]
                j = j - 1
            arr[j+1] = x
            i += 1

@nb.jit('(float64[:, ::1], float64[::1], int64[::1], int64[::1])')
def compute(poss, radii, neighbours_sizes, neighbours_values):
    n, m = neighbours_sizes.size, np.max(neighbours_sizes)

    # Big buffers allocated with the maximum size.
    # Thank to virtual memory, it does not take more memory can actually needed.
    particle_corsp_overlaps = np.empty(neighbours_values.size, dtype=np.float64)
    ends_ind_org = np.empty((neighbours_values.size, 2), dtype=np.float64)

    in_offset = 0
    out_offset = 0

    buff1 = np.empty(m, dtype=np.int64)
    buff2 = np.empty(m, dtype=np.float64)
    buff3 = np.empty(m, dtype=np.float64)

    for particle_idx in range(n):
        size = neighbours_sizes[particle_idx]
        cur = 0

        for i in range(size):
            value = neighbours_values[in_offset+i]
            if value != particle_idx:
                buff1[cur] = value
                cur += 1

        nears_i_ind = buff1[0:cur]
        small_inplace_sort(nears_i_ind)  # Note: bottleneck of this function
        in_offset += size

        if len(nears_i_ind) == 0:

        x1, y1, z1 = poss[particle_idx]
        cur = 0

        for i in range(len(nears_i_ind)):
            index = nears_i_ind[i]
            x2, y2, z2 = poss[index]
            dist = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
            contact_check = dist - (radii[index] + radii[particle_idx])
            if contact_check <= 0.0:
                buff2[cur] = contact_check
                buff3[cur] = index
                cur += 1

        particle_corsp_overlaps[out_offset:out_offset+cur] = buff2[0:cur]

        contacts_sec_ind = buff3[0:cur]
        sphere_olps_ind = contacts_sec_ind

        for i in range(cur):
            ends_ind_org[out_offset+i, 0] = particle_idx
            ends_ind_org[out_offset+i, 1] = sphere_olps_ind[i]

        out_offset += cur

    # Truncate the views to their real size
    particle_corsp_overlaps = particle_corsp_overlaps[:out_offset]
    ends_ind_org = ends_ind_org[:out_offset]

    assert len(ends_ind_org) % 2 == 0
    size = len(ends_ind_org)//2
    ends_ind = np.empty((size,2), dtype=np.int64)
    ends_ind_idx = np.empty(size, dtype=np.int64)
    gap = np.empty(size, dtype=np.float64)
    cur = 0

    # Find efficiently duplicates (replace np.unique+np.sort)
    for i in range(len(ends_ind_org)):
        left, right = ends_ind_org[i]
        if left < right:
            ends_ind[cur, 0] = left
            ends_ind[cur, 1] = right
            ends_ind_idx[cur] = i
            gap[cur] = particle_corsp_overlaps[i]
            cur += 1

    return gap, ends_ind, ends_ind_idx, ends_ind_org

def ends_gap(poss, radii):
    assert poss.size >= 1

    # Sort the balls
    index = np.argsort(radii)
    reverse_index = np.empty(index.size, np.int64)
    reverse_index[index] = np.arange(index.size, dtype=np.int64)
    sorted_poss = poss[index]
    sorted_radii = radii[index]

    # Split them in two groups: the small and the big ones
    split_ind = len(radii) * 3 // 4
    small_poss, big_poss = np.split(sorted_poss, [split_ind])
    small_radii, big_radii = np.split(sorted_radii, [split_ind])
    max_small_radii = sorted_radii[max(split_ind, 0)]
    max_big_radii = sorted_radii[-1]

    # Find the neighbours in parallel
    result = Parallel(n_jobs=4, backend='threading')([
        find_neighbours(small_poss, small_poss, small_radii+max_small_radii),
        find_neighbours(small_poss, big_poss,   small_radii+max_big_radii  ),
        find_neighbours(big_poss,   small_poss, big_radii+max_small_radii  ),
        find_neighbours(big_poss,   big_poss,   big_radii+max_big_radii    )
    small_small_neighbours = result[0]
    small_big_neighbours = result[1]
    big_small_neighbours = result[2]
    big_big_neighbours = result[3]

    # Merge the (segmented) arrays in a big one
    neighbours_sizes, neighbours_values = vstack_neighbours(
        hstack_neighbours(small_small_neighbours, small_big_neighbours, split_ind),
        hstack_neighbours(big_small_neighbours, big_big_neighbours, split_ind)

    # Reverse the indices.
    # Note that the results in `neighbours_values` associated to 
    # `neighbours_sizes[i]` are subsets of `query_radius([poss[i]], r=dia_max)`
    # on a `BallTree(poss)`.
    res = reorder_neighbours(neighbours_sizes, neighbours_values, index, reverse_index)
    neighbours_sizes, neighbours_values = res

    # Finally compute the neighbours with a method similar to the 
    # previous one, but using a much faster optimized code.
    return compute(poss, radii, neighbours_sizes, neighbours_values)

result = ends_gap(poss, radii)

Here is the results (still on the same i5-9600KF machine):

Small dataset:
 - Reference optimized Numba code:    256 ms
 - This highly-optimized Numba code:   82 ms

Big dataset:
 - Reference optimized Numba code:    42.7 s  (take about 7~8 GiB of RAM)
 - This highly-optimized Numba code:   4.2 s  (take about  1  GiB of RAM)

Thus the new algorithm is about 3.1 time faster on the small dataset (in addition to the previous optimizations), and about 10 times faster on the big dataset! This is 3 order of magnitude faster than the initially posted algorithms.

Note that 80% of the time is spend in the BallTree query (which is already mostly parallel). The main Numba computing function takes only 12% of the time and more than 75% of the time is spent in sorting the input indices. As a result, the neighbourhood search is clearly the bottleneck. It can be improved a bit by splitting the current queries in multiple smaller one but this will make the code even more complex for a relatively small improvement (eg. 1.5x faster). Note that more complex code are harder to maintain and modifications are bug-prone. Thus, I think moving to a native language to overcome the limitation of Python is the best solution to increase performance. That being said, writing a faster native code to solve this problem is far from being simple (unless you find good k-d tree, octree or ball tree library). Still, it is certainly better than optimizing this code further.


A profiling analysis shows that at least 50% of the time in BallTree of scikit-learn is spent in unoptimized scalar loops that could use SIMD instructions like AVX-2 (and loop unrolling) to be about 4 times faster. Additionally, some multi-threading issue are also visible (the 4 threads on the top are the joblib workers, the light-green sections are the idle time):


This shows that this implementation is sub-optimal. One possible way to easily improve the execution time may be to optimize the hot loops of the scikit-learn BallTree implementation. Another strategy could be to try to use threads more efficiently (possibly by releasing the GIL in some parts of the scikit-learn module).

As the BallTree class of scikit-learn is written in Cython (BallTree is based on DKTree itself based on BinaryTree). You can try to rebuild the package on your machine and simply tweak compiler optimizations. Using the parameter -O3 -march=native -ffast-math should enable the compiler to use faster SIMD instruction and more aggressive optimizations resulting in a significant speed up. Note that using -ffast-math is unsafe as it assume the code of Scikit will never use NaN, Inf or -0 values (otherwise the result is completely undefined) and that floating-point number operations are associative (resulting in different results). That being said, such an option is critical to improve the automatic vectorization of numerical codes.

For the GIL, one can see that it is released in the query_radius function but it does not seems the case for the constructor of BallTree. Maybe, the simplest solution is to implement a parallel version of query/query_radius like Scipy did.

Answered By – Jérôme Richard

Answer Checked By – Katrina (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *