Binary Heap

The code presented here is for a binary heap. It’s important to note that this isn’t a heap sort, just something that constructs a heap. You can sort with it (the Pop method will return each item in sorted order). This code could be used as the basis of a priority queue.

The important methods are Add, where we place the new item at the bottom of the heap and move it up to its proper position, and Remove where we replace the removed item with the last item in the heap (reducing the heap size by one in the process) and then move that item down to it’s new position.

The backing storage for this is a simple array which allows us to store a binary tree without having to maintain pointers to the child or parent nodes.

public class BinaryHeap<T> : ICollection<T> where T : IComparable {
    private int capacity = 0;
    private int count = 0;
    private T[] items;
    private Boolean isMinHeap = false;

    #region Constructors
    public BinaryHeap() {
        Construct(null, false, 0);
    }

    public BinaryHeap(Boolean minHeap) {
        Construct(null, minHeap, 0);
    }

    public BinaryHeap(int cap) {
        Construct(null, false, cap);
    }

    public BinaryHeap(Boolean minHeap, int cap) {
        Construct(null, minHeap, cap);
    }

    public BinaryHeap(IEnumerable<T> enumerator) {
        Construct(enumerator, false, 0);
    }

    public BinaryHeap(IEnumerable<T> enumerator, Boolean minHeap) {
        Construct(enumerator, minHeap, 0);
    }

    public BinaryHeap(IEnumerable<T> enumerator, int cap) {
        Construct(enumerator, false, cap);
    }

    public BinaryHeap(IEnumerable<T> enumerator, Boolean minHeap, int cap) {
        Construct(enumerator, minHeap, cap);
    }
    #endregion

    #region Properties
    public int Capacity {
        get {
            return capacity;
        }

        set {
            if (value < count) throw new ArgumentOutOfRangeException("Capacity is set to a value less than current.");
            capacity = value;
            Array.Resize(ref items, capacity);
        }
    }

    public T Top {
        get {
            if (count == 0) throw new InvalidOperationException("The BinaryHeap is empty");
            return items[0];
        }
    }
    #endregion

    #region Methods
    public T Pop() {
        if (count == 0) throw new InvalidOperationException("The BinaryHeap is empty");

        T value = items[0];
        Remove(value);

        return value;
    }

    public void TrimExcess() {
        capacity = count;
        Array.Resize(ref items, capacity);
    }
    #endregion

    #region Private Methods
    private void Construct(IEnumerable<T> enumerator, Boolean minHeap, int cap) {
        capacity = cap;
        isMinHeap = minHeap;
        items = new T[capacity];
        if (enumerator != null) {
            foreach (T item in enumerator) {
                this.Add(item);
            }
        }
    }

    private void Swap(int a, int b) {
        T temp = items[a];
        items[a] = items[b];
        items[b] = temp;
    }
    #endregion

    #region Interfaces
    #region ICollection<T>
    public void Add(T item) {
        if (count == capacity) {
            capacity = Math.Max(capacity * 2, 1);
            Array.Resize(ref items, capacity);
        }
        items[count] = item;

        int pos = count;
        int parent = (count - 1) / 2;

        if (isMinHeap) {
            while (pos > 0 && items[pos].CompareTo(items[parent]) < 0) {
                Swap(pos, parent);
                pos = parent;
                parent = (pos - 1) / 2;
            }
        } else {
            while (pos > 0 && items[pos].CompareTo(items[parent]) > 0) {
                Swap(pos, parent);
                pos = parent;
                parent = (pos - 1) / 2;
            }
        }
        count++;
    }

    public void Clear() {
        count = 0;
    }

    public bool Contains(T item) {
        return items.Contains(item);
    }

    public void CopyTo(T[] array, int arrayIndex) {
        items.CopyTo(array, arrayIndex);
    }

    public int Count {
        get { return count; }
    }

    public bool IsReadOnly {
        get { return false; }
    }

    public bool Remove(T item) {
        Boolean result = false;
        int pos = -1;
        for (int i = 0; i < count; i++) {
            if (items[i].CompareTo(item) == 0) {
                pos = i;
                result = true;
                break;
            }
        }

        if (result) {
            count--;
            items[pos] = items[count];

            int swap = pos;

            while (pos * 2 + 1 < count) {
                int child = pos * 2 + 1;
                if (isMinHeap) {
                    if (items[pos].CompareTo(items[child]) > 0) {
                        swap = child;
                    }
                    if (child + 1 < count && items[swap].CompareTo(items[child + 1]) > 0) {
                        swap = child + 1;
                    }
                } else {
                    if (items[pos].CompareTo(items[child]) < 0) {
                        swap = child;
                    }
                    if (child + 1 < count && items[swap].CompareTo(items[child + 1]) < 0) {
                        swap = child + 1;
                    }
                }
                if (swap == pos) break;
                Swap(pos, swap);
                pos = swap;
            }
        }

        return result;
    }

