# Implementing Bitonic Merge Sort in Vulkan Compute

A bitonic sorting network lays out a sequence of compare-and swap operations that, when applied to an array of sortable elements, sorts these elements. Fig. 0: Major steps of a bitonic merge sorting network applied to 1024 sortable elements.The pixels in the final square are sorted by brightness, left-to-right and top-to-bottom.

## Why Bitonic Merge Sort?

The beauty of this algorithm is that it maps really well onto parallel hardware, such as a GPU, where it has constant and relatively low performance complexity $$O(\log^2(n))$$. That is to say: It's quite fast. The implementation I'm discussing here was able to sort 1M random ints in 4ms. Compared to the 40ms I measured for C++'s std::sort, that's a 10x speedup! On top of that, for a given number of sortable elements, it will always take the same amount of time to complete.

When the algorithm executes, each worker thread performs one single operation for every pair of sortable elements: compare-and-swap. But how to choose the correct pairings for sortable elements? This is where the algorithm's complexity is buried.

## Constructing a Bitonic Sorting Network

First, let's look at a diagram of a relatively small sorting network, say, for 16 sortable elements. The diagram is for the alternative representation of a bitonic sorting network, as I found it on wikipedia. I prefer this version because it is slightly simpler: all worker threads use the same direction for their compare-and-swap operations. Fig.1: Sorting diagram for 16 sortable elementsHorizontal lines represent array indices into our array of sortable elements, each vertical line represents one worker thread performing a compare-and-swap operation.

In the diagram, each worker thread is drawn as a vertical line, with the start- and endpoints marking the indices of the current pair of sortable elements. All coloured blocks which are in the same column form a step, wich can be executed in parallel. We can see from the diagram that the number of threads per step is constant, and that it is always n/2, with n being the number of sortable elements.

Before we proceed to figuring out sort index pairs for each worker thread based on each thread's execution step and thread number, lets take a step back, look at the diagram through squinted eyes and see if some patters emerge if we group similar looking blocks together.

## We can point out some macro patterns:

1. Green block: doubles in height until it reaches total number of elements. Let's call such a block flip
2. Yellow block: after each green block, there is a cascade of yellow blocks, each half the previous height, until its height spans only two sortable elements. Let's call such a block disperse
3. Blocks are periodic within their column.

We can write the following pseudo-code:

 1 2 3 4 5 6 7 8  int n = 16; // n is our total number of sortable elements, must be power of 2 for (int h = 2; h <= n; h *= 2) { do_flip(h); for (int hh = h/2; hh > 1; hh /= 2){ do_disperse(hh); } } 

## Look at inner patterns:

Looking at the diagram again, we see that each of the two block types, green and yellow, follow distinct patterns for how each worker thread grabs pairs of sortable elements: While worker thread lines inside green ("flip") blocks start at full height and converge towards the centre, in yellow ("disperse") blocks they start at half height, and then shift down. We also note that the only variable controlling the pattern is the overall height h of a yellow or green block.

We can therefore split the problem in two:

1. Find rule to calculate pairs for green (“flip”) block based on height h.
2. Find rule to calculate pairs for yellow (“disperse”) block based on height h.

## Index pairs for “flip” (green) blocks

Let's enumerate sortable element index pairs from our green blocks, and label them $$T_{0\dots t}$$, where $$t$$ represents the worker thread index to which a sortable element index pair is assigned to:

\begin{align*} h = 2 &: T_0 [0:1] \\ h = 4 &: T_0 [0:3], T_1[1:2]\\ h = 8 &: T_0 [0:7], T_1 [1:6], T_2 [2:5], T_3 [3:4]\\ \end{align*}

We see that we can form pairs of indices based on the following rule:

$$T_t = [ t : h-t-1], \text{ for } t = 0,\dots,h/2$$

## Index pairs for “disperse” (yellow) blocks

Similarly, if we enumerate index pairs for yellow blocks we get

\begin{align*} h = 2 &: T_0 [0:1] \\ h = 4 &: T_0 [0:2], T_1[1:3]\\ h = 8 &: T_0 [0:4], T_1 [1:5], T_2 [2:6], T_3 [3:7]\\ \end{align*}

