[SOLVED] Select random element from a set, faster than linear time (Haskell)

Issue

I’d like to create this function, which selects a random element from a Set:

randElem :: (RandomGen g) => Set a -> g -> (a, g)

Simple listy implementations can be written. For example (code updated, verified working):

import Data.Set as Set
import System.Random (getStdGen, randomR, RandomGen)

randElem :: (RandomGen g) => Set a -> g -> (a, g)
randElem s g = (Set.toList s !! n, g')
    where (n, g') = randomR (0, Set.size s - 1) g

-- simple test drive
main = do g <- getStdGen
          print . fst $ randElem s g
    where s = Set.fromList [1,3,5,7,9]

But using !! incurs a linear lookup cost for large (randomly selected) n. Is there a faster way to select a random element in a Set? Ideally, repeated random selections should produce a uniform distribution over all options, meaning it does not prefer some elements over others.


Edit: some great ideas are popping up in the answers, so I just wanted to throw a couple more clarifications on what exactly I’m looking for. I asked this question with Sets as the solution to this situation in mind. I’ll prefer answers that both

  1. avoid using any outside-the-function bookkeeping beyond the Set’s internals, and
  2. maintain good performance (better than O(n) on average) even though the function is only used once per unique set.

I also have this love of working code, so expect (at minimum) a +1 from me if your answer includes a working solution.

Solution

Here’s an idea: You could do interval bisection.

  1. size s is constant time. Use randomR to get how far into the set you are selecting.
  2. Do split with various values between the original findMin and findMax until you get the element at the position you want. If you really fear that the set is made up say of reals and is extremely tightly clustered, you can recompute findMin and findMax each time to guarantee knocking off some elements each time.

The performance would be O(n log n), basically no worse than your current solution, but with only rather weak conditions to the effect that the set not be entirely clustered round some accumulation point, the average performance should be ~((logn)^2), which is fairly constant. If it’s a set of integers, you get O(log n * log m), where m is the initial range of the set; it’s only reals that could cause really nasty performance in an interval bisection (or other data types whose order-type has accumulation points).

PS. This produces a perfectly even distribution, as long as watching for off-by-ones to make sure it’s possible to get the elements at the top and bottom.

Edit: added ‘code’

Some inelegant, unchecked (pseudo?) code. No compiler on my current machine to smoke test, possibility of off-by-ones, and could probably be done with fewer ifs. One thing: check out how mid is generated; it’ll need some tweaking depending on whether you are looking for something that works with sets of ints or reals (interval bisection is inherently topological, and oughtn’t to work quite the same for sets with different topologies).

import Data.Set as Set
import System.Random (getStdGen, randomR, RandomGen)

getNth (s, n) = if n = 0 then (Set.findMin s) else if n + 1 = Set.size s then Set.findMax s
    else if n < Set.size bott then getNth (bott, n) else if pres and Set.size bott = n then n
    else if pres then getNth (top, n - Set.size bott - 1) else getNth (top, n - Set.size)
    where mid = ((Set.findMax s) - (Set.findMin s)) /2 + (Set.findMin s)
          (bott, pres, top) = (splitMember mid s)

randElem s g = (getNth(s, n), g')
    where (n, g') = randomR (0, Set.size s - 1) g

Answered By – Nicholas Wilson

Answer Checked By – David Marino (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *