Provide Caching For Your Complex Functions

Function caching, or more specifically memoization, is a code optimization technique that can be used to speed up code that calls complex functions with the same inputs.  Let's take a look at the classic fibonacci function:

    1: private int Fib(int val)
    2: {
    3:     return val <= 2 ? 1 : Fib(val - 2) + Fib(val - 1); 
    4: }

This is a recursive function that computes the nth value of the fibonacci sequence, where n is passed as an argument to the function.  As the value of n increases, the time taken to compute the result increases.  If this function is called multiple times with the same input, the cost associated with the calculation is incurred each time.  It would be nice if we could cache the results of the first call for a specific input and then if the function is called again with the same input, we just return the result from the previous call.  It would also be nice if we could provided a generic way to do this so that it would apply to any type of function.  In this post, I will share a few classes that I wrote to facilitate this type of function caching.

The following code is a class that provides the caching mechanism for a given function that consists of a single argument.

    1: public sealed class CachedFunction<T, TResult>
    2: {
    3:     private Func<T, TResult> _function;
    4:     private int _cacheSize;
    5:     private Dictionary<T, CacheInformation<T, TResult>> _cache;
    6:  
    7:     // The constructor for the CachedFunction class.
    8:     private CachedFunction(Func<T, TResult> function, int cacheSize)
    9:     {
   10:         this._function = function;
   11:         this._cacheSize = cacheSize;
   12:         this._cache = new Dictionary<T, CacheInformation<T, TResult>>(this._cacheSize);
   13:     }
   14:  
   15:     // A factory method that created a new instance of the CacheFunction class with the default cache size.
   16:     public static CachedFunction<T, TResult> Create(Func<T, TResult> function)
   17:     {
   18:         return Create(function, 10);
   19:  
   20:     }
   21:  
   22:     // A factory method that created a new instance of the CacheFunction class.
   23:     public static CachedFunction<T, TResult> Create(Func<T, TResult> function, int cacheSize)
   24:     {
   25:         if (function == null)
   26:             throw new ArgumentNullException("function");
   27:         if (cacheSize <= 0)
   28:             throw new ArgumentOutOfRangeException("cacheSize", Resources.ArgumentOutOfRangeExceptionIntLessThanOneMessage);
   29:  
   30:         return new CachedFunction<T,TResult>(function, cacheSize);
   31:     }
   32:  
   33:     // This method is called to execute the function with the argument specified.
   34:     // If the argument exists in the cache, the cached result is returned, otherwise
   35:     // the original function is called and the result is added to the cache.
   36:     public TResult Execute(T arg1)
   37:     {
   38:         if (this._cache.ContainsKey(arg1))
   39:             return this.GetReturnFromCache(arg1);
   40:         else
   41:         {
   42:             TResult returnValue = this._function(arg1);
   43:             this.AddToCache(arg1, returnValue);
   44:             return returnValue;
   45:         }
   46:     }
   47:  
   48:     // This method adds a new argument/result pair to the cache.  If the size
   49:     // if the cache dictionary exceeds the cache size value, then the oldest cached
   50:     // result with the smallest number of executions is removed from the cache.
   51:     private void AddToCache(T argument1, TResult returnValue)
   52:     {
   53:         if (this._cache.Count >= this._cacheSize)
   54:         {
   55:             int min = this._cache.Min(x => x.Value.ExecutionCount);
   56:             this._cache.Remove(this._cache.First(x => x.Value.ExecutionCount == min).Key);
   57:         }
   58:  
   59:         this._cache.Add(argument1, CacheInformation<T, TResult>.Create(argument1, returnValue));
   60:     }
   61:  
   62:     // This method returns a function result from the cache that corresponds
   63:     // to the given argument.
   64:     private TResult GetReturnFromCache(T argument1)
   65:     {
   66:         return this._cache[argument1].ReturnValue;
   67:     }
   68: }

