aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-12 19:58:12 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-13 18:45:03 -0800
commitb3b8bb68d6d590175a07dfc4022b4903d63222e5 (patch)
treedc4377a608e172994c773cdf4ebf079c104cfeed
parent2f4de4ae0b28e0e5ee2a5be6955366267fbc2404 (diff)
downloadbeliefs-b3b8bb68d6d590175a07dfc4022b4903d63222e5.tar.gz
beliefs-b3b8bb68d6d590175a07dfc4022b4903d63222e5.tar.bz2
beliefs-b3b8bb68d6d590175a07dfc4022b4903d63222e5.zip
Bernoulli Or/And Node access msg values by state names
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py2
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py2
-rw-r--r--beliefs/models/belief_update_node_model.py91
3 files changed, 69 insertions, 26 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
index adf5ed5..15802c2 100644
--- a/beliefs/factors/bernoulli_and_cpd.py
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -20,7 +20,7 @@ class BernoulliAndCPD(TabularCPD):
variable_card=2,
parents=parents,
parents_card=[2]*len(parents),
- values=[],
+ values=None,
state_names={var: ['False', 'True'] for var in [variable] + parents})
self._values = None
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index 6e01cf9..5b661a1 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -20,7 +20,7 @@ class BernoulliOrCPD(TabularCPD):
variable_card=2,
parents=parents,
parents_card=[2]*len(parents),
- values=[],
+ values=None,
state_names={var: ['False', 'True'] for var in [variable] + parents})
self._values = None
diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py
index 820ee0c..cd8ba8c 100644
--- a/beliefs/models/belief_update_node_model.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -174,13 +174,13 @@ class Node:
message_type: MessageType enum
Returns:
- msg_values: list of message values (each an np.array)
+ msg_values: list of DiscreteFactors containing message values (np.arrays)
"""
if message_type == MessageType.LAMBDA:
- msg_values = [msg.values for msg in self.lambda_received_msgs.values()]
+ msgs = [msg for msg in self.lambda_received_msgs.values()]
elif message_type == MessageType.PI:
- msg_values = [msg.values for msg in self.pi_received_msgs.values()]
- return msg_values
+ msgs = [msg for msg in self.pi_received_msgs.values()]
+ return msgs
def validate_and_return_msgs_received_for_msg_type(self, message_type):
"""
@@ -192,17 +192,17 @@ class Node:
message_type: MessageType enum
Returns:
- msg_values: list of message values (each an np.array)
+ msgs: list of DiscreteFactors containing message values (np.array)
"""
- msg_values = self._return_msgs_received_for_msg_type(message_type)
+ msgs = self._return_msgs_received_for_msg_type(message_type)
- if any(msg is None for msg in msg_values):
+ if any(msg.values is None for msg in msgs):
raise ValueError(
"Missing value for {msg_type} msg from child: can't compute {msg_type}_agg."
.format(msg_type=message_type.value)
)
else:
- return msg_values
+ return msgs
def compute_pi_agg(self):
# TODO: implement explict factor product operation
@@ -212,8 +212,10 @@ class Node:
if len(self.children) == 0:
return self.lambda_agg.values
else:
- lambda_msg_values =\
- self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ lambda_msg_values = [
+ msg.values for msg in
+ self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ ]
self.update_lambda_agg(reduce(np.multiply, lambda_msg_values))
return self.lambda_agg.values
@@ -295,14 +297,30 @@ class BernoulliOrNode(Node):
parents):
super().__init__(children=children, cpd=BernoulliOrCPD(label_id, parents))
+ def _init_aggregate_values(self):
+ variable = self.cpd.variable
+ return DiscreteFactor(variables=[self.cpd.variable],
+ cardinality=[self.cardinality],
+ values=None,
+ state_names={variable: self.cpd.state_names[variable]})
+
+ def _init_pi_received_msgs(self, parents):
+ msgs = {}
+ for k in parents:
+ kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
+ msgs[k] = DiscreteFactor(variables=[k],
+ cardinality=[kth_cardinality],
+ values=None,
+ state_names={k: self.cpd.state_names[k]})
+ return msgs
+
def compute_pi_agg(self):
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
- pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- parents_p0 = [p[0] for p in pi_msg_values]
- # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
- # of p = [P(False), P(True)]
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p0 = [p.get_value_for_state_vector({p.variables[0]: 'False'})
+ for p in pi_msgs]
p_0 = reduce(lambda x, y: x*y, parents_p0)
p_1 = 1 - p_0
self.update_pi_agg(np.array([p_0, p_1]))
@@ -314,10 +332,14 @@ class BernoulliOrNode(Node):
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p0_excluding_k = [p.values[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'False'})
+ for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
- lambda_0 = self.lambda_agg.values[1] + (self.lambda_agg.values[0] - self.lambda_agg.values[1])*p0_product
- lambda_1 = self.lambda_agg.values[1]
+
+ lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'})
+ lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'})
+ lambda_0 = lambda_agg_1 + (lambda_agg_0 - lambda_agg_1)*p0_product
+ lambda_1 = lambda_agg_1
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent
@@ -331,14 +353,30 @@ class BernoulliAndNode(Node):
parents):
super().__init__(children=children, cpd=BernoulliAndCPD(label_id, parents))
+ def _init_aggregate_values(self):
+ variable = self.cpd.variable
+ return DiscreteFactor(variables=[self.cpd.variable],
+ cardinality=[self.cardinality],
+ values=None,
+ state_names={variable: self.cpd.state_names[variable]})
+
+ def _init_pi_received_msgs(self, parents):
+ msgs = {}
+ for k in parents:
+ kth_cardinality = self.cpd.cardinality[self.cpd.variables.index(k)]
+ msgs[k] = DiscreteFactor(variables=[k],
+ cardinality=[kth_cardinality],
+ values=None,
+ state_names={k: self.cpd.state_names[k]})
+ return msgs
+
def compute_pi_agg(self):
if len(self.parents) == 0:
self.update_pi_agg(self.cpd.values)
else:
- pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- parents_p1 = [p[1] for p in pi_msg_values]
- # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities)
- # of p = [P(False), P(True)]
+ pi_msgs = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p1 = [p.get_value_for_state_vector({p.variables[0]: 'True'})
+ for p in pi_msgs]
p_1 = reduce(lambda x, y: x*y, parents_p1)
p_0 = 1 - p_1
self.update_pi_agg(np.array([p_0, p_1]))
@@ -350,10 +388,15 @@ class BernoulliAndNode(Node):
else:
# TODO: cleanup this validation
_ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p1_excluding_k = [p.values[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
+ p1_excluding_k = [p.get_value_for_state_vector({p.variables[0]: 'True'})
+ for par_id, p in self.pi_received_msgs.items() if par_id != parent_k]
p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1)
- lambda_0 = self.lambda_agg.values[0]
- lambda_1 = self.lambda_agg.values[0] + (self.lambda_agg.values[1] - self.lambda_agg.values[0])*p1_product
+
+ lambda_agg_0 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'False'})
+ lambda_agg_1 = self.lambda_agg.get_value_for_state_vector({self.label_id: 'True'})
+
+ lambda_0 = lambda_agg_0
+ lambda_1 = lambda_agg_0 + (lambda_agg_1 - lambda_agg_0)*p1_product
lambda_msg = np.array([lambda_0, lambda_1])
if not any(lambda_msg):
raise InvalidLambdaMsgToParent