We see that we can form pairs of indices based on the following rule:

$$T_t = [t : t + h/2], \text{ for } t = 0,\dots,h/2$$

## Rules for rows of blocks

Now, let's look at the diagram again.

We notice that both green and yellow blocks repeat vertically until they fill out the full height of their column. We can model this repetition as a combination of modulo (the pattern is the same within a block) and an integer-division based offset.

As a first step, we make sure both rules for green and yellow blocks are contained within their respective height, $$h$$. Note that we define these rules now for the total number of sortable elements, $$n$$ instead of the earlier $$h$$. Instead of using $$t$$ “unbounded”, we introduce a new term around $$t$$:

$$T_t = [ q + (t\bmod \frac{h}{2}) : q + h-(t \bmod \frac{h}{2})-1], \text{ for } t = 0,\dots,n/2$$

$$T_t = [ q + (t \bmod \frac{h}{2}) : q + (t \bmod \frac{h}{2}) + \frac{h}{2}], \text{ for } t = 0,\dots,n/2$$

What does the new term $$(t \bmod \frac{h}{2})$$ do? It limits the output for the term for any given $$t$$ to $$0,\dots,h/2$$, something you will recognise as a condition for our earlier rules.

You will also have noted that both rules introduced a new variable $$q$$, which represents the offset. We can express $$q$$ fully as a function of $$h$$ and $$t$$:

$$q = \lfloor{ (\frac{2t}{h}) } * h$$

In pseudo-code, this translates to:

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15  void calc_pairs(int t, int h){ int q = ((2 * t) / h) * h; int half_h = h / 2; int flip; flip = q + (t % half_h) ; flip = q + h - (t % half_h) - 1; int disp; disp = q + (t % half_h) ; disp = q + (t % half_h) + half_h; } 

## Let's make a first POC!

In order to test these little formulas, and becasue I thought it would generate some nice visuals, I wrote a little C program which, when given a count of sortable elements, prints a diagram, and index pairs for each step of the resulting sorting network. Fig. 2: Screengrab of test application generating sorting network index pairs. You can find the source code for this c-application on github.

# Implementing this as a Vulkan compute shader

Now that we have a reliable and tested method to generate index pairs, we can take a stab at implementing the algorithm using Vulkan Compute.

In Vulkan, compute tasks require us to bind a compute pipeline - which is thankfully very simple and takes a single compute shader. We'll get to the code for the compute shader in a little bit, but for now, let's look at what we need to specify in order to dispatch a compute command:

 1 2 3 4 5 6  // Provided by VK_VERSION_1_0 void vkCmdDispatch( VkCommandBuffer commandBuffer, uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ); 

We need to find the group counts in x, y, and z dimension for our dispatch command. Since we want each compute command to execute one step of our bitonic sorting grid, let's stick to a one-dimensional invocation: I'm setting groupCountY = 1 and groupCountZ = 1, since these values must be at minimum 1.

## What do we set for groupCountX?

Well, this will depend on the size of our workgroup, and the total number of sortable elements. Let's say we have $$n=16$$ sortable elements. This means, we will need $$\frac{n}{2} = 8$$ worker threads. Let's say each workgroup contains 8 worker threads, this means we would need to call our dispatch with a groupCountX of 1.

Every invocation triggers one workgroup of thread workers. In shade code, we can specify how many threads we want to have per work-group, by specifying local_size_x, and in this particular case, we would choose 8.

## Naive Bitonic Merge Sort

