aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/belief_update_node_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/models/belief_update_node_model.py')
-rw-r--r--beliefs/models/belief_update_node_model.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index cd8ba8c..17e98fa 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -205,25 +205,30 @@ class Node:
return msgs
def compute_pi_agg(self):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if len(self.parents) == 0:
+ self.update_pi_agg(self.cpd.values)
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ factors_to_multiply.extend(pi_msgs)
+
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ self.update_pi_agg(factor_product.marginalize(self.parents).values)
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
def compute_lambda_agg(self):
- if len(self.children) == 0:
- return self.lambda_agg.values
- else:
+ if len(self.children) != 0:
lambda_msg_values = [
msg.values for msg in
self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
]
self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
- return self.lambda_agg.values
def update_pi_agg(self, new_value):
- self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.pi_agg.update_values(new_value)
def update_lambda_agg(self, new_value):
- self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.lambda_agg.update_values(new_value)
def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type):
if key not in received_msg_dict.keys():
@@ -242,7 +247,8 @@ class Node:
if new_value.shape != expected_shape:
raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
.format(expected_shape, new_value.shape))
- received_msg_dict[key]._values = new_value
+ # received_msg_dict[key]._values = new_value
+ received_msg_dict[key].update_values(new_value)
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
@@ -267,8 +273,17 @@ class Node:
raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items()
+ if par_id != parent_k]
+ factors_to_multiply.extend(pi_msgs_excl_k)
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k])))
+ lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]])
+ return self._normalize(lambda_msg_to_k.values)
@property
def is_fully_initialized(self):