aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-12 21:28:26 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-13 18:45:03 -0800
commit10f5c49ea6767f54d59f88eb4064bb4959d14c6b (patch)
tree7ce485c2e678c593fd474537e3a125cd783c79cc
parentb3b8bb68d6d590175a07dfc4022b4903d63222e5 (diff)
downloadbeliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.gz
beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.bz2
beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.zip
implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node
-rw-r--r--beliefs/factors/discrete_factor.py7
-rw-r--r--beliefs/models/belief_update_node_model.py37
-rw-r--r--tests/test_belief_propagation.py64
3 files changed, 96 insertions, 12 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py
index da8e6bf..b75da28 100644
--- a/beliefs/factors/discrete_factor.py
+++ b/beliefs/factors/discrete_factor.py
@@ -86,9 +86,16 @@ class DiscreteFactor:
right = copy.deepcopy(other)
left.add_new_variables_from_other_factor(right)
right.add_new_variables_from_other_factor(left)
+ print('var', left.variables)
+ print(left.cardinality)
+ print(left.values)
+ print('var', right.variables)
+ print(right.cardinality)
+ print(right.values)
# reorder variables in right factor to match order in left
source_axes = list(range(right.values.ndim))
+ print('source_axes', source_axes)
destination_axes = [right.variables.index(var) for var in left.variables]
right.variables = [right.variables[idx] for idx in destination_axes]
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index cd8ba8c..17e98fa 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -205,25 +205,30 @@ class Node:
return msgs
def compute_pi_agg(self):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if len(self.parents) == 0:
+ self.update_pi_agg(self.cpd.values)
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ factors_to_multiply.extend(pi_msgs)
+
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ self.update_pi_agg(factor_product.marginalize(self.parents).values)
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
def compute_lambda_agg(self):
- if len(self.children) == 0:
- return self.lambda_agg.values
- else:
+ if len(self.children) != 0:
lambda_msg_values = [
msg.values for msg in
self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
]
self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
- return self.lambda_agg.values
def update_pi_agg(self, new_value):
- self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.pi_agg.update_values(new_value)
def update_lambda_agg(self, new_value):
- self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.lambda_agg.update_values(new_value)
def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type):
if key not in received_msg_dict.keys():
@@ -242,7 +247,8 @@ class Node:
if new_value.shape != expected_shape:
raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
.format(expected_shape, new_value.shape))
- received_msg_dict[key]._values = new_value
+ # received_msg_dict[key]._values = new_value
+ received_msg_dict[key].update_values(new_value)
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
@@ -267,8 +273,17 @@ class Node:
raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items()
+ if par_id != parent_k]
+ factors_to_multiply.extend(pi_msgs_excl_k)
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k])))
+ lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]])
+ return self._normalize(lambda_msg_to_k.values)
@property
def is_fully_initialized(self):
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 7a77311..1b8c0ac 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -3,10 +3,12 @@ import pytest
from pytest import approx
from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
+from beliefs.factors.cpd import TabularCPD
from beliefs.models.belief_update_node_model import (
BeliefUpdateNodeModel,
BernoulliOrNode,
- BernoulliAndNode
+ BernoulliAndNode,
+ Node
)
@@ -89,6 +91,41 @@ def mixed_cpd_model(edges_five_nodes):
'w': w_node})
+@pytest.fixture(scope='function')
+def custom_cpd_model():
+ """
+ Y-shaped model, with parents ,'u' and 'v' as Or-nodes, 'x' a node with
+ cardinality 3 and custom CPD, 'y' a node with cardinality 2 and custom CPD.
+ """
+ custom_cpd_x = TabularCPD(variable='x',
+ variable_card=3,
+ parents=['u', 'v'],
+ parents_card=[2, 2],
+ values=[[0.2, 0, 0.3, 0.1],
+ [0.4, 1, 0.7, 0.2],
+ [0.4, 0, 0, 0.7]],
+ state_names={'x': ['lo', 'med', 'hi'],
+ 'u': ['False', 'True'],
+ 'v': ['False', 'True']})
+ custom_cpd_y = TabularCPD(variable='y',
+ variable_card=2,
+ parents=['x'],
+ parents_card=[3],
+ values=[[0.3, 0.1, 0],
+ [0.7, 0.9, 1]],
+ state_names={'x': ['lo', 'med', 'hi'],
+ 'y': ['False', 'True']})
+
+ u_node = BernoulliOrNode(label_id='u', children=['x'], parents=[])
+ v_node = BernoulliOrNode(label_id='v', children=['x'], parents=[])
+ x_node = Node(children=['y'], cpd=custom_cpd_x)
+ y_node = Node(children=[], cpd=custom_cpd_y)
+ return BeliefUpdateNodeModel(nodes_dict={'u': u_node,
+ 'v': v_node,
+ 'x': x_node,
+ 'y': y_node})
+
+
def get_label_mapped_to_positive_belief(query_result):
"""Return a dictionary mapping each label_id to the probability of
the label being True."""
@@ -355,3 +392,28 @@ def test_conflicting_evidence_and_model(many_parents_and_model):
with pytest.raises(ConflictingEvidenceError) as err:
query_result = infer.query(evidence={'62': np.array([0, 1]), '112': np.array([1, 0])})
assert "Can't run belief propagation with conflicting evidence" in str(err)
+
+
+#==============================================================================================
+# Model with two custom cpds
+
+
+def test_no_evidence_custom_cpd_model(custom_cpd_model):
+ expected = {'x': np.array([0.15, 0.575, 0.275]),
+ 'v': np.array([0.5, 0.5]),
+ 'u': np.array([0.5, 0.5]),
+ 'y': np.array([0.1025, 0.8975])}
+ infer = BeliefPropagation(custom_cpd_model)
+ query_result = infer.query(evidence={})
+ compare_dictionaries(expected, query_result)
+
+
+def test_evidence_custom_cpd_model(custom_cpd_model):
+ """Custom node is observed to be in 'med' state."""
+ expected = {'x': np.array([0., 1., 0.]),
+ 'u': np.array([0.60869565, 0.39130435]),
+ 'v': np.array([0.47826087, 0.52173913]),
+ 'y': np.array([0.1, 0.9])}
+ infer = BeliefPropagation(custom_cpd_model)
+ query_result = infer.query(evidence={'x': np.array([0, 1, 0])})
+ compare_dictionaries(expected, query_result)