Introduction to Algorithms, 2022-3
Week 13: Notes

counting sort

In a lecture a few weeks ago, we saw that no comparison-based sorting algorithm can run in better than O(N log N) time in the worst case.

However there do exist some sorting algorithms that can run in linear time! These algorithms are not comparison-based, and so they are not as general as the other sorting algorithms we've studied in this course. In other words, they will work only on certain types of input data.

Counting sort is one linear-time sorting algorithm. In its simplest form, counting sort works on an input array of integers in the range 0 .. (M – 1) for some fixed integer M. The algorithm traverses the input array and builds a table that maps each integer k (0 ≤ k < M) to a value count[k] indicating how many times it occurred in the input. After that, it makes a second pass over the input array, writing each integer k exactly count[k] times as it does so.

For example, suppose that M = 5 and the input array is

2 4 0 1 2 1 4 4 2 0

The table of counts will look like this:

So during the second pass the algorithm will write two 0s, two 1s, three 2s and so on. The result will be

0 0 1 1 2 2 2 4 4 4

Here is an implementation:

# Sort the array a, which contains integers in the range 0 .. (m - 1).
def counting_sort(a, m):
    count = m * [0]

    for x in a:
        count[x] += 1

    pos = 0
    for k in range(m):
        for i in range(count[k]):
            a[pos] = k
            pos += 1
    

Let's consider the running time of this function. All values in the input are known to be less than M. Let N be the length of the input array a. Then it will take time O(M) to allocate the array 'count', as well as to examine each of its values in the second pass. It will take time O(N) to read each value from the input array and to write new values into the array on the second pass. The total running time will be O(M + N). If we consider M to be a constant, this is O(N).

stable sorting

In earlier weeks we saw that we can modify any sorting algorithm so that it sorts using a key function, which maps each element to a key used for sorting. For example, suppose that we would like to sort this array of strings by length:

>>> a = ['town', 'grape', 'prague', 'two', 'night', 'day', 'one', 'lift']

Then the key function maps each string to its length. Python's built-in sort can use a key function:

>>> sorted(a, key = len)
['two', 'day', 'one', 'town', 'lift', 'grape', 'night', 'prague']

We see that the output array contains all strings of length 3, then strings of length 4, and so on. Furthermore, for any length k, the strings of length appear in the same order as in the input array. That is guaranteed because Python's built-in sort is stable. A stable sort preserves the order of values with the same key.

Some sorting algorthms are naturally stable, and others are not. For example, bubblesort is stable because it will never swap two values with the same key, so their order will be preserved throughout the sort. Of the other algorithms that we've seen in this course, insertion sort and merge sort are stable; selection sort, quicksort and heap sort are not. You may want to think about why.

counting sort with a key

Let's modify the counting sort algorithm that we saw above so that it uses a key function. The input to the algorithm will be an array a, plus a key function f that maps each element of a to an integer in the range 0 .. (M – 1) for some fixed integer M. We'd like to produce an array containing all the values in a, sorted by key. As an additional requirement, we would like this sort to be stable.

As one possible approach, we can allocate an array of M buckets, each of which initially contains the empty list. We can iterate over the input array; for each value x, we append x to the list in bucket f(x). After than, we need only concatenate the lists in all the buckets. Here is an implementation in Python:

# Sort the array a using a key function that maps values to integers
# in the range 0 .. m - 1.  Return a new array with the sorted values.
def counting_sort_by(a, m, f):
    vals = [ [] for _ in range(m) ]
    for x in a:
        vals[f(x)].append(x)
    
    a = []
    for vs in vals:
        a += vs
    return a

Assume that the input array a has N elements, and that the key function f runs in constant time. How long will this function take to run, as a function of N? Allocating the buckets takes O(1), since there are M buckets and M is a constant. The first 'for' loop above will run in O(N) since Python's append() method runs in O(1) on average. The second 'for' loop will also run in O(N), since it appends every element exactly once. The total running time is O(N). The sort is stable, since each bucket contains values in the same order in which they appeared in the original array.

As a second possible approach, we can build a table 'count' that maps each key to its count, i.e. to the number of elements with that key. We can then adjust this table so that for each key k, count[k] is the first index at which elements with key k will be stored. For example, suppose that M = 4 and the table is initially [2, 5, 3, 7], meaning that there are 2 values with key 0, 5 values with key 1, and so on. After we adjust the table, it will have the values [0, 2, 7, 10], meaning that in the output array values with key 0 will begin at index 0, values with key 1 will begin at index 2, and so on. Each element in the adjusted table contains the sum of all the previous elements in the original table.

Finally, we can allocate a new array b of the appropriate length, then iterate over a and copy each element x to b at the index count[f(x)]. We will also increment that index in the table, so that the next value with the same key will be inserted in a subsequent position.

An implemention in Python looks like this:

# Sort the array a using a key function that maps values to integers
# in the range 0 .. m - 1.  Return a new array with the sorted values.
def counting_sort_by(a, m, f):
    count = m * [0]
    for x in a:
        count[f(x)] += 1

    sum = 0
    for i in range(m):
        sum, count[i] = sum + count[i], sum

    b = len(a) * [None]
    for x in a:
        k = f(x)
        b[count[k]] = x
        count[k] += 1
    
    return b

This implementation will also run in O(N), and also produces a stable sort.

In this section we have seen two implementations of counting sort with a key. The first is simpler, though it will probably use a bit more memory due to the overhead of storing a dynamic array in each bucket.

radix sort

When we use a counting sort to sort integers in the range 0 .. M – 1, the sort will allocate an array with M elements. That is reasonable as long as M is not too large. But what if M is a large number such as 1,000,000,000 or 1,000,000,000,000? We probably cannot allocate that much memory.

Radix sort is an algorithm can sort an array of integers up to a fixed size M. Like counting sort, it runs in time O(N), where there are N integers in the input array. However it will generally use far less memory than a counting sort, and is practical even for large values of M such as M = 1,000,000,000,000.

Radix sort works by performing a series of passes. In each pass, it uses a stable counting sort to sort the input numbers by a certain digit. In the first pass, this is the last digit of the input numbers. In the second pass, it is the second-to-last digit, and so on. After all passes are complete, the array will be sorted.

For example, consider the operation of radix sort on these numbers:

503 223 652 251 521 602

In the first pass, we sort by the last digit. The sort is stable, so if numbers have the same digit, their order will be preserved:

251 521 652 602 503 223

Next we sort by the middle digit:

602 503 521 223 251 652

Finally, we sort by the first digit:

223 251 503 521 602 652

The sort is complete, and the numbers are in order.

Why does radix sort work? After the first pass, the numbers are sorted by their last digit. In the second pass, we sort numbers by the second-to-last digit. Ties will be resolved using the previous ordering (since counting sort is stable), which is by the last digit, so after the second pass all numbers will be sorted by their last two digits. After another pass, the numbers will be sorted by their last three digits, and so on.

Here is an implementation in Python:

# Sort an array a of numbers that have up to k decimal digits.
def radix_sort(a, k):
    p = 1
    for i in range(k):
        a = counting_sort_by(a, 10, lambda x: (x // p) % 10)
        p *= 10
    return a

If we consider the number of digits k to be a constant, then the algorithm makes a constant number of passes. Each pass runs in O(N), so the total running time is k · O(N) = O(N).

The code above uses base 10, but of course the base is arbitrary. For efficiency, it may be better to use a base such as 1000, i.e. to treat each number as a series of digits in base 1000. Then we will need fewer passes; on the other hand, then each individual counting sort will need to allocate a larger array, and will take a bit longer since it needs to traverse that entire array.