diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-12-03 20:38:28 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-12-03 20:38:28 -0800 |
commit | 26b43410569044aff46053cae7c68862825dd4ec (patch) | |
tree | b184df84d416e2ddf837b25baadff4f9feaaa250 /tests/test_get_reachable_observed_variables.py | |
parent | 6a1b35f5bf122232d058ed0f3ea19c15629c0cbc (diff) | |
parent | c906bd37fba63ba706cc3b7802bfb18ffb05ee9a (diff) | |
download | beliefs-26b43410569044aff46053cae7c68862825dd4ec.tar.gz beliefs-26b43410569044aff46053cae7c68862825dd4ec.tar.bz2 beliefs-26b43410569044aff46053cae7c68862825dd4ec.zip |
LGS-164 belief propagation for polytrees, special case of OR cpds, refactored from LGSv0.0.2
Diffstat (limited to 'tests/test_get_reachable_observed_variables.py')
-rw-r--r-- | tests/test_get_reachable_observed_variables.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/tests/test_get_reachable_observed_variables.py b/tests/test_get_reachable_observed_variables.py new file mode 100644 index 0000000..d6590ad --- /dev/null +++ b/tests/test_get_reachable_observed_variables.py @@ -0,0 +1,129 @@ +import numpy as np + +from test_belief_propagation import simple_model, simple_edges + +from beliefs.inference.belief_propagation import BeliefPropagation +from beliefs.utils.random_variables import ( + get_reachable_observed_variables_for_inferred_variables +) + + +def test_reachable_observed_vars_direct_common_effect(simple_model): + observed_vars = {'14': np.array([1,0]), 'x': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'x', '14'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_indirect_common_effect(simple_model): + observed_vars = {'12': np.array([1,0]), '14': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'12', '14'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_common_cause(simple_model): + observed_vars = {'10': np.array([0,1])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'10'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_blocked_common_cause(simple_model): + observed_vars = {'10': np.array([0,1]), '5': np.array([0,1])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'5'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_indirect_causal(simple_model): + observed_vars = {'1': np.array([0,1]), '2': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'1', '2'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_blocked_causal(simple_model): + observed_vars = {'1': np.array([0,1]), '2': np.array([1,0]), '3': np.array([0,1])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'3'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_indirect_evidential(simple_model): + observed_vars = {'13': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'13'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_reachable_observed_vars_blocked_evidential(simple_model): + observed_vars = {'x': np.array([1,0]), '13': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + expected = {'x'} + observed = simple_model.reachable_observed_variables( + source='9', + observed=set(observed_vars.keys()) + ) + assert expected == observed + + +def test_get_reachable_obs_vars_for_inferred(simple_model): + observed_vars = {'6': np.array([1,0]), '7': np.array([1,0]), '10': np.array([1,0])} + infer = BeliefPropagation(simple_model) + infer.query(evidence=observed_vars) + + print(set(simple_model.get_unobserved_variables_in_definite_state(observed_vars.keys()))) + print(simple_model._get_ancestors_of(set(observed_vars.keys()))) + expected = {'4': {'10'}, '1': {'10'}, '11': {'7', '6', '10'}, '2': {'10'}, + '8': {'7', '6'}, '5': {'10'}, '3': {'10'}, '9': {'7', '6', '10'}} + + observed = get_reachable_observed_variables_for_inferred_variables( + model=simple_model, + observed=set(observed_vars.keys()) + ) + assert expected == observed |