From d166e36eaf5803af035e444628c67701322b0eb6 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Mon, 20 Nov 2017 17:05:37 -0800 Subject: refactor msg passing methods to BeliefUpdateNodeModel from BayesianModel --- beliefs/models/BayesianModel.py | 76 ++------- beliefs/models/BernoulliOrModel.py | 17 -- .../models/beliefupdate/BeliefUpdateNodeModel.py | 91 +++++++++++ beliefs/models/beliefupdate/BernoulliOrNode.py | 47 ++++++ beliefs/models/beliefupdate/Node.py | 179 +++++++++++++++++++++ 5 files changed, 332 insertions(+), 78 deletions(-) delete mode 100644 beliefs/models/BernoulliOrModel.py create mode 100644 beliefs/models/beliefupdate/BeliefUpdateNodeModel.py create mode 100644 beliefs/models/beliefupdate/BernoulliOrNode.py create mode 100644 beliefs/models/beliefupdate/Node.py (limited to 'beliefs/models') diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/BayesianModel.py index 6257a57..b57f968 100644 --- a/beliefs/models/BayesianModel.py +++ b/beliefs/models/BayesianModel.py @@ -1,9 +1,7 @@ import copy -import numpy as np import networkx as nx from beliefs.models.DirectedGraph import DirectedGraph -from beliefs.utils.edges_helper import EdgesHelper from beliefs.utils.math_helper import is_kronecker_delta @@ -12,74 +10,30 @@ class BayesianModel(DirectedGraph): Bayesian model stores nodes and edges described by conditional probability distributions. """ - def __init__(self, edges, nodes_dict=None): + def __init__(self, edges=[], variables=[], cpds=[]): """ - Input: - edges: list of edge tuples of form ('parent', 'child') - nodes: (optional) dict - a dict key, value pair as {label_id: instance_of_node_class_or_subclass} - """ - if nodes_dict is not None: - super().__init__(edges, nodes_dict.keys()) - else: - super().__init__(edges) - self.nodes_dict = nodes_dict - - @classmethod - def from_node_class(cls, edges, node_class): - """Automatically create all nodes from the same node class + Base class for Bayesian model. Input: - edges: list of edge tuples of form ('parent', 'child') - node_class: (optional) the Node class or subclass from which to - create all the nodes from edges. - """ - nodes = cls.create_nodes(edges, node_class) - return cls.__init__(edges=edges, nodes=nodes) - - @staticmethod - def create_nodes(edges, node_class): - """Returns list of Node instances created from edges using - the default node_class""" - edges_helper = EdgesHelper(edges) - nodes = edges_helper.create_nodes_from_edges(node_class=node_class) - label_to_node = dict() - for node in nodes: - label_to_node[node.label_id] = node - return label_to_node - - def set_boundary_conditions(self): - """ - 1. Root nodes: if x is a node with no parents, set Pi(x) = prior - probability of x. - - 2. Leaf nodes: if x is a node with no children, set Lambda(x) - to an (unnormalized) unit vector, of length the cardinality of x. - """ - for root in self.get_roots(): - self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values - - for leaf in self.get_leaves(): - self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) - - @property - def all_nodes_are_fully_initialized(self): - """ - Returns True if, for all nodes in the model, all lambda and pi - messages and lambda_agg and pi_agg are not None, else False. + edges: (optional) list of edges, + tuples of form ('parent', 'child') + variables: (optional) list of str or int + labels for variables + cpds: (optional) list of CPDs + TabularCPD class or subclass """ - for node in self.nodes_dict.values(): - if not node.is_fully_initialized: - return False - return True + super().__init__() + super().add_edges_from(edges) + super().add_nodes_from(variables) + self.cpds = cpds def copy(self): """ Returns a copy of the model. """ - copy_edges = list(self.edges()).copy() - copy_nodes = copy.deepcopy(self.nodes_dict) - copy_model = self.__class__(edges=copy_edges, nodes=copy_nodes) + copy_model = self.__class__(edges=list(self.edges()).copy(), + variables=list(self.nodes()).copy(), + cpds=[cpd.copy() for cpd in self.cpds]) return copy_model def get_variables_in_definite_state(self): diff --git a/beliefs/models/BernoulliOrModel.py b/beliefs/models/BernoulliOrModel.py deleted file mode 100644 index bf2b44c..0000000 --- a/beliefs/models/BernoulliOrModel.py +++ /dev/null @@ -1,17 +0,0 @@ -from beliefs.models.BayesianModel import BayesianModel -from beliefs.types.BernoulliOrNode import BernoulliOrNode - - -class BernoulliOrModel(BayesianModel): - """ - BernoulliOrModel stores node instances of BernoulliOrNodes (Bernoulli - variables associated with an OR conditional probability distribution). - """ - def __init__(self, edges, nodes=None): - """ - Input: - edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] - """ - if nodes is None: - nodes = self.create_nodes(edges, node_class=BernoulliOrNode) - super().__init__(edges, nodes_dict=nodes) diff --git a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py new file mode 100644 index 0000000..d74eaa7 --- /dev/null +++ b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py @@ -0,0 +1,91 @@ +import copy +import numpy as np + +from beliefs.models.BayesianModel import BayesianModel +from beliefs.utils.edges_helper import EdgesHelper + + +class BeliefUpdateNodeModel(BayesianModel): + """ + A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties + and methods for Pearl's belief update algorithm. + + ref: "Fusion, Propagation, and Structuring in Belief Networks" + Artificial Intelligence 29 (1986) 241-288 + + """ + def __init__(self, nodes_dict): + """ + Input: + nodes_dict: dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + """ + super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), + variables=list(nodes_dict.keys()), + cpds=[node.cpd for node in nodes_dict.values()]) + + self.nodes_dict = nodes_dict + + @classmethod + def from_edges(cls, edges, node_class): + """Create nodes from the same node class. + + Input: + edges: list of edge tuples of form ('parent', 'child') + node_class: the Node class or subclass from which to + create all the nodes from edges. + """ + edges_helper = EdgesHelper(edges) + nodes = edges_helper.create_nodes_from_edges(node_class) + nodes_dict = {node.label_id: node for node in nodes} + return cls(nodes_dict) + + @staticmethod + def _get_edges_from_nodes(nodes): + """ + Return list of all directed edges in nodes. + + Args: + nodes: an iterable of objects of the Node class or subclass + Returns: + edges: list of edge tuples + """ + edges = set() + for node in nodes: + if node.parents: + edge_tuples = zip(node.parents, [node.label_id]*len(node.parents)) + edges.update(edge_tuples) + return list(edges) + + def set_boundary_conditions(self): + """ + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. + + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. + """ + for root in self.get_roots(): + self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values + + for leaf in self.get_leaves(): + self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) + + @property + def all_nodes_are_fully_initialized(self): + """ + Returns True if, for all nodes in the model, all lambda and pi + messages and lambda_agg and pi_agg are not None, else False. + """ + for node in self.nodes_dict.values(): + if not node.is_fully_initialized: + return False + return True + + def copy(self): + """ + Returns a copy of the model. + """ + copy_nodes = copy.deepcopy(self.nodes_dict) + copy_model = self.__class__(nodes_dict=copy_nodes) + return copy_model diff --git a/beliefs/models/beliefupdate/BernoulliOrNode.py b/beliefs/models/beliefupdate/BernoulliOrNode.py new file mode 100644 index 0000000..3386275 --- /dev/null +++ b/beliefs/models/beliefupdate/BernoulliOrNode.py @@ -0,0 +1,47 @@ +import numpy as np +from functools import reduce + +from beliefs.models.beliefupdate.Node import ( + Node, + MessageType, + InvalidLambdaMsgToParent +) +from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD + + +class BernoulliOrNode(Node): + def __init__(self, + label_id, + children, + parents): + super().__init__(label_id=label_id, + children=children, + parents=parents, + cardinality=2, + cpd=BernoulliOrCPD(label_id, parents)) + + def compute_pi_agg(self): + if not self.parents: + self.pi_agg = self.cpd.values + else: + pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p0 = [p[0] for p in pi_msg_values] + p_0 = reduce(lambda x, y: x*y, parents_p0) + p_1 = 1 - p_0 + self.pi_agg = np.array([p_0, p_1]) + return self.pi_agg + + def compute_lambda_msg_to_parent(self, parent_k): + if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + # TODO: cleanup this validation + _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) + lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product + lambda_1 = self.lambda_agg[1] + lambda_msg = np.array([lambda_0, lambda_1]) + if not any(lambda_msg): + raise InvalidLambdaMsgToParent + return self._normalize(lambda_msg) diff --git a/beliefs/models/beliefupdate/Node.py b/beliefs/models/beliefupdate/Node.py new file mode 100644 index 0000000..daa2f14 --- /dev/null +++ b/beliefs/models/beliefupdate/Node.py @@ -0,0 +1,179 @@ +import numpy as np +from functools import reduce +from enum import Enum + + +class InvalidLambdaMsgToParent(Exception): + """Computed invalid lambda msg to send to parent.""" + pass + + +class MessageType(Enum): + LAMBDA = 'lambda' + PI = 'pi' + + +class Node: + """A node in a DAG with methods to compute the belief (marginal probability + of the node given evidence) and compute pi/lambda messages to/from its neighbors + to update its belief. + + Implemented from Pearl's belief propagation algorithm. + """ + def __init__(self, + label_id, + children, + parents, + cardinality, + cpd): + """ + Args + label_id: str + children: set of strings + parents: set of strings + cardinality: int, cardinality of the random variable the node represents + cpd: an instance of a conditional probability distribution, + e.g. BernoulliOrCPD or TabularCPD + """ + self.label_id = label_id + self.children = children + self.parents = parents + self.cardinality = cardinality + self.cpd = cpd + + self.pi_agg = None # np.array dimensions [1, cardinality] + self.lambda_agg = None # np.array dimensions [1, cardinality] + + self.pi_received_msgs = self._init_received_msgs(parents) + self.lambda_received_msgs = self._init_received_msgs(children) + + @classmethod + def from_cpd_class(cls, + label_id, + children, + parents, + cardinality, + cpd_class): + cpd = cpd_class(label_id, parents) + return cls(label_id, children, parents, cardinality, cpd) + + @property + def belief(self): + if self.pi_agg.any() and self.lambda_agg.any(): + belief = np.multiply(self.pi_agg, self.lambda_agg) + return self._normalize(belief) + else: + return None + + def _normalize(self, value): + return value/value.sum() + + @staticmethod + def _init_received_msgs(keys): + return {k: None for k in keys} + + def _return_msgs_received_for_msg_type(self, message_type): + """ + Input: + message_type: MessageType enum + + Returns: + msg_values: list of message values (each an np.array) + """ + if message_type == MessageType.LAMBDA: + msg_values = [msg for msg in self.lambda_received_msgs.values()] + elif message_type == MessageType.PI: + msg_values = [msg for msg in self.pi_received_msgs.values()] + return msg_values + + def validate_and_return_msgs_received_for_msg_type(self, message_type): + """ + Check that all messages have been received from children (parents). + Raise error if all messages have not been received. Called + before calculating lambda_agg (pi_agg). + + Input: + message_type: MessageType enum + + Returns: + msg_values: list of message values (each an np.array) + """ + msg_values = self._return_msgs_received_for_msg_type(message_type) + + if any(msg is None for msg in msg_values): + raise ValueError( + "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.". + format(msg_type=message_type.value) + ) + else: + return msg_values + + def compute_pi_agg(self): + # TODO: implement explict factor product operation + raise NotImplementedError + + def compute_lambda_agg(self): + if not self.children: + return self.lambda_agg + else: + lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA) + self.lambda_agg = reduce(np.multiply, lambda_msg_values) + return self.lambda_agg + + def _update_received_msg_by_key(self, received_msg_dict, key, new_value): + if key not in received_msg_dict.keys(): + raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}". + format(key, received_msg_dict.keys())) + + if not isinstance(new_value, np.ndarray): + raise TypeError("Expected a new value of type numpy.ndarray, but got type {}". + format(type(new_value))) + + if new_value.shape != (self.cardinality,): + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead". + format(self.cardinality, new_value.shape)) + received_msg_dict[key] = new_value + + def update_pi_msg_from_parent(self, parent, new_value): + self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs, + key=parent, + new_value=new_value) + + def update_lambda_msg_from_child(self, child, new_value): + self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs, + key=child, + new_value=new_value) + + def compute_pi_msg_to_child(self, child_k): + lambda_msg_from_child = self.lambda_received_msgs[child_k] + if lambda_msg_from_child is not None: + with np.errstate(divide='ignore', invalid='ignore'): + # 0/0 := 0 + return self._normalize( + np.nan_to_num(np.divide(self.belief, lambda_msg_from_child))) + else: + raise ValueError("Can't compute pi message to child_{} without having received" \ + "a lambda message from that child.") + + def compute_lambda_msg_to_parent(self, parent_k): + # TODO: implement explict factor product operation + raise NotImplementedError + + @property + def is_fully_initialized(self): + """ + Returns True if all lambda and pi messages and lambda_agg and + pi_agg are not None, else False. + """ + lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA) + if any(msg is None for msg in lambda_msgs): + return False + + pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI) + if any(msg is None for msg in pi_msgs): + return False + + if (self.pi_agg is None) or (self.lambda_agg is None): + return False + + return True -- cgit v1.2.3