Dex: Differential Probabilistic Inference
This notebook develops an unconventional approach to probabilistic inference using Dex tables and auto-differentiation. It is based loosely on the approach described in:
A Differential Approach to Inference in Bayesian Networks, Adnan Darwiche (2003)
This approach can be thought of as a probabilistic programming language (PPL) in the sense the inference is seperated from the modeling language. However, it does not require interrupting standard control flow.
Running Example 1: Coins and Dice
Let us start with a simple probabilistic modeling example to establish some notation.
In this example we have a coin and two weighted dice. We first flip the coin, if it is heads we roll dice 1 and if it is tails we roll dice 2.
This defines a generative process over two random variables $\mathbf{X} = { A, B } $, the coin flip and the dice roll respectively. We can write the process explicitly as,
$$a \sim Pr(A)$$ $$ b \sim Pr(B\ |\ A=a) $$
Probability Combinators
A discrete probability distribution is a normalized table of probabilities
Distributions are easy to create. Here are a couple simple ones.
And they are displayed by their support.
Expectations can be taken over arbitrary tables.
To represent conditional probabilities such as $ Pr(B \ |\ A)$ we define a type alias.
With this machinery we can define distributions for the coin and the dice.
Attempt 1: Observations and Marginalization
This allows us to compute the probability of any full observation from our model. $$Pr(A=a, B=b) = Pr(B=b\ | A=a) Pr(A=a) $$
However, this assumes that we have full observation of our variables. What if the coin is latent? This requires a sum. $$Pr(B) = \sum_a Pr(B\ | A=a) Pr(A=a) $$
But now we have two seperate functions for the same model! This feels unnecessary and bug-prone.
Attempt 2: Indicator Variables and the Network Polynomial
In order to make things simpler we will introduce explicit indicator variables $\lambda$ to the modeling language.
These can either be observed or latent. If a random variable is observed then we use an indicator. The expectation over the indicator gives,
$$ p(A=a) = E_{a' \sim p(A)} \lambda_{a'} $$
If it is latent the variable is one everywhere.
The probability chain rule tells us that we can propagate conditioning. $$\sum_{b} Pr(A,\ B = b) = Pr(A) \sum_b Pr(B=b\ | A)$$
This implies that the expectation of these indicators factors as well.
$$E_{a, b\sim Pr(A,\ B)} \lambda_a \lambda_b = E_{a'\sim Pr(A)}\left[ \lambda_a E_{b' \sim Pr(B | A=a)} [\lambda_b] \right] $$
We can write one step of this chain rule really cleanly in Dex.
This allows us to final write our model down in an extremely clean form.
$$a \sim Pr(A)$$ $$ b \sim Pr(B\ |\ A=a) $$
Now we can easily reproduce all the result above.
This representation for joint distributions is known as the network polynomial. This is a multi-linear function that uses indicator variables to represent data observations.
$$ f(\lambda) = \sum_{\mathbf{x}} \prod_{x, \mathbf{u}\in \mathbf{x}} \lambda_x \theta_{x|\mathbf{u}} $$
Here $\theta$ is the model parameters. These play the same role as above. The $\lambda$ are evidence indicators which indicate the states of the variable instantiations.
The network polynomial can be used to compute marginal probabilities of any subset of variables. Let $\mathbf{e}$ be the observations of some subset of $\mathbf{X}$. Darwiche shows that -
$$f[\mathbf{e}] = p(\mathbf{E} = \mathbf{e})$$
Where $f[e]$ assigns 1 to any $\lambda$ term that is consistent (non-contradictory) with $\mathbf{e}$ and 0 otherwise. Let's look at an example.
Differential Inference
The network polynomial is a convenient method for computing probilities, but what makes it particularly useful is that it allows us to compute posterior probabilities simply using derivatives.
For example, consider the probability on the coin flip given an observation of a dice roll. We can compute this using Bayes' Rule.
$$Pr(A | B=b) \propto Pr(B=b | A) Pr(A)$$
However using the network polynomial we can compute this same term purely with derivatives. Computing partial derivatives directly yields joint terms.
$$\frac{df[\mathbf{e}]}{dx} = Pr(\mathbf{e}, x)$$
This implies that the derivative of the log polynomial yields posterior terms.
$$\frac{d\log f[\mathbf{e}]}{dx} = Pr(x\ |\ \mathbf{e})$$
Let us try this out. We can compute the posterior probabity of the first coin after observing the second.
And this yields exactly the term above! This is really neat, it doesn't require any application of model specific inference.
We can generalize this to compute a table of distributions.
Example 2: Bayes Nets
A classic example in probalistic modeling is the Wet grass Bayes' net. In this example we need to infer the factors that could have led to the grass being wet.
More details on the problem are given here.
We now define the tables above.
And the architecture of the Bayes net.
Example 3: Dice Counting
Here's a classic elementary probability problem. Given two standard dice rolls, what is the probability distribution over their sum?
Helper functions for Dice sum
Here's the result.
We might also ask what the probability of the dice rolls given on output value.
Discussion - Conditional Independence
One tricky problem for discrete PPLs is modeling conditional independence. Models can be very slow to compute if we are not careful to exploint conditional independence properties such as Markov assumptions.
For example, let us consider a more complex version of the coin flip example. We will flip three times. The choice of the second weighted coin depends on the first. The choice of third weighted coin depends on the second.
$$a \sim Pr(A)$$ $$ b \sim Pr(B\ |\ A=a) $$ $$ c \sim Pr(C\ |\ B=b) $$
In this example $C$ is conditionally independent of $A$ given $B$.
We can be lazy and create the distributions randomly.
Now here is the generative process.
Note that as written this process looks like it does not take
advantage of the conditional independence property of the model.
The construction of the final coin is in a for
constructor that
contains a
. However, Dex knows that a
is not used in the inner
most construct. In theory it can lift that out of the loop and exploit
the conditional independence.
Alternatively we can make this explicit and do the lifting ourselves.
Example 4: Monty Hall Problem
Perhaps the most celebrated elementary problem in conditional probability is the Monty Hall problem.
You are on a game show. The host asks you to pick a door at random to win a prize. After selecting a door, one of the remaining doors (without the prize) is removed. You are asked if you want to change your selection...
The generative model is relatively simple
- We will first sample our pick and the door.
- Then we will consider changing our pick.
- Finally we will see if we won.
To check the odds we will compute probabity of winning conditioned on changing.
And compare to proability of winning with no change.
Example 5: Hidden Markov Models
Finally we conclude with a more complex example. A hidden Markov model is one of the most widely used discrete time series models. It models the relationship between discrete hidden states $Z$ and emissions $X$.
It consists of three distributions: initial, transition, and emission.
The model itself takes the following form for $m$ steps. ' $$ z_0 \sim \text{initial}$$ $$ z_1 \sim \text{transition}(z_0)$$ $$ x_1 \sim \text{emission}(z_1)$$ $$ ...$$
This is implemented in reverse order for clarity (backward algorithm).
We can marginalize out over latents.
Or we can compute the posterior probabilities of specific values.
Example 5a. HMM Monoid
We can also write out an HMM using a Monoid. Here we define a monoid for square matrix multiplication.
We also define a Markov version of our sample function. Instead of summing out over the usage of its result, it constructs a matrix a vector.
Here we write out the HMM using a forward style approach. Each time through the algorithm the accumulator represents the matrix of the joint likelihood from position 1 to i.
At first glance, this seems much less efficient. Above we had an algorithm that only required $O(Z)$ storage whereas this requires $O(Z^2)$. In theory this approach can be parallelized over the intermediate size variable $m$.
This should give the same result as before.
Unfortunately though, the code for monoid's does not yet allow for auto-differentiation.
Fancier Distributions
Probability Exercises (from Stat 110 textbook)
A college has 10 (non-overlapping) time slots for its 10 courses, and blithely assigns courses to time slots randomly and independently. A student randomly chooses 3 of the courses to enroll in. What is the probability that there is a conflict in the student’s schedule?
A certain family has 6 children, consisting of 3 boys and 3 girls. Assuming that all birth orders are equally likely, what is the probability that the 3 eldest children are the 3 girls.