aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-12 16:11:54 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-13 18:45:03 -0800
commit70bdf07d25f41de1a9510b64267bfa29791760c7 (patch)
tree476c8bf2264f6a3a187e9f3351c91855ee524fdc
parentf6ab3e7b918396dee70dc4ff2dc3a1341aaeb97b (diff)
downloadbeliefs-70bdf07d25f41de1a9510b64267bfa29791760c7.tar.gz
beliefs-70bdf07d25f41de1a9510b64267bfa29791760c7.tar.bz2
beliefs-70bdf07d25f41de1a9510b64267bfa29791760c7.zip
change all msg datatypes from np.array -> DiscreteFactor
-rw-r--r--beliefs/inference/belief_propagation.py20
-rw-r--r--beliefs/models/belief_update_node_model.py110
2 files changed, 77 insertions, 53 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 7ec648d..128f645 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -74,9 +74,9 @@ class BeliefPropagation:
if node_to_update_label_id not in evidence:
node.compute_pi_agg()
- logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
+ logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg.values))
node.compute_lambda_agg()
- logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
+ logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg.values))
for parent_id in parent_ids:
try:
@@ -97,7 +97,6 @@ class BeliefPropagation:
new_value=new_pi_msg)
nodes_to_update.add(MsgPassers(msg_receiver=child_id,
msg_sender=node_to_update_label_id))
-
self._belief_propagation(nodes_to_update, evidence)
def initialize_model(self):
@@ -115,8 +114,8 @@ class BeliefPropagation:
for node in self.model.nodes_dict.values():
ones_vector = np.ones([node.cardinality])
+ node.update_lambda_agg(ones_vector)
- node.lambda_agg = ones_vector
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
@@ -131,7 +130,7 @@ class BeliefPropagation:
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
- if node_sending_msg.pi_agg is None:
+ if node_sending_msg.pi_agg.values is None:
node_sending_msg.compute_pi_agg()
for child_id in child_ids:
@@ -150,22 +149,19 @@ class BeliefPropagation:
a dict key, value pair as {var: state_of_var observed}
"""
for evidence_id, observed_value in evidence.items():
- nodes_to_update = set()
-
if evidence_id not in self.model.nodes_dict.keys():
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
if is_kronecker_delta(observed_value):
# specific evidence
- self.model.nodes_dict[evidence_id].lambda_agg = observed_value
+ self.model.nodes_dict[evidence_id].update_lambda_agg(observed_value)
else:
# virtual evidence
- self.model.nodes_dict[evidence_id].lambda_agg = \
- self.model.nodes_dict[evidence_id].lambda_agg * observed_value
-
+ self.model.nodes_dict[evidence_id].update_lambda_agg(
+ self.model.nodes_dict[evidence_id].lambda_agg.values * observed_value
+ )
nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
-
self._belief_propagation(nodes_to_update=set(nodes_to_update),
evidence=evidence)
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 1c3ba6e..1765ed9 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -7,6 +7,7 @@ from functools import reduce
import networkx as nx
from beliefs.models.base_models import BayesianModel
+from beliefs.factors.discrete_factor import DiscreteFactor
from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD
@@ -88,10 +89,10 @@ class BeliefUpdateNodeModel(BayesianModel):
to an (unnormalized) unit vector, of length the cardinality of x.
"""
for root in self.get_roots():
- self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
+ self.nodes_dict[root].update_pi_agg(self.nodes_dict[root].cpd.values)
for leaf in self.get_leaves():
- self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
+ self.nodes_dict[leaf].update_lambda_agg(np.ones([self.nodes_dict[leaf].cardinality]))
@property
def all_nodes_are_fully_initialized(self):
@@ -135,17 +136,18 @@ class Node:
cpd: an instance of a conditional probability distribution,
e.g. BernoulliOrCPD or TabularCPD
"""
- self.label_id = label_id
+ self.label_id = label_id # this can be obtained from cpd.variable
self.children = children
- self.parents = parents
- self.cardinality = cardinality
+ self.parents = parents # this can be obtained from cpd.variables[1:]
+ self.cardinality = cardinality # this can be obtained from cpd.cardinality[0]
self.cpd = cpd
- self.pi_agg = None # np.array dimensions [1, cardinality]
- self.lambda_agg = None # np.array dimensions [1, cardinality]
+ # instances of DiscreteFactor with `values` an np.array of dimensions [1, cardinality]
+ self.pi_agg = self._init_aggregate_values()
+ self.lambda_agg = self._init_aggregate_values()
- self.pi_received_msgs = self._init_received_msgs(parents)
- self.lambda_received_msgs = self._init_received_msgs(children)
+ self.pi_received_msgs = self._init_pi_received_msgs(parents)
+ self.lambda_received_msgs = {child: self._init_aggregate_values() for child in children}
@classmethod
def from_cpd_class(cls,
@@ -159,8 +161,8 @@ class Node:
@property
def belief(self):
- if self.pi_agg.any() and self.lambda_agg.any():
- belief = np.multiply(self.pi_agg, self.lambda_agg)
+ if any(self.pi_agg.values) and any(self.lambda_agg.values):
+ belief = (self.lambda_agg * self.pi_agg).values
return self._normalize(belief)
else:
return None
@@ -168,9 +170,21 @@ class Node:
def _normalize(self, value):
return value/value.sum()
- @staticmethod
- def _init_received_msgs(keys):
- return {k: None for k in keys}
+ def _init_aggregate_values(self):
+ return DiscreteFactor(variables=[self.cpd.variable],
+ cardinality=[self.cardinality],
+ values=None,
+ state_names=None)
+
+ def _init_pi_received_msgs(self, parents):
+ msgs = {}
+ for k in parents:
+ kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
+ msgs[k] = DiscreteFactor(variables=[k],
+ cardinality=[kth_cardinality],
+ values=None,
+ state_names=None)
+ return msgs
def _return_msgs_received_for_msg_type(self, message_type):
"""
@@ -181,9 +195,9 @@ class Node:
msg_values: list of message values (each an np.array)
"""
if message_type == MessageType.LAMBDA:
- msg_values = [msg for msg in self.lambda_received_msgs.values()]
+ msg_values = [msg.values for msg in self.lambda_received_msgs.values()]
elif message_type == MessageType.PI:
- msg_values = [msg for msg in self.pi_received_msgs.values()]
+ msg_values = [msg.values for msg in self.pi_received_msgs.values()]
return msg_values
def validate_and_return_msgs_received_for_msg_type(self, message_type):
@@ -214,13 +228,20 @@ class Node:
def compute_lambda_agg(self):
if len(self.children) == 0:
- return self.lambda_agg
+ return self.lambda_agg.values
else:
- lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
- self.lambda_agg = reduce(np.multiply, lambda_msg_values)
- return self.lambda_agg
+ lambda_msg_values =\
+ self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
+ return self.lambda_agg.values
+
+ def update_pi_agg(self, new_value):
+ self.pi_agg.update_values(np.array(new_value).reshape(self.cardinality))
+
+ def update_lambda_agg(self, new_value):
+ self.lambda_agg.update_values(np.array(new_value).reshape(self.cardinality))
- def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
+ def _update_received_msg_by_key(self, received_msg_dict, key, new_value, message_type):
if key not in received_msg_dict.keys():
raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}"
.format(key, received_msg_dict.keys()))
@@ -229,23 +250,30 @@ class Node:
raise TypeError("Expected a new value of type numpy.ndarray, but got type {}"
.format(type(new_value)))
- if new_value.shape != (self.cardinality,):
- raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
- .format(self.cardinality, new_value.shape))
- received_msg_dict[key] = new_value
+ if message_type == MessageType.LAMBDA:
+ expected_shape = (self.cardinality,)
+ elif message_type == MessageType.PI:
+ expected_shape = (self.cpd.cardinality[self.cpd.variables.index(key)],)
+
+ if new_value.shape != expected_shape:
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
+ .format(expected_shape, new_value.shape))
+ received_msg_dict[key]._values = new_value
def update_pi_msg_from_parent(self, parent, new_value):
self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
key=parent,
- new_value=new_value)
+ new_value=new_value,
+ message_type=MessageType.PI)
def update_lambda_msg_from_child(self, child, new_value):
self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
key=child,
- new_value=new_value)
+ new_value=new_value,
+ message_type=MessageType.LAMBDA)
def compute_pi_msg_to_child(self, child_k):
- lambda_msg_from_child = self.lambda_received_msgs[child_k]
+ lambda_msg_from_child = self.lambda_received_msgs[child_k].values
if lambda_msg_from_child is not None:
with np.errstate(divide='ignore', invalid='ignore'):
# 0/0 := 0
@@ -272,7 +300,7 @@ class Node:
if any(msg is None for msg in pi_msgs):
return False
- if (self.pi_agg is None) or (self.lambda_agg is None):
+ if (self.pi_agg.values is None) or (self.lambda_agg.values is None):
return False
return True
@@ -291,7 +319,7 @@ class BernoulliOrNode(Node):
def compute_pi_agg(self):
if len(self.parents) == 0:
- self.pi_agg = self.cpd.values
+ self.update_pi_agg(self.cpd.values)
else:
pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p0 = [p[0] for p in pi_msg_values]
@@ -299,19 +327,19 @@ class BernoulliOrNode(Node):
# of p = [P(False), P(True)]
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
- self.pi_agg = np.array([p_0, p_1])
+ self.update_pi_agg(np.array([p_0, p_1]))
return self.pi_agg
def compute_lambda_msg_to_parent(self, parent_k):
- if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
- lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
- lambda_1 = self.lambda_agg[1]
+ lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product
+ lambda_1 = self.lambda_agg.values[1]
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
@@ -331,7 +359,7 @@ class BernoulliAndNode(Node):
def compute_pi_agg(self):
if len(self.parents) == 0:
- self.pi_agg = self.cpd.values
+ self.update_pi_agg(self.cpd.values)
else:
pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
parents_p1 = [p[1] for p in pi_msg_values]
@@ -339,19 +367,19 @@ class BernoulliAndNode(Node):
# of p = [P(False), P(True)]
p_1 = reduce(lambda x, y: x*y, parents_p1)
p_0 = 1 - p_1
- self.pi_agg = np.array([p_0, p_1])
+ self.update_pi_agg(np.array([p_0, p_1]))
return self.pi_agg
def compute_lambda_msg_to_parent(self, parent_k):
- if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ if np.array_equal(self.lambda_agg.values, np.ones([self.cardinality])):
return np.ones([self.cardinality])
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
+ p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
- lambda_0 = self.lambda_agg[0]
- lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product
+ lambda_0 = self.lambda_agg.values[0]
+ lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent