Dex: Differential Probabilistic Inference

This notebook develops an unconventional approach to probabilistic inference using Dex tables and auto-differentiation. It is based loosely on the approach described in:

A Differential Approach to Inference in Bayesian Networks, Adnan Darwiche (2003)

This approach can be thought of as a probabilistic programming language (PPL) in the sense the inference is seperated from the modeling language. However, it does not require interrupting standard control flow.

Running Example 1: Coins and Dice

Let us start with a simple probabilistic modeling example to establish some notation.

In this example we have a coin and two weighted dice. We first flip the coin, if it is heads we roll dice 1 and if it is tails we roll dice 2.

Coin = Fin 2
[tails, heads] = for i:Coin. i
Dice = Range 1 7
roll = \i . (i - 1)@ Dice
None = Fin 1
nil = 0@None
coin : Coin =>Float = [0.2, 0.8]
dice_1 : Dice => Float = for i. 1.0 / 6.0
dice_2 : Dice => Float = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1]

This defines a generative process over two random variables $\mathbf{X} = { A, B }$, the coin flip and the dice roll respectively. We can write the process explicitly as,

$$a \sim Pr(A)$$ $$b \sim Pr(B\ |\ A=a)$$

Probability Combinators

A discrete probability distribution is a normalized table of probabilities

data Dist variables = AsDist (variables => Float)
def (??) (y:m) (AsDist x: Dist m) : Float = x.y

Distributions are easy to create. Here are a couple simple ones.

