From 10f5c49ea6767f54d59f88eb4064bb4959d14c6b Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Tue, 12 Dec 2017 21:28:26 -0800 Subject: implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node --- tests/test_belief_propagation.py | 64 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) (limited to 'tests/test_belief_propagation.py') diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index 7a77311..1b8c0ac 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -3,10 +3,12 @@ import pytest from pytest import approx from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError +from beliefs.factors.cpd import TabularCPD from beliefs.models.belief_update_node_model import ( BeliefUpdateNodeModel, BernoulliOrNode, - BernoulliAndNode + BernoulliAndNode, + Node ) @@ -89,6 +91,41 @@ def mixed_cpd_model(edges_five_nodes): 'w': w_node}) +@pytest.fixture(scope='function') +def custom_cpd_model(): + """ + Y-shaped model, with parents ,'u' and 'v' as Or-nodes, 'x' a node with + cardinality 3 and custom CPD, 'y' a node with cardinality 2 and custom CPD. + """ + custom_cpd_x = TabularCPD(variable='x', + variable_card=3, + parents=['u', 'v'], + parents_card=[2, 2], + values=[[0.2, 0, 0.3, 0.1], + [0.4, 1, 0.7, 0.2], + [0.4, 0, 0, 0.7]], + state_names={'x': ['lo', 'med', 'hi'], + 'u': ['False', 'True'], + 'v': ['False', 'True']}) + custom_cpd_y = TabularCPD(variable='y', + variable_card=2, + parents=['x'], + parents_card=[3], + values=[[0.3, 0.1, 0], + [0.7, 0.9, 1]], + state_names={'x': ['lo', 'med', 'hi'], + 'y': ['False', 'True']}) + + u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[]) + v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[]) + x_node = Node(children=['y'], cpd=custom_cpd_x) + y_node = Node(children=[], cpd=custom_cpd_y) + return BeliefUpdateNodeModel(nodes_dict={'u': u_node, + 'v': v_node, + 'x': x_node, + 'y': y_node}) + + def get_label_mapped_to_positive_belief(query_result): """Return a dictionary mapping each label_id to the probability of the label being True.""" @@ -355,3 +392,28 @@ def test_conflicting_evidence_and_model(many_parents_and_model): with pytest.raises(ConflictingEvidenceError) as err: query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])}) assert "Can't run belief propagation with conflicting evidence" in str(err) + + +#============================================================================================== +# Model with two custom cpds + + +def test_no_evidence_custom_cpd_model(custom_cpd_model): + expected = {'x': np.array([0.15, 0.575, 0.275]), + 'v': np.array([0.5, 0.5]), + 'u': np.array([0.5, 0.5]), + 'y': np.array([0.1025, 0.8975])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={}) + compare_dictionaries(expected, query_result) + + +def test_evidence_custom_cpd_model(custom_cpd_model): + """Custom node is observed to be in 'med' state.""" + expected = {'x': np.array([0., 1., 0.]), + 'u': np.array([0.60869565, 0.39130435]), + 'v': np.array([0.47826087, 0.52173913]), + 'y': np.array([0.1, 0.9])} + infer = BeliefPropagation(custom_cpd_model) + query_result = infer.query(evidence={'x': np.array([0, 1, 0])}) + compare_dictionaries(expected, query_result) -- cgit v1.2.3