    public IEnumerator<T> GetEnumerator() {
        return new BinaryHeapEnumerator(items, count);
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() {
        return (IEnumerator)GetEnumerator();
    }
    #endregion
    #endregion

    #region Helper Classes
    public class BinaryHeapEnumerator : IEnumerator<T> {
        int position = -1;
        int limit;
        T[] array;

        public BinaryHeapEnumerator(T[] items, int count) {
            array = items;
            limit = count;
        }

        public T Current {
            get {
                try {
                    return array[position];
                } catch (IndexOutOfRangeException) {
                    throw new InvalidOperationException();
                }
            }
        }

        public void Dispose() {
        }

        object IEnumerator.Current {
            get {
                return Current;
            }
        }

        public bool MoveNext() {
            position++;
            return position < limit;
        }

        public void Reset() {
            position = -1;
        }
    }
    #endregion
}

Matrix Multiplication

One time consuming task is multiplying large matrices. In this post we’ll look at ways to improve the speed of this process. We’ll be using a square matrix, but with simple modifications the code can be adapted to any type of matrix.

The straight forward way to multiply a matrix is:

for (int i = 0; i < N; i++) {
    for (int j = 0; j < N; j++) {
        for (int k = 0; k < N; k++) {
            C[i,j] += A[i,k] * B[k,j];
        }
    }
}

Now it’s not important to us in which order the loops run (i, j, k) but the compiler might be able to optimize the code based on the order, so lets try some timings of all variations of orders (ijk, ikj, jik, jki, kij, kji) for the multiplication of two 500×500 matrices of doubles.

We’ll run each test 10 times to balance for any system calls that might happen during the test:

Order Milliseconds
ijk 2,361
ikj 2,000
jik 2,315
jki 5,775
kij 2,089
kji 5,665

It seems that the order ikj is slightly faster, while both jki and kji are significantly slower.

But what about using jagged arrays (that aren’t really jagged, since all rows will have same size)? They might give better performance. Here is the same code above but using jagged arrays:

for (int i = 0; i < N; i++) {
    for (int j = 0; j < N; j++) {
        for (int k = 0; k < N; k++) {
            C[i][j] += A[i][k] * B[k][j];
        }
    }
}

And the results:

Order Milliseconds
ijk 3,622
ikj 1,212
jik 3,641
jki 7,897
kij 1,244
kji 7,773

Well that is interesting. Most of the methods are slower, except for ikj and kij which are almost twice as fast.

This leads into the best part about using jagged arrays. With the current code every access to an element is a double index. First to the row, then to the column. By introducing some extra variables we can optimize this to a single index in each loop:

for (int i = 0; i < N; i++) {
    double[] iRowA = A[i];
    double[] iRowC = C[i];
    for (int k = 0; k < N; k++) {
        double[] kRowB = B[k];
        double ikA = iRowA[k];
        for (int j = 0; j < N; j++) {
            iRowC[j] += ikA * kRowB[j];
        }
    }
}

Running this code gives an execution time of 410 milliseconds. Now we are about 5 times faster than our original method.

Can we improve on this? Of course we can. With .NET 4.0 came the introduction of PLINQ, which gives us an easy way to perform parallel tasks. By taking one of the indexes out of the loop, we can use it as a parameter to the method and calculate multiple rows at a time. Our code to do this looks like this:

double[] iRowA = A[i];
double[] iRowC = C[i];
for (int k = 0; k < N; k++) {
    double[] kRowB = B[k];
    double ikA = iRowA[k];
    for (int j = 0; j < N; j++) {
        iRowC[j] += ikA * kRowB[j];
    }
}

All we have to do now is call this method N times with i ranging from 0 to N. PLINQ gives us the easy way to do this:

var source = Enumerable.Range(0, N);
var pquery = from num in source.AsParallel() select num;
pquery.ForAll((e) => Popt(A, B, C, e));

Where Popt is our method name taking 3 jagged arrays (C = A * B) and the row to calculate (e).
The execution time of this latest code is 187 milliseconds. That’s over 12 times faster than our original code! With the magic of PLINQ we are creating 500 threads in this example and don’t have to manage a single one of them, everything is handled for you.

The final bit of code for this post is the entire multiplication method:

// requires a global value N which is the dimension of the array
// Array C must be fully allocated or you'll get a null reference exception

void Topt(double[][] A, double[][] B, double[][] C) {
    var source = Enumerable.Range(0, N);
    var pquery = from num in source.AsParallel() select num;
    pquery.ForAll((e) => Popt(A, B, C, e));
}

void Popt(double[][] A, double[][] B, double[][] C, int i) {
    double[] iRowA = A[i];
    double[] iRowC = C[i];
    for (int k = 0; k < N; k++) {
        double[] kRowB = B[k];
        double ikA = iRowA[k];
        for (int j = 0; j < N; j++) {
            iRowC[j] += ikA * kRowB[j];
        }
    }
}

In a followup post we’ll make this more generic (removing the requirement of being a NxN matrix and the global N itself) and create a Matrix class.