Thursday, September 01, 2016

The United States of Monad


A couple of months ago I was chatting with approximately half of my readership about Monads, and more specifically this post which asked: Why do we have monads?

Two of the reasons I gave were exception handling and state. I have already touched on the former, and again refer you to Railway Oriented Programming for a brilliant explanation of how to handle exceptions in a purely functional manner.

What you’ll learn today

  • Why we need to handle state functionally in the first place
  • How to convert imperative for and while loops to functional ones that embed the local state
  • How this can be formalised using the State monad and F# computation expressions
  • How to extend this to keep track of a global state without resorting to a mutable variable

As ever, the code is on GitHub.


Handling state functionally

State is a funny old thing. It seems pretty harmless in small doses, but trying to control state globally and across threads and processes is a nightmare.

If we handle state using purely functional techniques we can control this — we will see how it is threaded through our computations rather than simply ambient, which allows us to make stronger statements about the concurrency and parallelism of our program.

Before we see that, let’s go back to basics and look at handling state in the smallest scope possible, a loop.


Keeping it in the loop

Let’s consider a very simple example. We have an action, say printing to the console, that we want to do a set number of times. Each time we do it, we want to also print how many times we’ve done it, i.e. the state of the computation. Keeping it even simpler, let’s fix the number of iterations:

let startVal = 1
let endVal = 10

We can achieve our goal in a number of ways, the first being a standard for loop. This doesn’t look like it’s impure, but we have a variable i below that will be mutated on each iteration of the loop body:

let forLoop act = 
    for i = startVal to endVal do
        act (i)

The first step towards a functional approach is to notice that a for loop is the same as a while loop that has a mutable integer variable. We thus rewrite:

let whileLoop act = 
    let mutable i = startVal
    while i <= endVal do
        act (i)
        i <- i + 1

We have seen how to eliminate these kind of variables before using an inner recursive function. In the method below we are threading the state through the computation as the second parameter of our aux function:

let recLoop act = 
    let rec aux i = 
        match i > endVal with
        | true -> ()
        | _ -> 
            act (i)
            aux (i + 1)
    aux startVal

We now have state being treated in a pure, functional method. No variables are being mutated,

Let’s run the code:

let action i = printf "%d " i

forLoop action
printf "%s" System.Environment.NewLine
whileLoop action
printf "%s" System.Environment.NewLine
recLoop action
printf "%s" System.Environment.NewLine

Thankfully, the output is exactly the same in each case!

What have we learnt?

We can handle state in a functional manner by recursive function calls that take the state as an input.


Monads, monads everywhere

What on earth does the above have to do with monads?

Let’s take our example a step further and say that we want to do something really complicated, like summing a list of integers. In this instance, all of the ‘temporary’ state (i.e. the ‘sum so far’) is contained entirely within the loop construct, totally hidden from the consumer of the function. We can do this once more using for, while or rec:

let forSum list = 
    let mutable sum = 0
    for elt in list do
        sum <- sum + elt
    sum

let whileSum list = 
    let mutable sum = 0
    let mutable i = 0
    while i < List.length list do
        sum <- sum + list.[i]
        i <- i + 1
    sum

let recSum list = 
    let rec aux remainderOfList sumSoFar = 
        match remainderOfList with
        | [] -> sumSoFar
        | head :: tail -> aux tail (sumSoFar + head)
    aux list 0

The key thing to note above is that in the recursive example, nothing is marked as mutable. This means we’ve now seen a technique that helps us:

  1. Run loops without introducing a mutable variable
  2. Keep track of temporary state whilst keeping functional purity

This quickly breaks down when doing anything non-local — imagine how you might update a global variable from inside a pure function that itself returns a value. To do this, we combine what we’ve seen before and can write the function such that it takes in the global state as an input, and returns the (updated) global state as well as whatever other value the function returns:

let sumListAndUpdateState list state = 
    (recSum list, state)

The above shows that you can update global state fairly simply by returning a tuple of the function result plus updated state.

How do we generalise this so that we don’t have to explicitly pass around the state on every function call? The answer is the state monad. There are many explanations of it out there, but I will walk you through the F# one myself. It all looks pretty complicated, but I’m hoping you get the gist rather than get bogged down by the details — we’re creating a way to thread state silently through any computation.

Being precise, the state monad is based on a type — one that encodes a function taking a state of type 'a and returns a state-result tuple where the result is of type 'b':

type StateMonad<'a,'b> = 'a -> ('a * 'b)

To be classed as a monad, we need to be able to do some stuff with it. First, we need to be able to create an instance of it from an initial state:

let returnS a = (fun s -> a, s)

We also need to be able to combine an instance of it with the result of a previous computation:

