Stable Sampling Trees

Stable Sampling Trees
Anna Saveleva on Unsplash: @paneva

TLDR: Sampling dynamic categorical distributions robustly, really fast.

Sampling trees, how to implement them, and the perils of using floats.

Suppose we want to sample from a set of 8 objects $X=\set{x_0,\dots,x_7}$ proportional to the following weights:

We could do this using the following steps: 1) build a prefix sum (cumulative sum), 2) randomly generate an integer between 0 and the sum of the weights, 3) binary search for the index whose prefix is the first that is strictly larger than the integer.

What if we need to change weights, or even add or remove objects? Following the above, we would have to recompute the prefix sum from scratch, taking $O(n)$ time ($n$ is the number of objects). Sampling trees take $O(\log n)$ time to sample and $O(\log n)$ time to update/add/remove a weight. Much more efficient if we need to perform even a handful of updates!

In Reinforcement Learning, Prioritised Experience Replay (PER) uses sampling trees to sample transitions proportional to temporal difference error. Here, the total number of objects (transitions) is massive but only a tiny number of weights are updated after each training iteration. In my own research, I've used sampling trees to efficiently simulate $k$-means++ sampling in special kernel spaces related to graph clustering.

In this post, I'll go over what sampling trees are and how they work, why floating point values can be evil and how to mitigate their effects on sampling trees, and finally a couple of tips on implementation optimisations. I'll hopefully add some benchmarks in the future.

What are sampling trees?

Given the weights in Fig 1, consider a balanced binary tree where these weights are leaves and internal nodes store the sum of the weight of their children:

Why is this useful? Well if we traverse the tree from the root by flipping a coin weighted by child weights, then traversing to a leaf node corresponds exactly to sampling proportional to the weights. As an example, consider the following sequence of coin tosses:

  • At the root we went left with probability $\frac{238}{348}$
  • Then right with probability $\frac{104}{238}$
  • Then right with probability $\frac{56}{104}$ to get to the object with index $3$.

The probability of this traversal was $\frac{238}{348}\frac{104}{238}\frac{56}{104} = \frac{56}{348}$ which is exactly the probability of sampling this object proportional to its weight.

To see why this works, we can write the probability of sampling $x\in X$ proportional to its weight as follows:

For any set of nested sets $S_1\subset S_2\subset S_3 \subset \dots \subset S_m \subseteq X$ where $x\in S_1$, we have $$\Pr\left[x\ \text{is sampled} \right] = \frac{w(x)}{w(X)} = \frac{w(S_m)}{w(X)}\frac{w(S_{m-1})}{w(S_m)}\dots \frac{w(x)}{w(S_1)}$$ where $w(x)$ is the weight of object $x$ and $w(S)$ is the sum of the weights of objects in $S$. Each subset corresponds to a node along the path from the root to $x$. In the telescoping product, each term represents an intermediate coin toss and everything cancels out except what we want, $\frac{w(x)}{w(X)}$.

Updating Sampling trees

Modifying weights

Say we want to change the weight of the object (leaf) with index 3 from $56$ to $5$. To maintain the invariant that each internal node holds the sum of its children, we need to walk up the path from the leaf to the root and recompute internal values. I'll call this update method "rebuild from children".

As an alternative, we could have just subtracted the difference $(56-5)$ from each internal node in the path to the root. I'll call this update method "propagate differences". As we will see later, this can become wildly numerically unstable if the weights are represented with floating point. However this approach can be a lot faster. We could in theory update all the internal nodes on the path to the root simultaneously or even use atomics to apply multiple leaf updates at the same time.

Adding new objects

Adding a new leaf is straightforward. We just add a leaf at the end of the tree, promoting existing leaves to internal nodes if necessary. Then rebuild internal node values along the path from the new leaf to the root. Notice this keeps the tree balanced. Here we add a new node with weight 100.

Deleting nodes

To delete a node, we swap it with the final leaf in the tree and delete it. Then we need to rebuild internal node values along the path from the deleted node to the root and also from the node we swapped to the root. If there is a dangling leaf left over, we absorb it into it's parent. This process keeps the tree balanced.

Given the sampling tree in Fig 5, Fig 6 shows what the sampling tree looks like after deleting the leaf with index 4. First this node is swapped with the node at the end of the tree, the one with index 8. Then we delete node 4 and absorb the dangling leaf, node 0, into it's parent (the internal node with value 177 in Fig 5). Finally we rebuild internal nodes up to the root from this node and the node we swapped (node 8).

Constructing Sampling Trees

We've seen how to update weights, add and delete leaves, but what if we want to build a sampling tree from scratch, given a list of weights? What about if we want to use higher arity trees? Consider the following constructed sampling tree:

We can use a Vec, or dynamic array, to store the tree. Let $d$ be the arity of the tree we wish to construct.

  • The root lives at the zeroth index
  • The $j$th child of a node at index $i$ is at index $(id+1+j)$, where $0\le j<d$.
  • The parent of a node at index $i$ is at index $(i-1)//d$ (where $//$ is floor division).

