Newer
Older
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Probability theory allows us to compute the likelihood of certain events, given assumptioons about the components of the event. A Bayesian network, or Bayes net for short, is a data structure to represent a joint probability distribution over several random variables, and do inference on it. \n",
"\n",
"As an example, here is a network with five random variables, each with its conditional probability table, and with arrows from parent to child variables. The story, from Judea Pearl, is that there is a house burglar alarm, which can be triggered by either a burglary or an earthquake. If the alarm sounds, one or both of the neighbors, John and Mary, might call the owwner to say the alarm is sounding.\n",
"\n",
"<p><img src=\"http://norvig.com/ipython/burglary2.jpg\">\n",
"\n",
"We implement this with the help of seven Python classes:\n",
"\n",
"\n",
"## `BayesNet()`\n",
"\n",
"A `BayesNet` is a graph (as in the diagram above) where each node represents a random variable, and the edges are parent→child links. You can construct an empty graph with `BayesNet()`, then add variables one at a time with the method call `.add(`*variable_name, parent_names, cpt*`)`, where the names are strings, and each of the `parent_names` must already have been `.add`ed.\n",
"\n",
"## `Variable(`*name, cpt, parents*`)`\n",
"\n",
"A random variable; the ovals in the diagram above. The value of a variable depends on the value of the parents, in a probabilistic way specified by the variable's conditional probability table (CPT). Given the parents, the variable is independent of all the other variables. For example, if I know whether *Alarm* is true or false, then I know the probability of *JohnCalls*, and evidence about the other variables won't give me any more information about *JohnCalls*. Each row of the CPT uses the same order of variables as the list of parents.\n",
"We will only allow variables with a finite discrete domain; not continuous values. \n",
"\n",
"## `ProbDist(`*mapping*`)`<br>`Factor(`*mapping*`)`\n",
"A probability distribution is a mapping of `{outcome: probability}` for every outcome of a random variable. \n",
"You can give `ProbDist` the same arguments that you would give to the `dict` initializer, for example\n",
"`ProbDist(sun=0.6, rain=0.1, cloudy=0.3)`.\n",
"As a shortcut for Boolean Variables, you can say `ProbDist(0.95)` instead of `ProbDist({T: 0.95, F: 0.05})`. \n",
"In a probability distribution, every value is between 0 and 1, and the values sum to 1.\n",
"A `Factor` is similar to a probability distribution, except that the values need not sum to 1. Factors\n",
"are used in the variable elimination inference method.\n",
"A mapping of `{Variable: value, ...}` pairs, describing the exact values for a set of variables—the things we know for sure.\n",
"A conditional probability table (or *CPT*) describes the probability of each possible outcome value of a random variable, given the values of the parent variables. A `CPTable` is a a mapping, `{tuple: probdist, ...}`, where each tuple lists the values of each of the parent variables, in order, and each probability distribution says what the possible outcomes are, given those values of the parents. The `CPTable` for *Alarm* in the diagram above would be represented as follows:\n",
" CPTable({(T, T): .95,\n",
" (T, F): .94,\n",
" (F, T): .29,\n",
" (F, F): .001},\n",
" [Burglary, Earthquake])\n",
" \n",
"How do you read this? Take the second row, \"`(T, F): .94`\". This means that when the first parent (`Burglary`) is true, and the second parent (`Earthquake`) is fale, then the probability of `Alarm` being true is .94. Note that the .94 is an abbreviation for `ProbDist({T: .94, F: .06})`.\n",
" \n",
"## `T = Bool(True); F = Bool(False)`\n",
"When I used `bool` values (`True` and `False`), it became hard to read rows in CPTables, because the columns didn't line up:\n",
"\n",
" (True, True, False, False, False)\n",
" (False, False, False, False, True)\n",
" (True, False, False, True, True)\n",
" \n",
"Therefore, I created the `Bool` class, with constants `T` and `F` such that `T == True` and `F == False`, and now rows are easier to read:\n",
" (T, T, F, F, F)\n",
" (F, F, F, F, T)\n",
" (T, F, F, T, T)\n",
" \n",
"Here is the code for these classes:"
]
},
{
"cell_type": "code",
"metadata": {
"button": false,
"collapsed": true,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"from collections import defaultdict, Counter\n",
"import itertools\n",
"import math\n",
"import random\n",
"\n",
" \"Bayesian network: a graph of variables connected by parent links.\"\n",
" \n",
" def __init__(self): \n",
" self.variables = [] # List of variables, in parent-first topological sort order\n",
" self.lookup = {} # Mapping of {variable_name: variable} pairs\n",
" \n",
" def add(self, name, parentnames, cpt):\n",
" \"Add a new Variable to the BayesNet. Parentnames must have been added previously.\"\n",
" parents = [self.lookup[name] for name in parentnames]\n",
" self.variables.append(var)\n",
" self.lookup[name] = var\n",
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
" return self\n",
" \n",
"class Variable(object):\n",
" \"A discrete random variable; conditional on zero or more parent Variables.\"\n",
" \n",
" def __init__(self, name, cpt, parents=()):\n",
" \"A variable has a name, list of parent variables, and a Conditional Probability Table.\"\n",
" self.__name__ = name\n",
" self.parents = parents\n",
" self.cpt = CPTable(cpt, parents)\n",
" self.domain = set(itertools.chain(*self.cpt.values())) # All the outcomes in the CPT\n",
" \n",
" def __repr__(self): return self.__name__\n",
" \n",
"class Factor(dict): \"An {outcome: frequency} mapping.\"\n",
"\n",
"class ProbDist(Factor):\n",
" \"\"\"A Probability Distribution is an {outcome: probability} mapping. \n",
" The values are normalized to sum to 1.\n",
" ProbDist(0.75) is an abbreviation for ProbDist({T: 0.75, F: 0.25}).\"\"\"\n",
" def __init__(self, mapping=(), **kwargs):\n",
" if isinstance(mapping, float):\n",
" mapping = {T: mapping, F: 1 - mapping}\n",
" self.update(mapping, **kwargs)\n",
" normalize(self)\n",
" \n",
"class Evidence(dict): \n",
" \"A {variable: value} mapping, describing what we know for sure.\"\n",
" \n",
"class CPTable(dict):\n",
" \"A mapping of {row: ProbDist, ...} where each row is a tuple of values of the parent variables.\"\n",
" \n",
" def __init__(self, mapping, parents=()):\n",
" \"\"\"Provides two shortcuts for writing a Conditional Probability Table. \n",
" With no parents, CPTable(dist) means CPTable({(): dist}).\n",
" With one parent, CPTable({val: dist,...}) means CPTable({(val,): dist,...}).\"\"\"\n",
" if len(parents) == 0 and not (isinstance(mapping, dict) and set(mapping.keys()) == {()}):\n",
" mapping = {(): mapping}\n",
" for (row, dist) in mapping.items():\n",
" if len(parents) == 1 and not isinstance(row, tuple): \n",
" row = (row,)\n",
" self[row] = ProbDist(dist)\n",
"\n",
"class Bool(int):\n",
" \"Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'\"\n",
" __str__ = __repr__ = lambda self: 'T' if self else 'F'\n",
" \n",
"T = Bool(True)\n",
"F = Bool(False)"
]
},
{
"cell_type": "markdown",
]
},
{
"cell_type": "code",
},
"outputs": [],
"source": [
"def P(var, evidence={}):\n",
" \"The probability distribution for P(variable | evidence), when all parent variables are known (in evidence).\"\n",
" row = tuple(evidence[parent] for parent in var.parents)\n",
" return var.cpt[row]\n",
"def normalize(dist):\n",
" \"Normalize a {key: value} distribution so values sum to 1.0. Mutates dist and returns it.\"\n",
" total = sum(dist.values())\n",
" for key in dist:\n",
" dist[key] = dist[key] / total\n",
" assert 0 <= dist[key] <= 1, \"Probabilities must be between 0 and 1.\"\n",
" return dist\n",
"\n",
"def sample(probdist):\n",
" \"Randomly sample an outcome from a probability distribution.\"\n",
" r = random.random() # r is a random point in the probability distribution\n",
" c = 0.0 # c is the cumulative probability of outcomes seen so far\n",
" for outcome in probdist:\n",
" c += probdist[outcome]\n",
" if r <= c:\n",
" return outcome\n",
" \n",
"def globalize(mapping):\n",
" \"Given a {name: value} mapping, export all the names to the `globals()` namespace.\"\n",
" globals().update(mapping)"
]
},
{
"cell_type": "markdown",
"Here are some examples of using the classes:"
]
},
{
"cell_type": "code",
},
"outputs": [],
"source": [
"# Example random variable: Earthquake:\n",
"# An earthquake occurs on 0.002 of days, independent of any other variables.\n",
"Earthquake = Variable('Earthquake', 0.002)"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The probability distribution for Earthquake\n",
"P(Earthquake)"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get the probability of a specific outcome by subscripting the probability distribution\n",
"P(Earthquake)[T]"
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": [
"F"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"# Randomly sample from the distribution:\n",
"sample(P(Earthquake))"
]
},
{
"cell_type": "code",
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"Counter({F: 99793, T: 207})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
],
"source": [
"# Randomly sample 100,000 times, and count up the results:\n",
"Counter(sample(P(Earthquake)) for i in range(100000))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Two equivalent ways of specifying the same Boolean probability distribution:\n",
"assert ProbDist(0.75) == ProbDist({T: 0.75, F: 0.25})"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Two equivalent ways of specifying the same non-Boolean probability distribution:\n",
"assert ProbDist(win=15, lose=3, tie=2) == ProbDist({'win': 15, 'lose': 3, 'tie': 2})\n",
"ProbDist(win=15, lose=3, tie=2)"
"cell_type": "code",
"execution_count": 10,
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'a': 1, 'b': 2, 'c': 3, 'd': 4}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
],
"source": [
"# The difference between a Factor and a ProbDist--the ProbDist is normalized:\n",
"Factor(a=1, b=2, c=3, d=4)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
"outputs": [
{
"data": {
"text/plain": [
"{'a': 0.1, 'b': 0.2, 'c': 0.3, 'd': 0.4}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ProbDist(a=1, b=2, c=3, d=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"Here is how we define the Bayes net from the diagram above:"
]
},
{
"cell_type": "code",
},
"outputs": [],
"source": [
"alarm_net = (BayesNet()\n",
" .add('Burglary', [], 0.001)\n",
" .add('Earthquake', [], 0.002)\n",
" .add('Alarm', ['Burglary', 'Earthquake'], {(T, T): 0.95, (T, F): 0.94, (F, T): 0.29, (F, F): 0.001})\n",
" .add('JohnCalls', ['Alarm'], {T: 0.90, F: 0.05})\n",
" .add('MaryCalls', ['Alarm'], {T: 0.70, F: 0.01})) "
]
},
{
"cell_type": "code",
"outputs": [
{
"data": {
"text/plain": [
"[Burglary, Earthquake, Alarm, JohnCalls, MaryCalls]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"# Make Burglary, Earthquake, etc. be global variables\n",
"globalize(alarm_net.lookup) \n",
"alarm_net.variables"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Probability distribution of a Burglary\n",
"P(Burglary)"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Probability of Alarm going off, given a Burglary and not an Earthquake:\n",
"P(Alarm, {Burglary: T, Earthquake: F})"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"{(F, F): {F: 0.999, T: 0.001},\n",
" (F, T): {F: 0.71, T: 0.29},\n",
" (T, F): {F: 0.06000000000000005, T: 0.94},\n",
" (T, T): {F: 0.050000000000000044, T: 0.95}}"
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Where that came from: the (T, F) row of Alarm's CPT:\n",
"Alarm.cpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"# Bayes Nets as Joint Probability Distributions\n",
"\n",
"A Bayes net is a compact way of specifying a full joint distribution over all the variables in the network. Given a set of variables {*X*<sub>1</sub>, ..., *X*<sub>*n*</sub>}, the full joint distribution is:\n",
"\n",
"P(*X*<sub>1</sub>=*x*<sub>1</sub>, ..., *X*<sub>*n*</sub>=*x*<sub>*n*</sub>) = <font size=large>Π</font><sub>*i*</sub> P(*X*<sub>*i*</sub> = *x*<sub>*i*</sub> | parents(*X*<sub>*i*</sub>))\n",
"\n",
"For a network with *n* variables, each of which has *b* values, there are *b<sup>n</sup>* rows in the joint distribution (for example, a billion rows for 30 Boolean variables), making it impractical to explicitly create the joint distribution for large networks. But for small networks, the function `joint_distribution` creates the distribution, which can be instructive to look at, and can be used to do inference. "
]
},
{
"cell_type": "code",
},
"outputs": [],
"source": [
"def joint_distribution(net):\n",
" \"Given a Bayes net, create the joint distribution over all variables.\"\n",
" return ProbDist({row: prod(P_xi_given_parents(var, row, net)\n",
" for var in net.variables)\n",
" for row in all_rows(net)})\n",
"def all_rows(net): return itertools.product(*[var.domain for var in net.variables])\n",
"\n",
"def P_xi_given_parents(var, row, net):\n",
" \"The probability that var = xi, given the values in this row.\"\n",
" dist = P(var, Evidence(zip(net.variables, row)))\n",
" xi = row[net.variables.index(var)]\n",
" return dist[xi]\n",
"\n",
"def prod(numbers):\n",
" \"The product of numbers: prod([2, 3, 5]) == 30. Analogous to `sum([2, 3, 5]) == 10`.\"\n",
" result = 1\n",
" for x in numbers:\n",
" result *= x\n",
" return result"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
"{(F, F, F, F, F),\n",
" (F, F, F, F, T),\n",
" (F, F, F, T, F),\n",
" (F, F, F, T, T),\n",
" (F, F, T, F, F),\n",
" (F, F, T, F, T),\n",
" (F, F, T, T, F),\n",
" (F, F, T, T, T),\n",
" (F, T, F, F, F),\n",
" (F, T, F, F, T),\n",
" (F, T, F, T, F),\n",
" (F, T, F, T, T),\n",
" (F, T, T, F, F),\n",
" (F, T, T, F, T),\n",
" (F, T, T, T, F),\n",
" (F, T, T, T, T),\n",
" (T, F, F, F, F),\n",
" (T, F, F, F, T),\n",
" (T, F, F, T, F),\n",
" (T, F, F, T, T),\n",
" (T, F, T, F, F),\n",
" (T, F, T, F, T),\n",
" (T, F, T, T, F),\n",
" (T, F, T, T, T),\n",
" (T, T, F, F, F),\n",
" (T, T, F, F, T),\n",
" (T, T, F, T, F),\n",
" (T, T, F, T, T),\n",
" (T, T, T, F, F),\n",
" (T, T, T, F, T),\n",
" (T, T, T, T, F),\n",
" (T, T, T, T, T)}"
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# All rows in the joint distribution (2**5 == 32 rows)\n",
"set(all_rows(alarm_net))"
]
},
{
"cell_type": "code",
"collapsed": false
},
"outputs": [],
"source": [
"# Let's work through just one row of the table:\n",
"row = (F, F, F, F, F)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This is the probability distribution for Alarm\n",
"P(Alarm, {Burglary: F, Earthquake: F})"
]
},
{
"cell_type": "code",
},
"outputs": [
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Here's the probability that Alarm is false, given the parent values in this row:\n",
"P_xi_given_parents(Alarm, row, alarm_net)"
]
},
{
"cell_type": "code",
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{(F, F, F, F, F): 0.9367427006190001,\n",
" (F, F, F, F, T): 0.009462047481000001,\n",
" (F, F, F, T, F): 0.04930224740100002,\n",
" (F, F, F, T, T): 0.0004980024990000002,\n",
" (F, F, T, F, F): 2.9910060000000004e-05,\n",
" (F, F, T, F, T): 6.979013999999999e-05,\n",
" (F, F, T, T, F): 0.00026919054000000005,\n",
" (F, F, T, T, T): 0.00062811126,\n",
" (F, T, F, F, F): 0.0013341744900000002,\n",
" (F, T, F, F, T): 1.3476510000000005e-05,\n",
" (F, T, F, T, F): 7.021971000000001e-05,\n",
" (F, T, F, T, T): 7.092900000000001e-07,\n",
" (F, T, T, F, F): 1.7382600000000002e-05,\n",
" (F, T, T, F, T): 4.0559399999999997e-05,\n",
" (F, T, T, T, F): 0.00015644340000000006,\n",
" (F, T, T, T, T): 0.00036503460000000007,\n",
" (T, F, F, F, F): 5.631714000000006e-05,\n",
" (T, F, F, F, T): 5.688600000000006e-07,\n",
" (T, F, F, T, F): 2.9640600000000033e-06,\n",
" (T, F, F, T, T): 2.9940000000000035e-08,\n",
" (T, F, T, F, F): 2.8143600000000003e-05,\n",
" (T, F, T, F, T): 6.56684e-05,\n",
" (T, F, T, T, F): 0.0002532924000000001,\n",
" (T, F, T, T, T): 0.0005910156000000001,\n",
" (T, T, F, F, F): 9.40500000000001e-08,\n",
" (T, T, F, F, T): 9.50000000000001e-10,\n",
" (T, T, F, T, F): 4.9500000000000054e-09,\n",
" (T, T, F, T, T): 5.0000000000000066e-11,\n",
" (T, T, T, F, F): 5.7e-08,\n",
" (T, T, T, F, T): 1.3299999999999996e-07,\n",
" (T, T, T, T, F): 5.130000000000002e-07,\n",
" (T, T, T, T, T): 1.1970000000000001e-06}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
],
"source": [
"# The full joint distribution:\n",
"joint_distribution(alarm_net)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Burglary, Earthquake, Alarm, JohnCalls, MaryCalls]\n"
]
},
{
"data": {
"text/plain": [
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Probability that \"the alarm has sounded, but neither a burglary nor an earthquake has occurred, \n",
"# and both John and Mary call\" (page 514 says it should be 0.000628)\n",
"\n",
"print(alarm_net.variables)\n",
"joint_distribution(alarm_net)[F, F, T, T, T]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inference by Querying the Joint Distribution\n",
"\n",
"We can use `P(variable, evidence)` to get the probability of aa variable, if we know the vaues of all the parent variables. But what if we don't know? Bayes nets allow us to calculate the probability, but the calculation is not just a lookup in the CPT; it is a global calculation across the whole net. One inefficient but straightforward way of doing the calculation is to create the joint probability distribution, then pick out just the rows that\n",
"match the evidence variables, and for each row check what the value of the query variable is, and increment the probability for that value accordningly:"
]
},
{
"cell_type": "code",
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
"collapsed": false
},
"outputs": [],
"source": [
"def enumeration_ask(X, evidence, net):\n",
" \"The probability distribution for query variable X in a belief net, given evidence.\"\n",
" i = net.variables.index(X) # The index of the query variable X in the row\n",
" dist = defaultdict(float) # The resulting probability distribution over X\n",
" for (row, p) in joint_distribution(net).items():\n",
" if matches_evidence(row, evidence, net):\n",
" dist[row[i]] += p\n",
" return ProbDist(dist)\n",
"\n",
"def matches_evidence(row, evidence, net):\n",
" \"Does the tuple of values for this row agree with the evidence?\"\n",
" return all(evidence[v] == row[net.variables.index(v)]\n",
" for v in evidence)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.9931237539265789, T: 0.006876246073421024}"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
],
"source": [
"# The probability of a Burgalry, given that John calls but Mary does not: \n",
"enumeration_ask(Burglary, {JohnCalls: F, MaryCalls: T}, alarm_net)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.03368899586522123, T: 0.9663110041347788}"
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The probability of an Alarm, given that there is an Earthquake and Mary calls:\n",
"enumeration_ask(Alarm, {MaryCalls: T, Earthquake: T}, alarm_net)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Variable Elimination\n",
"\n",
"The `enumeration_ask` algorithm takes time and space that is exponential in the number of variables. That is, first it creates the joint distribution, of size *b<sup>n</sup>*, and then it sums out the values for the rows that match the evidence. We can do better than that if we interleave the joining of variables with the summing out of values.\n",
"This approach is called *variable elimination*. The key insight is that\n",
"when we compute\n",
"\n",
"P(*X*<sub>1</sub>=*x*<sub>1</sub>, ..., *X*<sub>*n*</sub>=*x*<sub>*n*</sub>) = <font size=large>Π</font><sub>*i*</sub> P(*X*<sub>*i*</sub> = *x*<sub>*i*</sub> | parents(*X*<sub>*i*</sub>))\n",
"\n",
"we are repeating the calculation of, say, P(*X*<sub>*3*</sub> = *x*<sub>*4*</sub> | parents(*X*<sub>*3*</sub>))\n",
"multiple times, across multiple rows of the joint distribution.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# TODO: Copy over and update Variable Elimination algorithm. Also, sampling algorithms."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"In this net, whether a patient gets the flu is dependent on whether they were vaccinated, and having the flu influences whether they get a fever or headache. Here `Fever` is a non-Boolean variable, with three values, `no`, `mild`, and `high`."
"cell_type": "code",
"execution_count": 28,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
"outputs": [],
"source": [
"flu_net = (BayesNet()\n",
" .add('Vaccinated', [], 0.60)\n",
" .add('Flu', ['Vaccinated'], {T: 0.002, F: 0.02})\n",
" .add('Fever', ['Flu'], {T: ProbDist(no=25, mild=25, high=50),\n",
" F: ProbDist(no=97, mild=2, high=1)})\n",
" .add('Headache', ['Flu'], {T: 0.5, F: 0.03}))\n",
"\n",
"globalize(flu_net.lookup)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.9616440110625343, T: 0.03835598893746573}"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"# If you just have a headache, you probably don't have the Flu.\n",
"enumeration_ask(Flu, {Headache: T, Fever: 'no'}, flu_net)"
]
},
{
"cell_type": "code",
"metadata": {
"button": false,
"collapsed": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.9914651882096696, T: 0.008534811790330398}"
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Even more so if you were vaccinated.\n",
"enumeration_ask(Flu, {Headache: T, Fever: 'no', Vaccinated: T}, flu_net)"
]
},
{
"cell_type": "code",
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.9194016377587207, T: 0.08059836224127925}"
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# But if you were not vaccinated, there is a higher chance you have the flu.\n",
"enumeration_ask(Flu, {Headache: T, Fever: 'no', Vaccinated: F}, flu_net)"
]
},
{
"cell_type": "code",
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{F: 0.1904145077720207, T: 0.8095854922279793}"