let (>>=) x f = 
    (fun s0 -> 
    let a, s = x s0
    f a s)

Finally, we want to combine two instances of it:

let (>=>) m1 m2 = 
    (fun s0 ->
    let _, s = m1 s0
    m2 s)            

Done as an F# computation expression, this looks like:

type StateBuilder() = 
    member m.Bind(x, f) = x >>= f
    member m.Return a = returnS a
    member m.ReturnFrom(x) = x

let state = new StateBuilder()

We can then define some helper functions that we can use in our state computation expressions: the first gets the current state; the second sets it; the third executes our computation and returns the final state.

let getState = (fun s -> s, s)
let setState s = (fun _ -> (), s)
let Execute m s = m s |> snd

Now that we have the concept of threading state through a function encoded in a generic type, let’s put it to use!

A contrived example is our previous one summing a list — it’s overkill for a monad as it only ever used local state, but here goes:

let stateSum list = 
    let rec aux t = 
        state { 
            match t with
            | head :: tail -> 
                let! s = getState
                do! setState (s + head)
                return! aux tail
            | [] -> return ()
        }
    Execute (aux list) 0

Let’s focus on the aux computation expression. The most obvious thing is that the state isn’t an input! This is because aux actually returns a function — one that takes in the initial state and returns the final state plus the function return value (which in this case is unit but could be anything)

Behind the scenes it is going to call Bind to combine the recursive calls to aux into one big computation, each stage of which will take the sum of the list so far plus the tail of the list as inputs. When we have an empty list, we return the state plus the function ‘return value’. Here, the state is the value we want so we can just return unit.

Of course, passing around an integer as a state isn’t necessary. We can write a generic fold function using our state computation expression as follows:

let fold aggregator initialValue list = 
    let rec aux t = 
        state { 
            match t with
            | head :: tail -> 
                let! s = getState
                do! setState (aggregator s head)
                return! aux tail
            | [] -> return ()
        }
    Execute (aux list) initialValue

Now we are passing in everything we need to control the aggregation, and we have left with a function that takes in an initial state, a list, and a way to update the state from the list, returning the final state. Complicated, but beautiful to reason about.

Next, let’s see how we can handle global state.


Globalisation

So far, we have shown a number of ways in which local state can be handled. This is a fundamentally different problem (and easier) than that of global state.

Let’s use the example of a function that will only run an action five times. On the sixth run it will instead perform a failure action.

Here’s how it looks normally: we create a mutable variable and read that from inside the (impure, yuck) function:

let mutable globalCounter = 0

let canOnlyRunFiveTimes passAction failAction = 
    match globalCounter < 5 with
    | true ->            
        passAction globalCounter
        globalCounter <- globalCounter + 1
    | false -> 
        failAction globalCounter
        globalCounter <- -1

Using our state monad, we can do exactly the same thing but we have controlled the impurity:

let monadicCounter = 0

let canOnlyRunFiveTimesWithStateMonad passAction failAction = 
        state { 
            let! s = getState
            match s < 5 with
            | true -> 
                passAction s
                do! setState (s + 1)
            | false -> 
               failAction s
               do! setState (- 1)
        }

The best way to see this in action is by running some tests.

First, let’s run our function five times using the version that reads from a global mutable variable. It should return 5! Then run it a sixth time; it should return -1 to signify an error.

[<Test>]
let ``Global mutable state``() = 
  for _ in 1 .. 5 do canOnlyRunFiveTimes (printf "Run %d \r\n") (printf "Not run %d \r\n")
  globalCounter =! 5
  canOnlyRunFiveTimes (printf "Run %d \r\n") (printf "Not run %d \r\n")
  globalCounter =! -1

Here’s how it looks using global monadic state (no mutable keyword, whoop!!). We use our >=> operator from before that glues together two monad thingies to run the computation more than once:

[<Test>]
let ``Global monadic state``() =
  let monadicCounter = 0
  let m = canOnlyRunFiveTimesWithStateMonad (printf "Run %d \r\n") (printf "Not run %d \r\n") 
  let composedFiveTimes =  m >=> m >=> m >=> m >=> m
  let composedSixTimes =  composedFiveTimes >=> m
  Execute composedFiveTimes monadicCounter =! 5
  Execute composedSixTimes monadicCounter =! -1

The results are exactly the same!


Recap

Let’s look at what I promised you:

What you’ll learn today

  • Why we need to handle state functionally in the first place
  • How to convert imperative for and while loops to functional ones that embed the local state
  • How this can be formalised using the State monad and F# computation expressions
  • How to extend this to keep track of a global state without resorting to a mutable variable

I hope that you feel satisfied that you know these things now — if you don’t, let me know which bits are a struggle and how I can improve my explanation!

No comments:

Post a Comment