aboutsummaryrefslogtreecommitdiff
path: root/tests/test_get_reachable_observed_variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_get_reachable_observed_variables.py')
-rw-r--r--tests/test_get_reachable_observed_variables.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/tests/test_get_reachable_observed_variables.py b/tests/test_get_reachable_observed_variables.py
new file mode 100644
index 0000000..d6590ad
--- /dev/null
+++ b/tests/test_get_reachable_observed_variables.py
@@ -0,0 +1,129 @@
+import numpy as np
+
+from test_belief_propagation import simple_model, simple_edges
+
+from beliefs.inference.belief_propagation import BeliefPropagation
+from beliefs.utils.random_variables import (
+ get_reachable_observed_variables_for_inferred_variables
+)
+
+
+def test_reachable_observed_vars_direct_common_effect(simple_model):
+ observed_vars = {'14': np.array([1,0]), 'x': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'x', '14'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_common_effect(simple_model):
+ observed_vars = {'12': np.array([1,0]), '14': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'12', '14'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_common_cause(simple_model):
+ observed_vars = {'10': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'10'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_common_cause(simple_model):
+ observed_vars = {'10': np.array([0,1]), '5': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'5'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_causal(simple_model):
+ observed_vars = {'1': np.array([0,1]), '2': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'1', '2'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_causal(simple_model):
+ observed_vars = {'1': np.array([0,1]), '2': np.array([1,0]), '3': np.array([0,1])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'3'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_indirect_evidential(simple_model):
+ observed_vars = {'13': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'13'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_reachable_observed_vars_blocked_evidential(simple_model):
+ observed_vars = {'x': np.array([1,0]), '13': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ expected = {'x'}
+ observed = simple_model.reachable_observed_variables(
+ source='9',
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed
+
+
+def test_get_reachable_obs_vars_for_inferred(simple_model):
+ observed_vars = {'6': np.array([1,0]), '7': np.array([1,0]), '10': np.array([1,0])}
+ infer = BeliefPropagation(simple_model)
+ infer.query(evidence=observed_vars)
+
+ print(set(simple_model.get_unobserved_variables_in_definite_state(observed_vars.keys())))
+ print(simple_model._get_ancestors_of(set(observed_vars.keys())))
+ expected = {'4': {'10'}, '1': {'10'}, '11': {'7', '6', '10'}, '2': {'10'},
+ '8': {'7', '6'}, '5': {'10'}, '3': {'10'}, '9': {'7', '6', '10'}}
+
+ observed = get_reachable_observed_variables_for_inferred_variables(
+ model=simple_model,
+ observed=set(observed_vars.keys())
+ )
+ assert expected == observed