The cache dictionary in the code above contains a class that represents the functions argument, return value, and the number of times that this instance of the function has been called.  That code is shown below:

    1: internal sealed class CacheInformation<T, TResult>
    2: {
    3:     private TResult _returnValue;
    4:     private int _executionCount;
    5:  
    6:     private CacheInformation(T argument1, TResult returnValue)
    7:     {
    8:         this.ExecutionCount = 0;
    9:         this.Argument1 = argument1;
   10:         this.ReturnValue = returnValue;
   11:     }
   12:  
   13:     public static CacheInformation<T, TResult> Create(T argument1, TResult returnValue)
   14:     {
   15:         return new CacheInformation<T, TResult>(argument1, returnValue);
   16:     }
   17:  
   18:     public T Argument1 { get; private set; }
   19:     
   20:     public TResult ReturnValue 
   21:     {
   22:         get
   23:         {
   24:             Interlocked.Increment(ref this._executionCount);
   25:             return this._returnValue;
   26:         }
   27:         private set
   28:         {
   29:             this._returnValue = value;
   30:         }
   31:     }
   32:  
   33:     public int ExecutionCount 
   34:     { 
   35:         get
   36:         { 
   37:             return this._executionCount;
   38:         }
   39:         private set
   40:         {
   41:             this._executionCount = value;
   42:         }
   43:     }
   44:     
   45: }

Now let's take a look at how we can implement the same fibonacci function that we started with, but this time with the caching mechanism that we have implemented.

    1: private CachedFunction<int, int> myFunc;
    2:  
    3: public Program()
    4: {
    5:     myFunc = CachedFunction<int, int>.Create(x => Fib(x));
    6: }
    7:  
    8: private int Fib(int val)
    9: {
   10:     return val <= 2 ? 1 : Fib(val - 2) + Fib(val - 1); 
   11: }
   12:  
   13: public int DoFib(int val)
   14: {
   15:     return myFunc.Execute(val);
   16: }
   17:  
   18: static void Main(string[] args)
   19: {
   20:     Program p = new Program();
   21:     Console.WriteLine(p.DoFib(40));
   22:     Console.WriteLine(p.DoFib(39));
   23:     Console.WriteLine(p.DoFib(40));
   24:     Console.WriteLine(p.DoFib(39));
   25:     Console.WriteLine(p.DoFib(39));
   26:     Console.WriteLine(p.DoFib(39));
   27:     Console.WriteLine(p.DoFib(39));
   28:     Console.WriteLine(p.DoFib(40));
   29:     Console.WriteLine(p.DoFib(40));
   30:     Console.WriteLine(p.DoFib(40));
   31:     Console.WriteLine(p.DoFib(40));
   32: }

When the code above is executed, the first two calls will take some time to compute and then the following calls will return immediately with the cached result for 39 and 40.

This concept only provides value for functions that take some time to compute.  If the function is simple, like returning the square of an input, the cost associated with caching may outweigh the cost of the actual operation.  But in the case of something rather complex, this can save a lot of time if the inputs to a function are likely to be the same. 

The example I provided also only works for a function that takes a single argument, but this could be easily modified using a Tuple to provide support for functions that take multiple arguments.

Comments

  • Anonymous
    December 07, 2008
    It would be interesting to compare the cost if the Fib function internally used the CachedFunction. private int Fib(int val) {   return val <= 2 ? 1 : DoFib(val - 2) + DoFib(val - 1); }

  • Anonymous
    December 08, 2008
    Wow, an actually interesting post.  Congrats. I agree with Anders, you should try doing the Fib using the Cache.  (and post the time results).  The first handful of iterations you'll get nailed, but for something like that you could see some large gains with the higher numbers.   2 Problems that I can see:

  1. Can (should) only be used on functions that are deterministic.  I.E. If your method is doing something based off of timestamps, or Rands, etc, then you could be caching a result that could potentially change [ think of caching a function that takes as a parameter how many times to roll a dice and add up all the rolls]
  2. This seems like it would be very useful for something that is computationally intense.  However, with such an application, one would think that the set of all possible keys would be so large that you would not get enough "hits" to keep it in your cache. I've seen something similar to this before in embedded systems.  Used to keep lookup tables for Trig calcs in a table, so you wouldn't have to figure out cos(43) etc.   Your code is nice because it's more generic.