Folk/Join Framework & Merge Sort (Java)

We all know that Merge Sort is one of the fastest sorting algorithms out there (on average), it’s based on the divided & conquer technique. In today world of computing, computers usually have more than one CPU, and to take advantage of all of those resources, we need to implement Merge Sort in a way such that it will leverage all of those juicy CPUs.

Today I am going to show you how to implement Merge Sort using Java’s Folk/Join framework. However before we start, let’s get some information about Folk/Join framework in Java.

So what is Folk/Join framework?

This framework is based on the ForkJoinPool class in Java, which is a special kind of executor, two operations, the fork() and join() methods (and their different variants), and an internal algorithm named the work-stealing algorithm. What this algorithm does is to determine which tasks to be executed. When a task is waiting for the finalization of a child task using the join() method, the thread that is executing that task steal another task from the pool of tasks that are waiting and starts its execution. There are disadvantages though, for example, it wont work if the child task is doing some IO operations.

From Java 8 onward, with the introduction of the mighty Stream, behind the scene, it is implemented using Fork/Join framework.

if ( problem.size() > DEFAULT_SIZE) {
    childTask1=new Task();
    childTask2=new Task();
    childTask1.fork();
    childTask2.fork();
    childTaskResults1=childTask1.join();
    childTaskResults2=childTask2.join();
    taskResults=makeResults(childTaskResults1, childTaskResults2);
    return taskResults;
} else {
    taskResults=solveBasicProblem();
    return taskResults;
}

The above snippet is basically what it looks like. It’s very similar to your normal recursive function, we have a base case, in this case the data is small enough for us to calculate directly, otherwise we will split the data into two equal parts, and continue to process recursively on each part.

Similarly, the Merge Sort algorithm stops when the data need to be processed is small enough (usually just 1 element).

Here’s one of the way you can implement a normal Merge Sort.

public class SerialMergeSort {

	public void mergeSort (Comparable data[], int start, int end) { 
		if (end-start > 1) { 
			int middle = (end + start) >>> 1; 
			mergeSort(data,start,middle); 
			mergeSort(data,middle,end); 
			merge(data,start,middle,end); 
		} 
	}

	private void merge(Comparable[] data, int start, int middle, int end) {
		int length=end-start;
		Comparable[] tmp=new Comparable[length];

		int i, j, index;
		i=start;
		j=middle;
		index=0;

		while ((i<middle) && (j<end)) {
			if (data[i].compareTo(data[j])<=0) {
				tmp[index]=data[i++];
			} else {
				tmp[index]=data[j++];
			}
			index++;
		}

		while (i<middle) {
			tmp[index++]=data[i++];
		}

		while (j<end) {
			tmp[index++]=data[j++];
		}

		for (index=0; index < length; index++) {
			data[index+start]=tmp[index];
		}
	}
}

With Fork/Join framework, most of the time you will choose between RecursiveAction which does not return anything and RecursiveTask which is a generic abstract class which returns something when it’s done. With Merge Sort, we need to do an extract processing step when two child tasks are done (merge the sorted lists). In this case our class will extend CountedCompleter class. It has some important methods that we’re going to use.

  • compute(): this is where the body of the Merge Sort algorithm is implemented.
  • addToPendingCount(1): Increase the number of child tasks by 1 that need to be completed before onCompletion method can be executed.
  • onCompletion(CountedComplete caller): this method will be called as soon as all of the child tasks are finished.
  • tryComplete(): tell the parent task that one of its child tasks has been completed.

Here is the most important part in the compute() method:

@Override
	public void compute() {
		if (end - start >= 1024) {
			middle = (end + start) >>> 1;
			MergeSortTask task1 = new MergeSortTask(data, start, middle, this);
			MergeSortTask task2 = new MergeSortTask(data, middle, end, this);
			addToPendingCount(1);
			task1.fork();
			task2.fork();
		} else {
			new SerialMergeSort().mergeSort(data, start, end);
			tryComplete();
		}
	}

Let’s explain the code snippet above.

As you can see, it’s very similar to the serial Merge Sort. We use theSerialMergeSort for the base case. Also we add the reference to the parent task in case we need it. tryComplete() internally will call onCompletion(), when onCompletion() is finished it will call tryComplete() over its parent to try to complete that task.

If the data is still too big (more than 1024 elements), we will split it into two child MergeSortTask, addToPendingCount(1) to register this as one of the child task of the parent task. Then we call fork() to asynchronously invoke task1 and task2.

Here’s the onCompletion() part that you’re looking for:

	@Override
	public void onCompletion(CountedCompleter caller) {

		if (middle == 0) {
			return;
		}
		int length = end - start ;
		Comparable tmp[] = new Comparable[length];

		int i, j, index;
		i = start;
		j = middle;
		index = 0;

		while ((i < middle) && (j < end)) {
			if (data[i].compareTo(data[j]) <= 0) {
				tmp[index] = data[i++];
			} else {
				tmp[index] = data[j++];
			}
			index++;
		}

		while (i < middle) {
			tmp[index++] = data[i++];
		}

		while (j < end) {
			tmp[index++] = data[j++];
		}

		for (index = 0; index < length; index++) {
			data[index + start] = tmp[index];
		}

	}

It's the "merge" part of the SerialMergeSort.

You can download the full source code here.

It’s already too long for a daily blog post and I’m also not a fan of a blog that’s too long. After running benchmark locally on my local laptop (4 cores – 8 threads), the improvement in term of time is usually more than x2 on average.