def normalize (x: m=>Float) : Dist m = AsDist for i. x.i / sum x
def uniform : Dist m = normalize for i. 1.0
def delta (x:m) : Dist m = AsDist for i. select ((ordinal x) == (ordinal i)) 1.0 0.0
instance Arbitrary (Dist m) arb = \key. a = arb key normalize $for i. abs a.i And they are displayed by their support. def support (AsDist x: Dist m) : List (m & Float) = concat$ for i. select (x.i > 0.0) (AsList 1 [(i, x.i)]) mempty
instance [Show m] Show (Dist m) show = \a. (AsList _ out) = support a concat $for i. "Key: " <> (show$ fst out.i) <> " Prob: " <> (show $snd out.i) <> "\n" instance Show (Range i j) show = \a . show$ ordinal a
show $(delta (4@_)):Dist Dice (AsList 15 "Key: 4 Prob: 1 ") Expectations can be taken over arbitrary tables. def expect [VSpace out] (AsDist x: Dist m) (y : m => out) : out = sum for m'. x.m' .* y.m' To represent conditional probabilities such as$ Pr(B \ |\ A)$we define a type alias. def Pr (b:Type) (a:Type): Type = a => Dist b With this machinery we can define distributions for the coin and the dice. p_A : Pr Coin None = [AsDist coin] p_B_A : Pr Dice Coin = [AsDist dice_1, AsDist dice_2] Attempt 1: Observations and Marginalization This allows us to compute the probability of any full observation from our model. $$Pr(A=a, B=b) = Pr(B=b\ | A=a) Pr(A=a)$$ def p_AB (a:Coin) (b:Dice) : Float = (a ?? p_A.nil) * (b ?? p_B_A.a) p_AB heads (roll 6) 0.08 However, this assumes that we have full observation of our variables. What if the coin is latent? This requires a sum. $$Pr(B) = \sum_a Pr(B\ | A=a) Pr(A=a)$$ def p_B (b:Dice) : Float = sum for a. (a ?? p_A.nil) * (b ?? p_B_A.a) p_B (roll 6) 0.113333 But now we have two seperate functions for the same model! This feels unnecessary and bug-prone. Attempt 2: Indicator Variables and the Network Polynomial In order to make things simpler we will introduce explicit indicator variables$\lambda$to the modeling language. def Var (a:Type) : Type = a => Float These can either be observed or latent. If a random variable is observed then we use an indicator. The expectation over the indicator gives, $$p(A=a) = E_{a' \sim p(A)} \lambda_{a'}$$ If it is latent the variable is one everywhere. def observed (x:a) : Var a = for i. select ((ordinal i) == (ordinal x)) 1.0 0.0 def latent : Var a = one The probability chain rule tells us that we can propagate conditioning. $$\sum_{b} Pr(A,\ B = b) = Pr(A) \sum_b Pr(B=b\ | A)$$ This implies that the expectation of these indicators factors as well. $$E_{a, b\sim Pr(A,\ B)} \lambda_a \lambda_b = E_{a'\sim Pr(A)}\left[ \lambda_a E_{b' \sim Pr(B | A=a)} [\lambda_b] \right]$$ We can write one step of this chain rule really cleanly in Dex. def (~) (lambda:Var a) (pr: Dist a) (fn_a : a => Float) : Float = expect pr$ for a'. lambda.a' * fn_a.a'

This allows us to final write our model down in an extremely clean form.

$$a \sim Pr(A)$$ $$b \sim Pr(B\ |\ A=a)$$

def coin_flip (a': Var Coin) (b': Var Dice) : Float = (a' ~ p_A.nil) (for a. (b' ~ p_B_A.a) one)

Now we can easily reproduce all the result above.

coin_flip (observed heads) (observed (roll 6))
0.08
coin_flip (latent) (observed (roll 6))
0.113333
coin_flip latent latent
1.

This representation for joint distributions is known as the network polynomial. This is a multi-linear function that uses indicator variables to represent data observations.

$$f(\lambda) = \sum_{\mathbf{x}} \prod_{x, \mathbf{u}\in \mathbf{x}} \lambda_x \theta_{x|\mathbf{u}}$$

Here $\theta$ is the model parameters. These play the same role as above. The $\lambda$ are evidence indicators which indicate the states of the variable instantiations.

The network polynomial can be used to compute marginal probabilities of any subset of variables. Let $\mathbf{e}$ be the observations of some subset of $\mathbf{X}$. Darwiche shows that -

$$f[\mathbf{e}] = p(\mathbf{E} = \mathbf{e})$$

Where $f[e]$ assigns 1 to any $\lambda$ term that is consistent (non-contradictory) with $\mathbf{e}$ and 0 otherwise. Let's look at an example.

Differential Inference

The network polynomial is a convenient method for computing probilities, but what makes it particularly useful is that it allows us to compute posterior probabilities simply using derivatives.

For example, consider the probability on the coin flip given an observation of a dice roll. We can compute this using Bayes' Rule.

$$Pr(A | B=b) \propto Pr(B=b | A) Pr(A)$$

grass : Pr Grass (Sprinkler & Rain) = for (s, r). bernoulli $case s of {|nosprinkler=()|} -> case r of {|norain=()|} -> 0.0 {|rain=()|} -> 0.8 {|sprinkler=()|} -> case r of {|norain=()|} -> 0.9 {|rain=()|} -> 0.99 And the architecture of the Bayes net. def wet_naive (r' : Var Rain) (s' : Var Sprinkler) (g' : Var Grass) : Float = (r' ~ rain.nil) (for r. (s' ~ sprinkler.r) (for s. (g' ~ grass.(s,r)) one)) wet_naive (latent) (latent) (observed {|wet=()|}) 0.44838 posterior (\x. wet_naive x (latent) (observed {|wet=()|})) (AsDist [0.642312, 0.357688]@{norain: Unit | rain: Unit}) Example 3: Dice Counting Here's a classic elementary probability problem. Given two standard dice rolls, what is the probability distribution over their sum? DiceSum = Range 2 13 Helper functions for Dice sum def (+@+) (a:a') (b:b') : c = (((ordinal a) + (ordinal b))@_) def roll_sum (x:Int) : DiceSum = (x - 2)@_ def two_dice (dice : Var (Dice & Dice)) (dicesum : Var DiceSum) : Float = (dice ~ uniform) (for (d1, d2). (dicesum ~ delta (d1 +@+ d2)) one) Here's the result. posterior (\m. two_dice latent m) (AsDist [ 0.027778 , 0.055556 , 0.083333 , 0.111111 , 0.138889 , 0.166667 , 0.138889 , 0.111111 , 0.083333 , 0.055556 , 0.027778 ]@(%IntRange 2 13)) We might also ask what the probability of the dice rolls given on output value. support$ posterior (\m. two_dice m (observed (roll_sum 4)))
(AsList 3 [ (((0@%IntRange 1 7), (2@%IntRange 1 7)), 0.333333) , (((1@%IntRange 1 7), (1@%IntRange 1 7)), 0.333333) , (((2@%IntRange 1 7), (0@%IntRange 1 7)), 0.333333) ])

Discussion - Conditional Independence

One tricky problem for discrete PPLs is modeling conditional independence. Models can be very slow to compute if we are not careful to exploint conditional independence properties such as Markov assumptions.

For example, let us consider a more complex version of the coin flip example. We will flip three times. The choice of the second weighted coin depends on the first. The choice of third weighted coin depends on the second.

$$a \sim Pr(A)$$ $$b \sim Pr(B\ |\ A=a)$$ $$c \sim Pr(C\ |\ B=b)$$

In this example $C$ is conditionally independent of $A$ given $B$.

We can be lazy and create the distributions randomly.

coin1 : Pr Coin None = arb $newKey 1 coin2 : Pr Coin Coin = arb$ newKey 2
coin3 : Pr Coin Coin = arb $newKey 3 Now here is the generative process. def coin_flip2 (a': Var Coin) (b': Var Coin) (c': Var Coin) : Float = (a' ~ coin1.nil) (for a. (b' ~ coin2.a) (for b. (c' ~ coin3.b) one)) Note that as written this process looks like it does not take advantage of the conditional independence property of the model. The construction of the final coin is in a for constructor that contains a. However, Dex knows that a is not used in the inner most construct. In theory it can lift that out of the loop and exploit the conditional independence. Alternatively we can make this explicit and do the lifting ourselves. def coin_flip_opt2 (a': Var Coin) (b': Var Coin) (c': Var Coin) : Float = final_flip = for b. (c' ~ coin3.b) one (a' ~ coin1.nil) (for a. (b' ~ coin2.a) final_flip) Example 4: Monty Hall Problem Perhaps the most celebrated elementary problem in conditional probability is the Monty Hall problem. You are on a game show. The host asks you to pick a door at random to win a prize. After selecting a door, one of the remaining doors (without the prize) is removed. You are asked if you want to change your selection... Doors = Fin 3 YesNo = { no:Unit | yes:Unit} def yesno (x:Bool) : Dist YesNo = delta$ select x {|yes=()|} {|no=()|}

