aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-08 16:01:37 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-11 18:50:00 -0800
commit06626854ca893b44c128ca333fb5623591134746 (patch)
tree7f18bde2c2cc31b7d13c50ff34c17182e8adaccc
parent4373157138e85d2dbad9672cef5963a27a3d962c (diff)
downloadbeliefs-06626854ca893b44c128ca333fb5623591134746.tar.gz
beliefs-06626854ca893b44c128ca333fb5623591134746.tar.bz2
beliefs-06626854ca893b44c128ca333fb5623591134746.zip
tests for belief propagation with AND and mixed AND and OR nodes
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py2
-rw-r--r--tests/test_belief_propagation.py122
2 files changed, 113 insertions, 11 deletions
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index bfb3a95..162e156 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -25,7 +25,7 @@ class BernoulliOrCPD(TabularCPD):
@property
def values(self):
- if not any(self._values):
+ if len(self._values) == 0:
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 5c5a612..7a77311 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -5,13 +5,14 @@ from pytest import approx
from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
from beliefs.models.belief_update_node_model import (
BeliefUpdateNodeModel,
- BernoulliOrNode
+ BernoulliOrNode,
+ BernoulliAndNode
)
@pytest.fixture(scope='module')
-def edges_four_nodes():
- """Edges define a polytree with 4 nodes (connected in an X-shape with the
+def edges_five_nodes():
+ """Edges define a polytree with 5 nodes (connected in an X-shape with the
node, 'x', at the center of the X."""
edges = [('u', 'x'), ('v', 'x'), ('x', 'y'), ('x', 'z')]
return edges
@@ -38,8 +39,8 @@ def many_parents_edges():
@pytest.fixture(scope='function')
-def four_node_model(edges_four_nodes):
- return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode)
+def five_node_model(edges_five_nodes):
+ return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliOrNode)
@pytest.fixture(scope='function')
@@ -53,11 +54,41 @@ def many_parents_model(many_parents_edges):
@pytest.fixture(scope='function')
+def many_parents_and_model(many_parents_edges):
+ return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliAndNode)
+
+
+@pytest.fixture(scope='function')
def one_node_model():
a_node = BernoulliOrNode(label_id='x', children=[], parents=[])
return BeliefUpdateNodeModel(nodes_dict={'x': a_node})
+@pytest.fixture(scope='function')
+def five_node_and_model(edges_five_nodes):
+ return BeliefUpdateNodeModel.init_from_edges(edges_five_nodes, BernoulliAndNode)
+
+
+@pytest.fixture(scope='function')
+def mixed_cpd_model(edges_five_nodes):
+ """
+ X-shaped 5 node model plus one more node, 'w', with edge from 'w' to 'z'.
+ 'z' is an AND node while all other nodes are OR nodes.
+ """
+ u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[])
+ v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[])
+ x_node = BernoulliOrNode(label_id='x', children=['y', 'z'], parents=['u', 'v'])
+ y_node = BernoulliOrNode(label_id='y', children=[], parents=['x'])
+ z_node = BernoulliAndNode(label_id='z', children=[], parents=['x', 'w'])
+ w_node = BernoulliOrNode(label_id='w', children=['z'], parents=[])
+ return BeliefUpdateNodeModel(nodes_dict={'u': u_node,
+ 'v': v_node,
+ 'x': x_node,
+ 'y': y_node,
+ 'z': z_node,
+ 'w': w_node})
+
+
def get_label_mapped_to_positive_belief(query_result):
"""Return a dictionary mapping each label_id to the probability of
the label being True."""
@@ -118,27 +149,90 @@ def test_NO_evidence_one_node_model(one_node_model):
#==============================================================================================
-# Tests of 4-node, 4-edge model
+# Tests of 5-node, 4-edge model
-def test_no_evidence_four_node_model(four_node_model):
+def test_no_evidence_five_node_model(five_node_model):
expected = {'x': 1-0.5**2}
- infer = BeliefPropagation(four_node_model)
+ infer = BeliefPropagation(five_node_model)
query_result = infer.query(evidence={})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
-def test_virtual_evidence_for_node_x_four_node_model(four_node_model):
+def test_virtual_evidence_for_node_x_five_node_model(five_node_model):
"""Virtual evidence for node x."""
expected = {'x': 0.967741935483871, 'y': 0.967741935483871, 'z': 0.967741935483871,
'u': 0.6451612903225806, 'v': 0.6451612903225806}
- infer = BeliefPropagation(four_node_model)
+ infer = BeliefPropagation(five_node_model)
query_result = infer.query(evidence={'x': np.array([1, 10])})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
#==============================================================================================
+# Tests of 5-node, 4-edge model with AND cpds
+
+def test_no_evidence_five_node_and_model(five_node_and_model):
+ expected = {'x': 0.5**2}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_one_parent_false_five_node_and_model(five_node_and_model):
+ expected = {'x': 0}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([1,0])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_one_parent_true_five_node_and_model(five_node_and_model):
+ expected = {'x': 0.5}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_both_parents_true_five_node_and_model(five_node_and_model):
+ expected = {'x': 1, 'y': 1, 'z': 1}
+ infer = BeliefPropagation(five_node_and_model)
+ query_result = infer.query(evidence={'u': np.array([0,1]), 'v': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
+# Tests of mixed cpd model (all CPDs are OR, except for one AND node with 2 parents)
+
+
+def test_no_evidence_mixed_cpd_model(mixed_cpd_model):
+ expected = {'x': 1-0.5**2, 'z': 0.5*(1-0.5**2)}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_x_false_w_true_mixed_cpd_model(mixed_cpd_model):
+ expected = {'u': 0, 'v': 0, 'y': 0, 'z': 0}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([1,0]), 'w': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+def test_x_true_w_true_mixed_cpd_model(mixed_cpd_model):
+ expected = {'y': 1, 'z': 1}
+ infer = BeliefPropagation(mixed_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([0,1]), 'w': np.array([0,1])})
+ result = get_label_mapped_to_positive_belief(query_result)
+ compare_dictionaries(expected, result)
+
+
+#==============================================================================================
# Tests of simple BernoulliOr polytree model
def test_no_evidence_simple_model(simple_model):
@@ -253,3 +347,11 @@ def test_negative_evidence_node_62(many_parents_model):
query_result = infer.query(evidence={'62': np.array([1, 0])})
result = get_label_mapped_to_positive_belief(query_result)
compare_dictionaries(expected, result)
+
+
+def test_conflicting_evidence_and_model(many_parents_and_model):
+ """If one of the parents of node 62 is False, then node 62 has to be False."""
+ infer = BeliefPropagation(many_parents_and_model)
+ with pytest.raises(ConflictingEvidenceError) as err:
+ query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])})
+ assert "Can't run belief propagation with conflicting evidence" in str(err)