Now, let's look at a first stab at a shader implementation for bitonic merge sort:

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65  #version 450 core // Sort 16 sortable elements #define N_ELEMENTS 16 layout (local_size_x = (N_ELEMENTS / 2) ) in; layout (set=0, binding = 0) buffer SortData { uint value[]; }; // Workgroup local memory. We use this to minimise round-trips to global memory. // Allows us to evaluate a sorting network of up to 1024 with a single shader invocation. shared uint local_value[N_ELEMENTS]; void local_compare_and_swap(ivec2 idx){ if (local_value[idx.x] < local_value[idx.y]) { uint tmp = local_value[idx.x]; local_value[idx.x] = local_value[idx.y]; local_value[idx.y] = tmp; } } void do_flip(int h){ uint t = gl_LocalInvocationID.x; int q = ((2 * t) / h) * h; ivec2 indices = q + ivec2( t % h, h - (t % h) ); local_compare_and_swap(indices); } void do_disperse(int h){ uint t = gl_LocalInvocationID.x; int q = ((2 * t) / h) * h; ivec2 indices = q + ivec2( t % h, (t % h) + (h / 2) ); local_compare_and_swap(indices); } void main(){ uint t = gl_LocalInvocationID.x; // Each local worker must save two elements to local memory, as there // are twice as many elments as workers. local_value[t*2] = value[t*2]; local_value[t*2+1] = value[t*2+1]; int n = N_ELEMENTS; for ( uint h = 2; h <= n; h /= 2 ) { barrier(); do_flip(h) for ( uint hh = h / 2; hh > 1 ; hh /= 2 ) { barrier(); do_disperse(hh); } } barrier(); // Write local memory back to buffer value[t*2] = local_value[t*2]; value[t*2+1] = local_value[t*2+1]; } 

Note that in this implementation we're using shared memory (#L15). This is a pretty important optimisation. Reading and writing from GPU global memory is pretty slow, and it is much faster to keep memory on cache for as much as we can. The closest cache we have at hand on a compute shader is the shared cache, which is shared amongst all threads within a local workgroup.

In our particular case, this means the 8 threads of our workgroup share the same cache, and because the cache is allocated at the size of the total number of sortable elements, n, we can fit all our sortable data into our local workgroup cache.

## Why use shader local memory?

Using local memory (that is local workgroup cache) allows us to execute multiple steps of the algorithm on a single compute shader invocation. This is much faster, because once the cache is populated, we can do all the work on local memory, before writing back our results to global GPU memory.

To prevent any data races on our local cache, we issue barrier() instructions whenever we want to make sure that all worker thread writes have completed, before we continue execution. Because we know that each step of the algorithm can execute fully in parallel, we need only to issue barriers after each step, and just before we access global memory.

## Know your limits: This won't scale

The shader above is a working implementation of a bitonic sorting network. It will, alas, not scale, as hardware limits apply. There are two major hardware limits: One is the total volume of worker threads which can be issued per work_group, the other is the total number of bytes which may be used as local cache. Fortunately, we can query these limits via Vulkan.

The relevant hardware limits are: maxComputeSharedMemorySize for the total number bytes used for local workgroup cache, and maxComputeWorkGroupInvocations for the total volume of worker threads for one workgroup.

How can we work around these limits, then?

## But this will scale: Bitonic Merge Sort for Arbitrary Power of 2 Sortable Elements

Ideally, we'd like to increase local_size_x to mirror any change in n. To keep things very simple, we'd like our worker thread count $$t = \frac{n}{2}$$, so that we could run everything inside of a single workgroup, and always invoke our compute shader with a groupCountX of 1.

But, as we have just established, even GPUs have a limited number of worker threads. Perhaps the best way to adapt to this incovenience is to invoke our compute shaders with a groupCountX > 1, which means we must first see how we can divide the problem into g groups.

Let's look at our diagram again. Let's say we have a number n == 16 of sortable elements, but a maximum group size of 4 thread workers per workgroup. This would split our workload into two lanes (shaded light- and dark blue in the diagram): Bitonic merge sort network, split over two workgroups (shown as a light- and a dark blue lanes).Note: Thread workers are shown as vertical lines, with the dark circles marking the first and second parameter they take for their compare-and-swap operations.

### We notice the following patterns:

Up to and including step 6, and labelled local_bms in the diagram, no thread worker needs to reach out of their lane, all steps may operate on the local memory of a thread's workgroup.

Step 7 is special, as compare-and-swap operations reaches across the group's lane. Let's label this step big_flip.