The generative model is relatively simple

1. We will first sample our pick and the door.
2. Then we will consider changing our pick.
3. Finally we will see if we won.
def monty_hall (change': Var YesNo) (win': Var YesNo) : Float = (one ~ uniform) (for (pick, correct): (Doors & Doors). (change' ~ uniform) (for change. win_dist = case change of {|yes=()|} -> yesno (pick /= correct) {|no=()|} -> yesno (pick == correct) (win' ~ win_dist) one))

To check the odds we will compute probabity of winning conditioned on changing.

{|yes=()|} ?? (posterior $monty_hall (observed {|yes=()|})) 0.666667 And compare to proability of winning with no change. {|yes=()|} ?? (posterior$ monty_hall (observed {|no=()|}))
0.333333

Example 5: Hidden Markov Models

Finally we conclude with a more complex example. A hidden Markov model is one of the most widely used discrete time series models. It models the relationship between discrete hidden states $Z$ and emissions $X$.

Z = Fin 5
X = Fin 10

It consists of three distributions: initial, transition, and emission.

initial : Pr Z nil = arb $newKey 1 emission : Pr X Z = arb$ newKey 2
transition : Pr Z Z = arb $newKey 3 The model itself takes the following form for$m$steps. ' $$z_0 \sim \text{initial}$$ $$z_1 \sim \text{transition}(z_0)$$ $$x_1 \sim \text{emission}(z_1)$$ $$...$$ This is implemented in reverse order for clarity (backward algorithm). def hmm (init': Var Z) (x': m => Var X) (z' : m => Var Z) : Float = (init' ~ initial.nil)$ yieldState one ( \future . for i:m. j = ((size m) - (ordinal i) - 1)@_ future := for z. (x'.j ~ emission.z) (for _. (z'.j ~ transition.z) (get future)))

We can marginalize out over latents.

hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) (for i. latent)
0.000285

Or we can compute the posterior probabilities of specific values.

posteriorTab $\z . hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) z [ (AsDist [0.004153, 0.162454, 0.539687, 0.069191, 0.224516]) , (AsDist [0.180291, 0.288983, 0.129116, 0.108978, 0.292633]) ] Example 5a. HMM Monoid We can also write out an HMM using a Monoid. Here we define a monoid for square matrix multiplication. def MarkovMonoid (a:Type) : Monoid (a => a => Float) = M = a -- XXX: Typing Monoid a below would quantify it over a, which we don't want named-instance result : Monoid (M => M => Float) mempty = for m1 m2. select ((ordinal m1) == (ordinal m2)) 1.0 0.0 mcombine = \m1 m2. for i j. sum for k. m1.i.k * m2.k.j result We also define a Markov version of our sample function. Instead of summing out over the usage of its result, it constructs a matrix a vector. def markov (lambda:Var a) (pr: Dist a) : a => Float = for a'. (a' ?? pr) * lambda.a' Here we write out the HMM using a forward style approach. Each time through the algorithm the accumulator represents the matrix of the joint likelihood from position 1 to i. def hmm_monoid (init': Var Z) (x': m => Var X) (z' : m => Var Z) : Float = scores = yieldAccum (MarkovMonoid Z) \ref . for i:m. ref += for z:Z. emit = (x'.i ~ emission.z) one emit .* (markov z'.i transition.z) (init' ~ initial.nil)$ for j. sum scores.j

At first glance, this seems much less efficient. Above we had an algorithm that only required $O(Z)$ storage whereas this requires $O(Z^2)$. In theory this approach can be parallelized over the intermediate size variable $m$.

This should give the same result as before.

hmm_monoid (observed (1@_)) (for i:(Fin 2). observed (1@_)) (for i. latent)
0.000285

Unfortunately though, the code for monoid's does not yet allow for auto-differentiation.

posteriorTab $\z . hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) z [ (AsDist [0.004153, 0.162454, 0.539687, 0.069191, 0.224516]) , (AsDist [0.180291, 0.288983, 0.129116, 0.108978, 0.292633]) ] Fancier Distributions def without_replacement (y: n=>m) (AsDist x: Dist m) : Dist m = renorm = sum for n'. x.(y.n') AsDist$ for m'. case (any for n'. (ordinal (y.n')) == (ordinal m')) of False -> (x.m' / (1.0 - renorm)) True -> 0.0

Probability Exercises (from Stat 110 textbook)

A college has 10 (non-overlapping) time slots for its 10 courses, and blithely assigns courses to time slots randomly and independently. A student randomly chooses 3 of the courses to enroll in. What is the probability that there is a conflict in the student’s schedule?

Slot = Fin 10
def courses (conflict:Var YesNo): Float = (one ~ uniform) (for (i,j,k):(Slot& Slot& Slot). (conflict ~ yesno ((i == j) || (j == k) || (i == k))) one)
courses (observed {|yes=()|})
0.28

A certain family has 6 children, consisting of 3 boys and 3 girls. Assuming that all birth orders are equally likely, what is the probability that the 3 eldest children are the 3 girls.

Children = Fin 6
def birth (event:Var YesNo): Float = (one ~ uniform) (for i:Children. (one ~ without_replacement [i] uniform) (for j:Children. (one ~ without_replacement [i, j] uniform) (for k:Children. (event ~ yesno ((ordinal i < 3) && (ordinal j < 3) && (ordinal k < 3))) one)))
birth (observed {|yes=()|})
0.05