aboutsummaryrefslogtreecommitdiff
path: root/beliefs
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-12 21:28:26 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-13 18:45:03 -0800
commit10f5c49ea6767f54d59f88eb4064bb4959d14c6b (patch)
tree7ce485c2e678c593fd474537e3a125cd783c79cc /beliefs
parentb3b8bb68d6d590175a07dfc4022b4903d63222e5 (diff)
downloadbeliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.gz
beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.tar.bz2
beliefs-10f5c49ea6767f54d59f88eb4064bb4959d14c6b.zip
implement explicit factor methods for compute_pi_agg and compute_lambda_msg_to_parent in Node
Diffstat (limited to 'beliefs')
-rw-r--r--beliefs/factors/discrete_factor.py7
-rw-r--r--beliefs/models/belief_update_node_model.py37
2 files changed, 33 insertions, 11 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py
index da8e6bf..b75da28 100644
--- a/beliefs/factors/discrete_factor.py
+++ b/beliefs/factors/discrete_factor.py
@@ -86,9 +86,16 @@ class DiscreteFactor:
right = copy.deepcopy(other)
left.add_new_variables_from_other_factor(right)
right.add_new_variables_from_other_factor(left)
+ print('var', left.variables)
+ print(left.cardinality)
+ print(left.values)
+ print('var', right.variables)
+ print(right.cardinality)
+ print(right.values)
# reorder variables in right factor to match order in left
source_axes = list(range(right.values.ndim))
+ print('source_axes', source_axes)
destination_axes = [right.variables.index(var) for var in left.variables]
right.variables = [right.variables[idx] for idx in destination_axes]
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index cd8ba8c..17e98fa 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -205,25 +205,30 @@ class Node:
return msgs
def compute_pi_agg(self):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if len(self.parents) == 0:
+ self.update_pi_agg(self.cpd.values)
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ factors_to_multiply.extend(pi_msgs)
+
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ self.update_pi_agg(factor_product.marginalize(self.parents).values)
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
def compute_lambda_agg(self):
- if len(self.children) == 0:
- return self.lambda_agg.values
- else:
+ if len(self.children) != 0:
lambda_msg_values = [
msg.values for msg in
self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
]
self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
- return self.lambda_agg.values
def update_pi_agg(self, new_value):
- self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.pi_agg.update_values(new_value)
def update_lambda_agg(self, new_value):
- self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality))
+ self.lambda_agg.update_values(new_value)
def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type):
if key not in received_msg_dict.keys():
@@ -242,7 +247,8 @@ class Node:
if new_value.shape != expected_shape:
raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
.format(expected_shape, new_value.shape))
- received_msg_dict[key]._values = new_value
+ # received_msg_dict[key]._values = new_value
+ received_msg_dict[key].update_values(new_value)
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
@@ -267,8 +273,17 @@ class Node:
raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
- # TODO: implement explict factor product operation
- raise NotImplementedError
+ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ factors_to_multiply = [self.cpd]
+ pi_msgs_excl_k = [msg for par_id, msg in self.pi_received_msgs.items()
+ if par_id != parent_k]
+ factors_to_multiply.extend(pi_msgs_excl_k)
+ factor_product = reduce(lambda phi1, phi2: phi1*phi2, factors_to_multiply)
+ new_factor = factor_product.marginalize(list(set(self.parents) - set([parent_k])))
+ lambda_msg_to_k = (self.lambda_agg * new_factor).marginalize([self.lambda_agg.variables[0]])
+ return self._normalize(lambda_msg_to_k.values)
@property
def is_fully_initialized(self):