In steps 8 - 10, a cascade of disperse operations, all compare-and-swap operations are self-contained within: local_disp.

Not in the diagram, but by extrapolation plausible, is a step where a “disperse” operation reaches over the lane boundary: big_disperse.

## Update CPU code to allow large numbers of sortable elements

From the diagram above we have enumerated the following four cases:

1. local binary merge sort
2. big flip
3. local disperse
4. big disperse

Before we adapt our shader code to accomodate for these four cases, we must rethink how we organise our CPU code.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24  const uint32_t workgroup_count = n / ( workgroup_size_x * 2 ); uint32_t h = workgroup_size_x * 2; local_bms( n, h ); // we must now double h, as this happens before every flip h *= 2; for ( ; h <= n; h *= 2 ) { big_flip( n, h ); for ( uint32_t hh = h / 2; hh > 1; hh /= 2 ) { if ( hh <= workgroup_size_x * 2 ) { local_disperse( n, hh ); break; } else { big_disperse( n, hh ); } } } 

Each call to dispatch() etc. triggers a compute shader dispatch, followed by a compute shader memory barrier. The barrier (lines 30-32 in the following code snippet) is needed to enforce that the previous step has fully completed before the next step of the algorithm begins.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33  auto local_bms = [ & ]( uint32_t n, uint32_t h ) { params.algorithm = Parameters::eAlgorithmVariant::eLocalBms; dispatch( n, h ); }; auto big_flip = [ & ]( uint32_t n, uint32_t h ) { params.algorithm = Parameters::eAlgorithmVariant::eBigFlip; dispatch( n, h ); }; auto local_disperse = [ & ]( uint32_t n, uint32_t h ) { params.algorithm = Parameters::eAlgorithmVariant::eLocalDisperse; dispatch( n, h ); }; auto big_disperse = [ & ]( uint32_t n, uint32_t h ) { params.algorithm = Parameters::eAlgorithmVariant::eBigDisperse; dispatch( n, h ); }; // ---- auto dispatch = [ & ]( uint32_t n, uint32_t h ) { params.n = n; params.h = h; encoder .setArgumentData( LE_ARGUMENT_NAME( "Parameters" ), ¶ms, sizeof( params ) ) .dispatch( workgroup_count ) .bufferMemoryBarrier( { LE_PIPELINE_STAGE_COMPUTE_SHADER_BIT }, { LE_PIPELINE_STAGE_COMPUTE_SHADER_BIT }, { LE_ACCESS_SHADER_READ_BIT }, app->pixels_data->handle ); }; 

In addition to updating out CPU-side code, we also must update our shader code. Our compute shader must provice four different kernels, one each for local binary merge sort, big flip, local disperse, and big disperse. We could write four different shaders, and bind a different pipeline before calling dispatch, but for my proof of concept, I used a uniform parameter and some if-statements on the shader to invoke the chosen algorithm. Since, for each dispatch, all thread workers are running the same code path, there should not be any significant penalty for using an if-statement in this case.

Now, take a closer at how we need to update our shader code so that we can sort arbitrary power-of-two sized arrays.

### Near compare-and-swaps: Local indices + offset

With local_disperse and local_bms operations, our workgroups do not reach out of their lane. Therefore we can keep calculating offsets based on local t and h as outlined above. All we have to do is to apply a constant offset which corresponds to the lane we're on, before we pull data to local memory, on which these algorithms operate. This offset r is calculated based on the height of a lane, multiplied by its index:

$$r = 2 * workgroup\_size\_x * group\_id\_x$$

At the end of local_disperse and local_bms operations, we must make sure to apply the offset again, when we write our locally stored results back to global memory.

### Far compare-and-swaps: Global indices + splitting

For workgroups which reach out of their lane, we can't use local memory, and all index addresses must be global.

Let's look at the diagram again.

How do we enforce that any sortable element is only ever touched by a single worker thread from any workgroups invoked during such a big_[disperse|flip] step?

We don't operate on local $$t$$, but calculate global $$t'$$ for each worker thread. This applies for big_flip, and big_disperse:

$$t' = workgroup\_id\_x * workgroup\_size\_x + local\_id\_x$$

This gives us an updated rule for big_flip:

$$T_t = [ q + (t'\bmod \frac{h}{2}) : q + h-(t' \bmod \frac{h}{2})-1], \text{ for } t' = 0,\dots,n/2$$

… and for big_disperse:

$$T_t = [ q + (t' \bmod \frac{h}{2}) : q + (t' \bmod \frac{h}{2}) + \frac{h}{2}], \text{ for } t' = 0,\dots,n/2$$

Note that the global offset $$q$$ is still calculated as before, that is:

$$q = \lfloor{ (\frac{2t'}{h}) } * h$$

The last step is to put this all together. Which gives us the following updated shader code:

version 450 core // This shader implements a sorting network for 1024 elements. // // It is follows the alternative notation for bitonic sorting networks, as given at: // https://en.m.wikipedia.org/wiki/Bitonic_sorter#Alternative_representation #extension GL_ARB_separate_shader_objects : enable #extension GL_ARB_shading_language_420pack : enable // Externally defined via shader compiler. #ifndef LOCAL_SIZE_X #define LOCAL_SIZE_X 1 #endif // Note that there exist hardware limits - // Look these up for your GPU via https://vulkan.gpuinfo.org/ // // sizeof(local_value[LOCAL_SIZE_X]) : Must be <= maxComputeSharedMemorySize // LOCAL_SIZE_X/2 : Must be <= maxComputeWorkGroupInvocations // ENUM for uniform::Parameters.algorithm: #define eLocalBms 0 #define eLocalDisperse 1 #define eBigFlip 2 #define eBigDisperse 3 layout (local_size_x = LOCAL_SIZE_X ) in; // Note hardware limit mentioned above! layout (set=0, binding = 0) buffer SortData { // This is our unsorted input buffer - tightly packed, // an array of N_GLOBAL_ELEMENTS elements. uint value[]; }; // Note: These parameters are currently unused. layout (set=0, binding=1) uniform Parameters { uint n; uint h; uint algorithm; } parameters; // Workgroup local memory. We use this to minimise round-trips to global memory. // It allows us to evaluate a sorting network of up to 1024 with one shader invocation. shared uint local_value[LOCAL_SIZE_X * 2]; void global_compare_and_swap(ivec2 idx){ if (value[idx.x] < value[idx.y]) { uint tmp = value[idx.x]; value[idx.x] = value[idx.y]; value[idx.y] = tmp; } } void big_flip( in uint n, in uint h) { // uint n // total number of sortable elements // uint h // flip height // uint gl_WorkGroupSize.x // number of threads in block/workgroup: // each thread deals with two sortable elements if ( gl_WorkGroupSize.x * 2 > h ) { return; } uint t_prime = gl_GlobalInvocationID.x; uint half_h = h >> 1; // Note: h >> 1 is equivalent to h / 2 uint q = ((2 * t_prime) / h) * h; uint x = q + (t_prime % half_h); uint y = q + h - (t_prime % half_h) - 1; global_compare_and_swap(ivec2(x,y)); } void big_disperse( in uint n, in uint h ) { // uint n // total number of sortable elements // uint h // disperse height // uint gl_WorkGroupSize.x // number of threads in block/workgroup: // each thread deals with two sortable elements if ( gl_WorkGroupSize.x * 2 > h ) { return; }; uint t_prime = gl_GlobalInvocationID.x; uint half_h = h >> 1; // Note: h >> 1 is equivalent to h / 2 uint q = ((2 * t_prime) / h) * h; uint x = q + (t_prime % (half_h)); uint y = q + (t_prime % (half_h)) + half_h; global_compare_and_swap(ivec2(x,y)); } // Performs compare-and-swap over elements held in shared, // workgroup-local memory void local_compare_and_swap(ivec2 idx){ if (local_value[idx.x] < local_value[idx.y]) { uint tmp = local_value[idx.x]; local_value[idx.x] = local_value[idx.y]; local_value[idx.y] = tmp; } } // Performs full-height flip (h height) over locally available indices. void local_flip(in uint h){ uint t = gl_LocalInvocationID.x; barrier(); uint half_h = h >> 1; // Note: h >> 1 is equivalent to h / 2 ivec2 indices = ivec2( h * ( ( 2 * t ) / h ) ) + ivec2( t % half_h, h - 1 - ( t % half_h ) ); local_compare_and_swap(indices); } // Performs progressively diminishing disperse operations (starting with height h) // on locally available indices: e.g. h==8 -> 8 : 4 : 2. // One disperse operation for every time we can half h. void local_disperse(in uint h){ uint t = gl_LocalInvocationID.x; for ( ; h > 1 ; h /= 2 ) { barrier(); uint half_h = h >> 1; // Note: h >> 1 is equivalent to h / 2 ivec2 indices = ivec2( h * ( ( 2 * t ) / h ) ) + ivec2( t % half_h, half_h + ( t % half_h ) ); local_compare_and_swap(indices); } } // Perform binary merge sort for local elements, up to a maximum number // of elements h. void local_bms(uint h){ uint t = gl_LocalInvocationID.x; for ( uint hh = 2; hh <= h; hh <<= 1 ) { // note: h <<= 1 is same as h *= 2 local_flip( hh); local_disperse( hh/2 ); } } void main(){ // this shader can be called in four different modes: // 1. local flip+disperse (up to n == local_size_x * 2) // 2. big flip // 3. big disperse // 4. local disperse // the total number of elements uint t = gl_LocalInvocationID.x; // Calculate global offset for local workgroup // uint offset = gl_WorkGroupSize.x * 2 * gl_WorkGroupID.x; if (parameters.algorithm <= eLocalDisperse){ // pull to local memory // Each local worker must save two elements to local memory, as there // are twice as many elments as workers. local_value[t*2] = value[offset+t*2]; local_value[t*2+1] = value[offset+t*2+1]; } uint n = parameters.n; // check which one of these does not work properly... // it could also be a synchronisation problem!!! switch (parameters.algorithm){ case eLocalBms: local_bms(parameters.h); break; case eLocalDisperse: local_disperse(parameters.h); break; case eBigFlip: big_flip(parameters.n, parameters.h); break; case eBigDisperse: big_disperse(parameters.n, parameters.h); break; } // Write local memory back to buffer if (parameters.algorithm <= eLocalDisperse){ barrier(); // push to global memory value[offset+t*2] = local_value[t*2]; value[offset+t*2+1] = local_value[t*2+1]; } } 

## Find a working example of this code in Island's examples directory If you're interested in a working example project implementing this algorithm using the Island renderer, I recommend you take a look at Island's bitonic merge sort example which comes bundled with Island, on github.

## Summary

This post describes how to implement bitonic merge sort in parallel GPU code, and the required CPU code to host and organise the dispatch of compute kernels. It was written to deepen my understanding, and to please a reader casually interested in GPU programming. If you've made it this far, thank you, you have been most indulgent:)

## What I learned so far

Implementing an algorithm on the GPU can be daunting, especially as limited debugging tools made it challenging to test whether the indices I calculated, for example, were correct. Writing a little C program which would generate and visualize indices proved invaluable, as it gave me a reliable method to test my code without having to run GPU debuggers, and ponder endless tables of numbers.

Speaking of which, I recommend generating visual debug outputs if you can: A graphical representation of your data is about 100x as readable as a tabular or textual representation, and it's much easier to discover bugs and patterns even at the first glance.

Writing this article helped me to deepen my understanding, and the process of having to explain the index calculations led me to simplify some rules which I was overthinking in the first working draft. It seems that conceptual elegance comes more easily when, after time, sharp edges are worn down by experience.

I've implemented this algorithm while participating in the Winter2'21 batch at the Recurse Center. Somebody there, I wish I could remember who, said a sentence which resonated quite deeply with me, and which I tried to apply whenever I got stuck implementing this algorithm:

When in doubt, write code.

## Thank You

Special thanks go to fellow Recurser Julia Evans who gave me feedback on an earlier version of this post.