Merge Sort in Scala

Following my last post, on Merge Sort, I decided to try out a possible implementation in Scala. Furthermore we’ll see here a parametrized implementation, to sort lists of any type.  My previous post on the subject was aimed at explaining the algorithm and was exemplified, for simplicity, with a concrete example for lists of integers. Let’s extend that.

As a first step, let’s start simple and analyse a possible implementation only lists of integers.

def msort(xs: List[Int]): List[Int] = {
  def merge(xs: List[Int], ys: List[Int]): List[Int] = (xs, ys) match {
    case (Nil, ys) => ys
    case (xs, Nil) => xs
    case (x::xs1, y::ys1) =>
        if (x < y) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  		
    val n = xs.length / 2
    if (n == 0) xs
    else {
      val (fst, snd) = xs splitAt n
      merge(msort(fst), msort(snd))
    }
}                      

Let’s go through this. The first piece of code is the definition of the merge function. This function uses pattern matching on the input list to check the 3 possible cases:

  • the first list is empty so the merge result is the second one;
  • the second list is empty so the merge result is the first one;
  • they both have elements, so we need to check which of the first element of each list is smaller (it then becomes the head of the result list) and proceed with the remaining elements.

mSort will then determine the middle of our list and if the result is then, we either have an empty list or a list with only 1 element. That being the case the list is sorted. If not we split it at that element and recursively sort each half.

Next step, we want to make this a little bit more generic. Let’s see how we would go about that.

def msort[T](xs: List[T])(lt: (T, T) => Boolean): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] = (xs, ys) match {
    case (Nil, ys) => ys
    case (xs, Nil) => xs
    case (x::xs1, y::ys1) =>
      if (lt(x, y)) x :: merge(xs1, ys)
      else y :: merge(xs, ys1)
  }
  		
  val n = xs.length / 2
  if (n == 0)
    xs
  else {
    val (fst, snd) = xs splitAt n
    merge(msort(fst)(lt), msort(snd)(lt))
  }
}

As we can see the code is fairly similar to the previous one. We replaced the Int type for a generic type  T. The most “strange” part of it is the  (lt: (T, T) => Boolean). So, what does that do? In the first example, the way we found out if an element is smaller than another was to use the implicit function for integers. Since we’re now using a generic function, we need to supply that function. So, a function call for our function would be something like:

val nums = List( 2, 5 , 23, 1, -4)
msort(nums)((x: Int, y: Int) => x < y)

Can we go a little bit further? Having to pass around those “ugly” compare functions is rather boring. Let’s take another step.

import math.Ordering

def msort[T](xs: List[T])(implicit ord: Ordering[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] = (xs, ys) match {
    case (Nil, ys) => ys
    case (xs, Nil) => xs
    case (x::xs1, y::ys1) =>
      if (ord.lt(x, y)) x :: merge(xs1, ys)
      else y :: merge(xs, ys1)
  }
  		
  val n = xs.length / 2
  if (n == 0)
    xs
  else {
    val (fst, snd) = xs splitAt n
    merge(msort(fst), msort(snd))
  }
}

math.Ordering “defines many implicit objects to deal with subtypes of AnyVal (e.g. Int, Double), String, and others”. We can leverage this by using it’s lt (less than) function for comparing elements. By declaring it as an implicit  parameter, we’re just asking the compiler to figure out the correct function to use. We can now call our functions like this:

val nums = List( 2, 5 , 23, 1, -4)
msort(nums)

Just like we wanted. Simpler and less cumbersome. No need to pass that lt function any more.

Advertisements