aboutsummaryrefslogtreecommitdiff
path: root/tests/test_belief_propagation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_belief_propagation.py')
-rw-r--r--tests/test_belief_propagation.py64
1 files changed, 63 insertions, 1 deletions
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 7a77311..1b8c0ac 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -3,10 +3,12 @@ import pytest
from pytest import approx
from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
+from beliefs.factors.cpd import TabularCPD
from beliefs.models.belief_update_node_model import (
BeliefUpdateNodeModel,
BernoulliOrNode,
- BernoulliAndNode
+ BernoulliAndNode,
+ Node
)
@@ -89,6 +91,41 @@ def mixed_cpd_model(edges_five_nodes):
'w': w_node})
+@pytest.fixture(scope='function')
+def custom_cpd_model():
+ """
+ Y-shaped model, with parents ,'u' and 'v' as Or-nodes, 'x' a node with
+ cardinality 3 and custom CPD, 'y' a node with cardinality 2 and custom CPD.
+ """
+ custom_cpd_x = TabularCPD(variable='x',
+ variable_card=3,
+ parents=['u', 'v'],
+ parents_card=[2, 2],
+ values=[[0.2, 0, 0.3, 0.1],
+ [0.4, 1, 0.7, 0.2],
+ [0.4, 0, 0, 0.7]],
+ state_names={'x': ['lo', 'med', 'hi'],
+ 'u': ['False', 'True'],
+ 'v': ['False', 'True']})
+ custom_cpd_y = TabularCPD(variable='y',
+ variable_card=2,
+ parents=['x'],
+ parents_card=[3],
+ values=[[0.3, 0.1, 0],
+ [0.7, 0.9, 1]],
+ state_names={'x': ['lo', 'med', 'hi'],
+ 'y': ['False', 'True']})
+
+ u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[])
+ v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[])
+ x_node = Node(children=['y'], cpd=custom_cpd_x)
+ y_node = Node(children=[], cpd=custom_cpd_y)
+ return BeliefUpdateNodeModel(nodes_dict={'u': u_node,
+ 'v': v_node,
+ 'x': x_node,
+ 'y': y_node})
+
+
def get_label_mapped_to_positive_belief(query_result):
"""Return a dictionary mapping each label_id to the probability of
the label being True."""
@@ -355,3 +392,28 @@ def test_conflicting_evidence_and_model(many_parents_and_model):
with pytest.raises(ConflictingEvidenceError) as err:
query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])})
assert "Can't run belief propagation with conflicting evidence" in str(err)
+
+
+#==============================================================================================
+# Model with two custom cpds
+
+
+def test_no_evidence_custom_cpd_model(custom_cpd_model):
+ expected = {'x': np.array([0.15, 0.575, 0.275]),
+ 'v': np.array([0.5, 0.5]),
+ 'u': np.array([0.5, 0.5]),
+ 'y': np.array([0.1025, 0.8975])}
+ infer = BeliefPropagation(custom_cpd_model)
+ query_result = infer.query(evidence={})
+ compare_dictionaries(expected, query_result)
+
+
+def test_evidence_custom_cpd_model(custom_cpd_model):
+ """Custom node is observed to be in 'med' state."""
+ expected = {'x': np.array([0., 1., 0.]),
+ 'u': np.array([0.60869565, 0.39130435]),
+ 'v': np.array([0.47826087, 0.52173913]),
+ 'y': np.array([0.1, 0.9])}
+ infer = BeliefPropagation(custom_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([0, 1, 0])})
+ compare_dictionaries(expected, query_result)