The Vec representing the above ternary tree is just

    [348 182  95  71  77  57  48  56  18  21  45  26]

Given these rules, to construct the sampling tree, we just copy the weights to the end of the vec, then build internal nodes out of their children, back to front. The takes $O(n)$ time. The only question is how many internal nodes do we need to allocate space for? In the above example, we needed $4$ internal nodes for a tree with $8$ leaves and arity $3$. The trick is to count edges.

Let $I$ be the number of internal nodes, and let $L$ be the number of leaves. Looking at the tree, every node except the root has an incoming edge, so the total number of edges is $I+L-1$. Also, every internal internal node has exactly $d$ outgoing edges (for now we'll ignore the fact that the internal node at the end may have fewer). Thus by counting the total number of edges two different ways, we get the following equation, $dI = I+L-1$. Rearranging we get $I = \frac{L-1}{d-1}$. To account for that extra internal node with potentially fewer edges than $d$, we'll need to take the ceiling of this to find the number of internal nodes to allocate space for: $$I=\left\lceil \frac{L-1}{d-1}\right\rceil.$$

Using the equality that $\left\lceil \frac{x}{y}\right\rceil = \left\lfloor \frac{x+y-1}{y}\right\rfloor = \left\lfloor \frac{x-1}{y} + 1\right\rfloor = (x-1)//y +1$, we get $$I = (L-2)//(d-1) + 1.$$ For the sake of efficiency, $d$ should be a power of 2 so that finding children and parent indices can be reduced to bit shifting and addition (no integer multiplications/divisions required!).

Floats are evil (a detour)

You should try to use integer-valued weights in your sampling trees at all costs. If that's not possible then read on.

We first need to discuss why float truncation is evil and show that your intuition is probably wrong. Then we can talk about what specifically can go wrong with floats for weights in a sampling tree.

If you haven't been stung by floats before, let me explain why data structures involving floating points can introduce bugs that only appear after days, weeks, or even years of working. Let's consider one of the simplest possible floating point data structures, a running sum. Given a stream of positive floats, its job is to maintain the sum of all the numbers seen so far.

class RunningSum:
    def __init__(self):
        self.total = np.float32(0.0)

    def update(self, value):
        self.total += np.float32(value)
    
    def get(self):
        return self.total

Now let's try 100 million updates over 3 runs and plot the mean relative error between our running sum and the actual sum, (or at least what numpy says when we use double precision). We sample points geometrically between 1 and 1 million so the input has lots of numbers across a range of magnitudes.

As Fig 8 shows, the error appears to grow linearly with the number of updates. This is bad, really bad! You might have expected truncation errors to cancel themselves out in the long run. Fig 8 shows this is not true. From the signed error, we can see that a systematic bias is introduced which biases the running sum below the real value. There are a couple of standard techniques for dealing with this truncation error. If you know all the updates will be positive you can use Kahan summation. This attempts to track the bits that are lost due to truncation:

class KahanSum:
    def __init__(self):
        self.total = np.float32(0.0)
        self.c = np.float32(0.0)  # c tracks the low-order bits lost to truncation

    def update(self, value):
        y = np.float32(value) - self.c     # Try to add the low-order bits back in
        t = self.total + y                 # compute the new total
        self.c = (t - self.total) - y      # update the low-order bits lost computing t
        self.total = t                     # update the running total

    def get(self):
        return self.total

If we plot the unsigned relative error of both methods, we get Fig 9. Kahan has almost zero relative error which doesn't appear to grow over time at all. The cost is that we need the extra self.c accumulator and require 4 FLOPs instead of just 1.

In terms of accuracy, Hallman and Ipsen proved that the relative error of Kahan Summation is at most proportional to $3u + 4nu^2 + O(u^3)$, where $u$ is the unit roundoff and $n$ is the number of terms in the sum.

To give some perspective, for singe precision (32 bits), $u\approx 10^{-8}$ and for double precision (64 bits), $u\approx 10^{-16}$. We would have to do roughly $n=\frac 1 {u}$ updates before the error due to $n$ would become noticeable (larger than $3u$). For single precision this is about $\approx 10^{8}$ terms, for double precision, it would be $\approx {10}^{16}$. Depending on the application, double precision may be preferred as an error increasing at a rate of a single unit round off every ${10}^{16}$ updates is probably acceptable.

Managing relative error in sampling trees with floats

Whenever we rebuild an internal node from its children, we have to sum the weights of the children, or add the difference before and after the leaf update (see the modifying weights Section for an example). While elegant and faster, $O(d\log n)$ vs $O(\log n)$ where $n$ is the size of the current tree, the second approach is severely susceptible to truncation errors because the difference can be far smaller than the weight at the internal node and weights can also drift over time due to truncation errors.

Managing internal node updates that rebuild from children

Consider the first method, where we rebuild each internal node's value out of its children. Then the only sources of truncation error is propagating updates up the tree, rebuilding from children.

Because the sum at each internal node only considers d (the arity) children and only $\log n$ sums are computed to update the root, the relative error at any internal node is bounded by $O(u\cdot d \cdot \log(n))$, where $n$ is the size of the tree. This is completely independent of the number of updates and is therefore a solid option in practice.

Managing internal node updates that propagates weight differences

If we want to use the fancier method of propagating updates with differences, then the relative error starts to depend on the number of updates, not just the unit roundoff and the structure of the tree.

To deal with the relative error at each node (naively $O(u\cdot m)$ where $m$ is the number of updates to that node), we could store a Kahan accumulator at each node. If we expect positive and negative differences to be propagated (realistic), we can store the total with a positive and negative part, each with their own kahan accumulator, or use Neumaier's variant, which can handle mixed-sign updates with a single accumulator. A nice bit of python lore is that Neumaier's variant was only added to python's built-in sum method in Version 3.12.

The bottom line is that we need to maintain extra information (the accumulators) if we want to use this update method while being numerically stable as the number of updates grows.

Implementing Sampling Trees

Struct of Arrays

In almost all cases, using a struct of arrays layout is preferable since the tree layout ensures children are next to each other in memory. This means computing the sum of the child weights is very cache friendly for low arity trees. Depending on which update method you use (rebuild from children vs difference propagation with accumulators), this might look something like the following in rust:

    struct TreeData<const ARITRY: usize>{
        data: Vec<f32>
        n_accs: Vec<f32> // if we use Neumaier accumulators

        // Probably also include HashMaps mapping node names
        // to a unique index and back and other HasMaps mapping 
        // the unique indies to tree indices and back.
        // Using the newtype pattern for index types makes life a lot easier! 
    }

Handling auxiliary information

In more complicated applications, like my dynamic spectral clustering paper, the weights at internal nodes are not stored explicitly, but are computed on the fly from auxiliary information associated with the set of leaves each node represents. This information is usually updatable in the same way as data and so fits in nicely with the struct of arrays layout. As an example, if the objects in the sampling tree are nodes in a graph, a typical piece of auxiliary information is volume, the sum of degrees in a set of nodes. Here, each node stores the volume of its children and leaves store degrees.

Arity should be a power of 2:

If we can prove at compile time that the arity is a power of two, then we (or rather the compiler) can replace the multiplication and division in the parent/child index calculations with left and right bit shifts. A dirty way of doing this in (stable) rust is using const generics, macros and a trait bound:

    pub trait PowerOfTwo {} // a marker trait
    pub struct ConstPow2<const N: usize>; // a dummy struct 

    macro_rules! impl_power_of_two_up_to_1024 {
        ( $( $pow:expr ),* ) => {
            $(
                // impl the marker trait for each input
                impl PowerOfTwo for ConstPow2<$pow> {}
            )*
        };
    }
    impl_power_of_two_up_to_1024! {
        2, 4, 8, 16, 32, 64, 128, 256, 512, 1024
    }

    // usage:
    impl<const ARITY: usize> TreeData<ARITY> 
        where ConstPow2<ARITY>: PowerOfTwo {
        // Implementation    
    }

Batching updates:

One big optimisation that improves update throughput and numerical stability is batching internal node updates. If we sample infrequently, then we can delay internal node updates and then process them together. This has a massive impact, regardless of if we update internal nodes using the rebuild from children or difference propagation approaches. In the rebuild from children case, given a batch of leaves to propagate updates from, we update internal nodes level by level, starting with the lowest leaves (a tricky case is managing leaves on the second last level). This saves a massive amount of time because every time two update paths collide, we only continue with one of them. This is the same reason building a sampling tree from scratch takes $O(n)$ time, but performing $n$ updates without any batching takes $O(n\log n)$ time.

Given the tree in Figure 10, the red nodes highlight the nodes that have to be updated if we update the leaves with indices 0,3,5, and 7. If we were to update them as a batch, then the root would only have to be rebuilt once, instead of 4 times, likewise, the internal nodes with values 238 and 110 only have to be rebuilt once instead of twice. For very large trees, this makes a massive difference.

Since batching reduces the number of times internal nodes need to be rebuilt, it also helps numerical stability if we use the difference propagation approach for updating internal nodes since it cuts down the number of updates required at each node, assuming we apply differences to nodes level by level too. Naively, if we merge difference together from child nodes, the relative error at the root after $b$ batches will be $O(u\cdot b\cdot d \log n)$. Using Kahan/Neumaier accumulators will help even more. Alternatively, we can periodically use the rebuild from children approach to bound the relative error back to $O(u\cdot d\cdot \log n)$.

Benchmarks

When I have time, I plan to write a numerically stable version of my dynamic spectral clustering algorithm, and I'm going to benchmark these optimisations. I'll add the results here.