diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-12-11 18:56:15 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-12-11 18:56:15 -0800 |
commit | 65d822247e30b6e104a8c09d3b930487b9f20a58 (patch) | |
tree | d44b83f001ab352b30e17ab981295c2ee70a4d56 /beliefs | |
parent | 26b43410569044aff46053cae7c68862825dd4ec (diff) | |
parent | 7b5c17c316481edbbd13815390d0b34fb50a03a6 (diff) | |
download | beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.tar.gz beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.tar.bz2 beliefs-e3e0589969b0660d7fa94bf55515d5ba31f5c6e7.zip |
LGS-173 Merge branch 'bernoulli_and_node'v0.0.3
Diffstat (limited to 'beliefs')
-rw-r--r-- | beliefs/factors/bernoulli_and_cpd.py | 45 | ||||
-rw-r--r-- | beliefs/factors/bernoulli_or_cpd.py | 7 | ||||
-rw-r--r-- | beliefs/models/belief_update_node_model.py | 49 |
3 files changed, 96 insertions, 5 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py new file mode 100644 index 0000000..fdb0c25 --- /dev/null +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -0,0 +1,45 @@ +import numpy as np + +from beliefs.factors.cpd import TabularCPD + + +class BernoulliAndCPD(TabularCPD): + """CPD class for a Bernoulli random variable whose relationship to its + parents (also Bernoulli random variables) is described by AND logic. + + If all of the variable's parents are True, then the variable + is True, and False otherwise. + """ + def __init__(self, variable, parents=[]): + """ + Args: + variable: int or string + parents: optional, list of int and/or strings + """ + super().__init__(variable=variable, + variable_card=2, + parents=parents, + parents_card=[2]*len(parents), + values=[]) + self._values = None + + @property + def values(self): + if self._values is None: + self._values = self._build_kwise_values_array(len(self.variables)) + self._values = self._values.reshape(self.cardinality) + return self._values + + @staticmethod + def _build_kwise_values_array(k): + # special case a completely independent factor, and + # return the uniform prior + if k == 1: + return np.array([0.5, 0.5]) + + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. + return np.array( + [1.]*(2**(k-1)-1) + [0.] + [0.,]*(2**(k-1)-1) + [1.] + ) diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index bfb3a95..12ee2f6 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -21,11 +21,11 @@ class BernoulliOrCPD(TabularCPD): parents=parents, parents_card=[2]*len(parents), values=[]) - self._values = [] + self._values = None @property def values(self): - if not any(self._values): + if self._values is None: self._values = self._build_kwise_values_array(len(self.variables)) self._values = self._values.reshape(self.cardinality) return self._values @@ -37,6 +37,9 @@ class BernoulliOrCPD(TabularCPD): if k == 1: return np.array([0.5, 0.5]) + # values are stored as a row vector using an ordering such that + # the right-most variables as defined in [variable].extend(parents) + # cycle through their values the fastest. return np.array( [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1) ) diff --git a/beliefs/models/belief_update_node_model.py b/beliefs/models/belief_update_node_model.py index 667e0f1..1c3ba6e 100644 --- a/beliefs/models/belief_update_node_model.py +++ b/beliefs/models/belief_update_node_model.py @@ -8,6 +8,7 @@ import networkx as nx from beliefs.models.base_models import BayesianModel from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD +from beliefs.factors.bernoulli_and_cpd import BernoulliAndCPD class InvalidLambdaMsgToParent(Exception): @@ -212,7 +213,7 @@ class Node: raise NotImplementedError def compute_lambda_agg(self): - if not self.children: + if len(self.children) == 0: return self.lambda_agg else: lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) @@ -289,11 +290,13 @@ class BernoulliOrNode(Node): cpd=BernoulliOrCPD(label_id, parents)) def compute_pi_agg(self): - if not self.parents: + if len(self.parents) == 0: self.pi_agg = self.cpd.values else: pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) parents_p0 = [p[0] for p in pi_msg_values] + # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) + # of p = [P(False), P(True)] p_0 = reduce(lambda x, y: x*y, parents_p0) p_1 = 1 - p_0 self.pi_agg = np.array([p_0, p_1]) @@ -305,7 +308,7 @@ class BernoulliOrNode(Node): else: # TODO: cleanup this validation _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p0_excluding_k = [p[0] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product lambda_1 = self.lambda_agg[1] @@ -313,3 +316,43 @@ class BernoulliOrNode(Node): if not any(lambda_msg): raise InvalidLambdaMsgToParent return self._normalize(lambda_msg) + + +class BernoulliAndNode(Node): + def __init__(self, + label_id, + children, + parents): + super().__init__(label_id=label_id, + children=children, + parents=parents, + cardinality=2, + cpd=BernoulliAndCPD(label_id, parents)) + + def compute_pi_agg(self): + if len(self.parents) == 0: + self.pi_agg = self.cpd.values + else: + pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p1 = [p[1] for p in pi_msg_values] + # Parents are Bernoulli variables with pi_msg_values (surrogate prior probabilities) + # of p = [P(False), P(True)] + p_1 = reduce(lambda x, y: x*y, parents_p1) + p_0 = 1 - p_1 + self.pi_agg = np.array([p_0, p_1]) + return self.pi_agg + + def compute_lambda_msg_to_parent(self, parent_k): + if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + # TODO: cleanup this validation + _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + p1_excluding_k = [p[1] for par_id, p in self.pi_received_msgs.items() if par_id != parent_k] + p1_product = reduce(lambda x, y: x*y, p1_excluding_k, 1) + lambda_0 = self.lambda_agg[0] + lambda_1 = self.lambda_agg[0] + (self.lambda_agg[1] - self.lambda_agg[0])*p1_product + lambda_msg = np.array([lambda_0, lambda_1]) + if not any(lambda_msg): + raise InvalidLambdaMsgToParent + return self._normalize(lambda_msg) |