1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
import numpy as np
from test_belief_propagation import simple_model, simple_edges
from beliefs.inference.belief_propagation import BeliefPropagation
from beliefs.utils.random_variables import (
get_reachable_observed_variables_for_inferred_variables
)
def test_reachable_observed_vars_direct_common_effect(simple_model):
observed_vars = {'14': np.array([1,0]), 'x': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'x', '14'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_indirect_common_effect(simple_model):
observed_vars = {'12': np.array([1,0]), '14': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'12', '14'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_common_cause(simple_model):
observed_vars = {'10': np.array([0,1])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'10'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_blocked_common_cause(simple_model):
observed_vars = {'10': np.array([0,1]), '5': np.array([0,1])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'5'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_indirect_causal(simple_model):
observed_vars = {'1': np.array([0,1]), '2': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'1', '2'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_blocked_causal(simple_model):
observed_vars = {'1': np.array([0,1]), '2': np.array([1,0]), '3': np.array([0,1])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'3'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_indirect_evidential(simple_model):
observed_vars = {'13': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'13'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_reachable_observed_vars_blocked_evidential(simple_model):
observed_vars = {'x': np.array([1,0]), '13': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
expected = {'x'}
observed = simple_model.reachable_observed_variables(
source='9',
observed=set(observed_vars.keys())
)
assert expected == observed
def test_get_reachable_obs_vars_for_inferred(simple_model):
observed_vars = {'6': np.array([1,0]), '7': np.array([1,0]), '10': np.array([1,0])}
infer = BeliefPropagation(simple_model)
infer.query(evidence=observed_vars)
print(set(simple_model.get_unobserved_variables_in_definite_state(observed_vars.keys())))
print(simple_model._get_ancestors_of(set(observed_vars.keys())))
expected = {'4': {'10'}, '1': {'10'}, '11': {'7', '6', '10'}, '2': {'10'},
'8': {'7', '6'}, '5': {'10'}, '3': {'10'}, '9': {'7', '6', '10'}}
observed = get_reachable_observed_variables_for_inferred_variables(
model=simple_model,
observed=set(observed_vars.keys())
)
assert expected == observed
|