aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-11 18:56:15 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-11 18:56:15 -0800
commit65d822247e30b6e104a8c09d3b930487b9f20a58 (patch)
treed44b83f001ab352b30e17ab981295c2ee70a4d56
parent26b43410569044aff46053cae7c68862825dd4ec (diff)
parent7b5c17c316481edbbd13815390d0b34fb50a03a6 (diff)
downloadbeliefs-65d822247e30b6e104a8c09d3b930487b9f20a58.tar.gz
beliefs-65d822247e30b6e104a8c09d3b930487b9f20a58.tar.bz2
beliefs-65d822247e30b6e104a8c09d3b930487b9f20a58.zip
LGS-173 Merge branch 'bernoulli_and_node'v0.0.3
-rw-r--r--VERSION2
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py45
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py7
-rw-r--r--beliefs/models/belief_update_node_model.py49
-rw-r--r--tests/test_belief_propagation.py122
5 files changed, 209 insertions, 16 deletions
diff --git a/VERSION b/VERSION
index 4e379d2..bcab45a 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.0.2
+0.0.3
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
new file mode 100644
index 0000000..fdb0c25
--- /dev/null
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from beliefs.factors.cpd import TabularCPD
+
+
+class BernoulliAndCPD(TabularCPD):
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents (also Bernoulli random variables) is described by AND logic.
+
+ If all of the variable's parents are True, then the variable
+ is True, and False otherwise.
+ """
+ def __init__(self, variable, parents=[]):
+ """
+ Args:
+ variable: int or string
+ parents: optional, list of int and/or strings
+ """
+ super().__init__(variable=variable,
+ variable_card=2,
+ parents=parents,
+ parents_card=[2]*len(parents),
+ values=[])
+ self._values = None
+
+ @property
+ def values(self):
+ if self._values is None:
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
+ return np.array(
+ [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.]
+ )
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index bfb3a95..12ee2f6 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD):
parents=parents,
parents_card=[2]*len(parents),
values=[])
- self._values = []
+ self._values = None
@property
def values(self):
- if not any(self._values):
+ if self._values is None:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
@@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD):
if k == 1:
return np.array([0.5, 0.5])
+ # values are stored as a row vector using an ordering such that
+ # the right-most variables as defined in [variable].extend(parents)
+ # cycle through their values the fastest.
return np.array(
[1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
)
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 667e0f1..1c3ba6e 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -8,6 +8,7 @@ import networkx as nx
from beliefs.models.base_models import BayesianModel
from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
+from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD
class InvalidLambdaMsgToParent(Exception):
@@ -212,7 +213,7 @@ class Node:
raise NotImplementedError
def compute_lambda_agg(self):
- if not self.children:
+ if len(self.children) == 0:
return self.lambda_agg
else:
lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
@@ -289,11 +290,13 @@ class BernoulliOrNode(Node):
cpd=BernoulliOrCPD(label_id, parents))
def compute_pi_agg(self):
- if not self.parents:
+ if len(self.parents) == 0:
self.pi_agg = self.cpd.values
else:
pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p0 = [p[0] for p in pi_msg_values]
+ # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
+ # of p = [P(False), P(True)]
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
self.pi_agg = np.array([p_0, p_1])
@@ -305,7 +308,7 @@ class BernoulliOrNode(Node):
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
lambda_1 = self.lambda_agg[1]
@@ -313,3 +316,43 @@ class BernoulliOrNode(Node):
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
return self._normalize(lambda_msg)
+
+
+class BernoulliAndNode(Node):
+ def __init__(self,
+ label_id,
+ children,
+ parents):
+ super().__init__(label_id=label_id,
+ children=children,
+ parents=parents,
+ cardinality=2,
+ cpd=BernoulliAndCPD(label_id, parents))
+
+ def compute_pi_agg(self):
+ if len(self.parents) == 0:
+ self.pi_agg = self.cpd.values
+ else:
+ pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p1 = [p[1] for p in pi_msg_values]
+ # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
+ # of p = [P(False), P(True)]
+ p_1 = reduce(lambda x, y: x*y, parents_p1)
+ p_0 = 1 - p_1
+ self.pi_agg = np.array([p_0, p_1])
+ return self.pi_agg
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ # TODO: cleanup this validation
+ _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
+ p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
+ lambda_0 = self.lambda_agg[0]
+ lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product
+ lambda_msg = np.array([lambda_0, lambda_1])
+ if not any(lambda_msg):
+ raise InvalidLambdaMsgToParent
+ return self._normalize(lambda_msg)
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 5c5a612..7a77311 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -5,13 +5,14 @@ from pytest import approx
from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
from beliefs.models.belief_update_node_model import (
BeliefUpdateNodeModel,
- BernoulliOrNode
+ BernoulliOrNode,
+ BernoulliAndNode
)
@pytest.fixture(scope='module')
-def edges_four_nodes():
- """Edges define a polytree with 4 nodes (connected in an X-shape with the
+def edges_five_nodes():
+ """Edges define a polytree with 5 nodes (connected in an X-shape with the
node, 'x', at the center of the X."""
edges = [('u', 'x'), ('v', 'x'), ('x', 'y'), ('x', 'z')]
return edges
@@ -38,8 +39,8 @@ def many_parents_edges():
@pytest.fixture(scope='function')
-def four_node_model(edges_four_nodes):
- return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode)
+def five_node_model(edges_five_nodes):
+ return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliOrNode)
@pytest.fixture(scope='function')
@@ -53,11 +54,41 @@ def many_parents_model(many_parents_edges):
@pytest.fixture(scope='function')
+def many_parents_and_model(many_parents_edges):
+ return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliAndNode)
+
+
+@pytest.fixture(scope='function')
def one_node_model():
a_node = BernoulliOrNode(label_id='x', children=[], parents=[])
return BeliefUpdateNodeModel(nodes_dict={'x': a_node})
+@pytest.fixture(scope='function')
+def five_node_and_model(edges_five_nodes):
+ return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliAndNode)
+
+
+@pytest.fixture(scope='function')
+def mixed_cpd_model(edges_five_nodes):
+ """
+ X-shaped 5 node model plus one more node, 'w', with edge from 'w' to 'z'.
+ 'z' is an AND node while all other nodes are OR nodes.
+ """
+ u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[])
+ v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[])
+ x_node = BernoulliOrNode(label_id='x', children=['y', 'z'], parents=['u', 'v'])
+ y_node = BernoulliOrNode(label_id='y', children=[], parents=['x'])
+ z_node = BernoulliAndNode(label_id='z', children=[], parents=['x', 'w'])
+ w_node = BernoulliOrNode(label_id='w', children=['z'], parents=[])
+ return BeliefUpdateNodeModel(nodes_dict={'u': u_node,
+ 'v': v_node,
+ 'x': x_node,
+ 'y': y_node,
+ 'z': z_node,
+ 'w': w_node})
+
+
def get_label_mapped_to_positive_belief(query_result):
"""Return a dictionary mapping each label_id to the probability of
the label being True."""
@@ -118,27 +149,90 @@ def test_NO_evidence_one_node_model(one_node_model):
#==============================================================================================
-# Tests of 4-node, 4-edge model
+# Tests of 5-node, 4-edge model
-def test_no_evidence_four_node_model(four_node_model):
+def test_no_evidence_five_node_model(five_node_model):
expected = {'x': 1-0.5**2}
- infer = BeliefPropagation(four_node_model)
+ infer = BeliefPropagation(five_node_model)
query_result = infer.query(evidence={})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
-def test_virtual_evidence_for_node_x_four_node_model(four_node_model):
+def test_virtual_evidence_for_node_x_five_node_model(five_node_model):
"""Virtual evidence for node x."""
expected = {'x': 0.967741935483871, 'y': 0.967741935483871, 'z': 0.967741935483871,
'u': 0.6451612903225806, 'v': 0.6451612903225806}
- infer = BeliefPropagation(four_node_model)
+ infer = BeliefPropagation(five_node_model)
query_result = infer.query(evidence={'x': np.array([1, 10])})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
#==============================================================================================
+# Tests of 5-node, 4-edge model with AND cpds
+
+def test_no_evidence_five_node_and_model(five_node_and_model):
+ expected = {'x': 0.5**2}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_one_parent_false_five_node_and_model(five_node_and_model):
+ expected = {'x': 0}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([1,0])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_one_parent_true_five_node_and_model(five_node_and_model):
+ expected = {'x': 0.5}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_both_parents_true_five_node_and_model(five_node_and_model):
+ expected = {'x': 1, 'y': 1, 'z': 1}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([0,1]), 'v': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
+# Tests of mixed cpd model (all CPDs are OR, except for one AND node with 2 parents)
+
+
+def test_no_evidence_mixed_cpd_model(mixed_cpd_model):
+ expected = {'x': 1-0.5**2, 'z': 0.5*(1-0.5**2)}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_x_false_w_true_mixed_cpd_model(mixed_cpd_model):
+ expected = {'u': 0, 'v': 0, 'y': 0, 'z': 0}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([1,0]), 'w': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_x_true_w_true_mixed_cpd_model(mixed_cpd_model):
+ expected = {'y': 1, 'z': 1}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([0,1]), 'w': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
# Tests of simple BernoulliOr polytree model
def test_no_evidence_simple_model(simple_model):
@@ -253,3 +347,11 @@ def test_negative_evidence_node_62(many_parents_model):
query_result = infer.query(evidence={'62': np.array([1, 0])})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
+
+
+def test_conflicting_evidence_and_model(many_parents_and_model):
+ """If one of the parents of node 62 is False, then node 62 has to be False."""
+ infer = BeliefPropagation(many_parents_and_model)
+ with pytest.raises(ConflictingEvidenceError) as err:
+ query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])})
+ assert "Can't run belief propagation with conflicting evidence" in str(err)