aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-20 17:05:37 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-21 13:18:34 -0800
commitd166e36eaf5803af035e444628c67701322b0eb6 (patch)
tree3e715d2ab34ce447222ccfa11bcde31065faae26 /beliefs/models
parent71e384a741e52f94882b14062a3dc10e5f391533 (diff)
downloadbeliefs-d166e36eaf5803af035e444628c67701322b0eb6.tar.gz
beliefs-d166e36eaf5803af035e444628c67701322b0eb6.tar.bz2
beliefs-d166e36eaf5803af035e444628c67701322b0eb6.zip
refactor msg passing methods to BeliefUpdateNodeModel from BayesianModel
Diffstat (limited to 'beliefs/models')
-rw-r--r--beliefs/models/BayesianModel.py76
-rw-r--r--beliefs/models/BernoulliOrModel.py17
-rw-r--r--beliefs/models/beliefupdate/BeliefUpdateNodeModel.py91
-rw-r--r--beliefs/models/beliefupdate/BernoulliOrNode.py47
-rw-r--r--beliefs/models/beliefupdate/Node.py179
5 files changed, 332 insertions, 78 deletions
diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/BayesianModel.py
index 6257a57..b57f968 100644
--- a/beliefs/models/BayesianModel.py
+++ b/beliefs/models/BayesianModel.py
@@ -1,9 +1,7 @@
import copy
-import numpy as np
import networkx as nx
from beliefs.models.DirectedGraph import DirectedGraph
-from beliefs.utils.edges_helper import EdgesHelper
from beliefs.utils.math_helper import is_kronecker_delta
@@ -12,74 +10,30 @@ class BayesianModel(DirectedGraph):
Bayesian model stores nodes and edges described by conditional probability
distributions.
"""
- def __init__(self, edges, nodes_dict=None):
+ def __init__(self, edges=[], variables=[], cpds=[]):
"""
- Input:
- edges: list of edge tuples of form ('parent', 'child')
- nodes: (optional) dict
- a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
- """
- if nodes_dict is not None:
- super().__init__(edges, nodes_dict.keys())
- else:
- super().__init__(edges)
- self.nodes_dict = nodes_dict
-
- @classmethod
- def from_node_class(cls, edges, node_class):
- """Automatically create all nodes from the same node class
+ Base class for Bayesian model.
Input:
- edges: list of edge tuples of form ('parent', 'child')
- node_class: (optional) the Node class or subclass from which to
- create all the nodes from edges.
- """
- nodes = cls.create_nodes(edges, node_class)
- return cls.__init__(edges=edges, nodes=nodes)
-
- @staticmethod
- def create_nodes(edges, node_class):
- """Returns list of Node instances created from edges using
- the default node_class"""
- edges_helper = EdgesHelper(edges)
- nodes = edges_helper.create_nodes_from_edges(node_class=node_class)
- label_to_node = dict()
- for node in nodes:
- label_to_node[node.label_id] = node
- return label_to_node
-
- def set_boundary_conditions(self):
- """
- 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
- probability of x.
-
- 2. Leaf nodes: if x is a node with no children, set Lambda(x)
- to an (unnormalized) unit vector, of length the cardinality of x.
- """
- for root in self.get_roots():
- self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
-
- for leaf in self.get_leaves():
- self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
-
- @property
- def all_nodes_are_fully_initialized(self):
- """
- Returns True if, for all nodes in the model, all lambda and pi
- messages and lambda_agg and pi_agg are not None, else False.
+ edges: (optional) list of edges,
+ tuples of form ('parent', 'child')
+ variables: (optional) list of str or int
+ labels for variables
+ cpds: (optional) list of CPDs
+ TabularCPD class or subclass
"""
- for node in self.nodes_dict.values():
- if not node.is_fully_initialized:
- return False
- return True
+ super().__init__()
+ super().add_edges_from(edges)
+ super().add_nodes_from(variables)
+ self.cpds = cpds
def copy(self):
"""
Returns a copy of the model.
"""
- copy_edges = list(self.edges()).copy()
- copy_nodes = copy.deepcopy(self.nodes_dict)
- copy_model = self.__class__(edges=copy_edges, nodes=copy_nodes)
+ copy_model = self.__class__(edges=list(self.edges()).copy(),
+ variables=list(self.nodes()).copy(),
+ cpds=[cpd.copy() for cpd in self.cpds])
return copy_model
def get_variables_in_definite_state(self):
diff --git a/beliefs/models/BernoulliOrModel.py b/beliefs/models/BernoulliOrModel.py
deleted file mode 100644
index bf2b44c..0000000
--- a/beliefs/models/BernoulliOrModel.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from beliefs.models.BayesianModel import BayesianModel
-from beliefs.types.BernoulliOrNode import BernoulliOrNode
-
-
-class BernoulliOrModel(BayesianModel):
- """
- BernoulliOrModel stores node instances of BernoulliOrNodes (Bernoulli
- variables associated with an OR conditional probability distribution).
- """
- def __init__(self, edges, nodes=None):
- """
- Input:
- edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
- """
- if nodes is None:
- nodes = self.create_nodes(edges, node_class=BernoulliOrNode)
- super().__init__(edges, nodes_dict=nodes)
diff --git a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py
new file mode 100644
index 0000000..d74eaa7
--- /dev/null
+++ b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py
@@ -0,0 +1,91 @@
+import copy
+import numpy as np
+
+from beliefs.models.BayesianModel import BayesianModel
+from beliefs.utils.edges_helper import EdgesHelper
+
+
+class BeliefUpdateNodeModel(BayesianModel):
+ """
+ A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties
+ and methods for Pearl's belief update algorithm.
+
+ ref: "Fusion, Propagation, and Structuring in Belief Networks"
+ Artificial Intelligence 29 (1986) 241-288
+
+ """
+ def __init__(self, nodes_dict):
+ """
+ Input:
+ nodes_dict: dict
+ a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
+ """
+ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
+ variables=list(nodes_dict.keys()),
+ cpds=[node.cpd for node in nodes_dict.values()])
+
+ self.nodes_dict = nodes_dict
+
+ @classmethod
+ def from_edges(cls, edges, node_class):
+ """Create nodes from the same node class.
+
+ Input:
+ edges: list of edge tuples of form ('parent', 'child')
+ node_class: the Node class or subclass from which to
+ create all the nodes from edges.
+ """
+ edges_helper = EdgesHelper(edges)
+ nodes = edges_helper.create_nodes_from_edges(node_class)
+ nodes_dict = {node.label_id: node for node in nodes}
+ return cls(nodes_dict)
+
+ @staticmethod
+ def _get_edges_from_nodes(nodes):
+ """
+ Return list of all directed edges in nodes.
+
+ Args:
+ nodes: an iterable of objects of the Node class or subclass
+ Returns:
+ edges: list of edge tuples
+ """
+ edges = set()
+ for node in nodes:
+ if node.parents:
+ edge_tuples = zip(node.parents, [node.label_id]*len(node.parents))
+ edges.update(edge_tuples)
+ return list(edges)
+
+ def set_boundary_conditions(self):
+ """
+ 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
+ probability of x.
+
+ 2. Leaf nodes: if x is a node with no children, set Lambda(x)
+ to an (unnormalized) unit vector, of length the cardinality of x.
+ """
+ for root in self.get_roots():
+ self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
+
+ for leaf in self.get_leaves():
+ self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
+
+ @property
+ def all_nodes_are_fully_initialized(self):
+ """
+ Returns True if, for all nodes in the model, all lambda and pi
+ messages and lambda_agg and pi_agg are not None, else False.
+ """
+ for node in self.nodes_dict.values():
+ if not node.is_fully_initialized:
+ return False
+ return True
+
+ def copy(self):
+ """
+ Returns a copy of the model.
+ """
+ copy_nodes = copy.deepcopy(self.nodes_dict)
+ copy_model = self.__class__(nodes_dict=copy_nodes)
+ return copy_model
diff --git a/beliefs/models/beliefupdate/BernoulliOrNode.py b/beliefs/models/beliefupdate/BernoulliOrNode.py
new file mode 100644
index 0000000..3386275
--- /dev/null
+++ b/beliefs/models/beliefupdate/BernoulliOrNode.py
@@ -0,0 +1,47 @@
+import numpy as np
+from functools import reduce
+
+from beliefs.models.beliefupdate.Node import (
+ Node,
+ MessageType,
+ InvalidLambdaMsgToParent
+)
+from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
+
+
+class BernoulliOrNode(Node):
+ def __init__(self,
+ label_id,
+ children,
+ parents):
+ super().__init__(label_id=label_id,
+ children=children,
+ parents=parents,
+ cardinality=2,
+ cpd=BernoulliOrCPD(label_id, parents))
+
+ def compute_pi_agg(self):
+ if not self.parents:
+ self.pi_agg = self.cpd.values
+ else:
+ pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p0 = [p[0] for p in pi_msg_values]
+ p_0 = reduce(lambda x, y: x*y, parents_p0)
+ p_1 = 1 - p_0
+ self.pi_agg = np.array([p_0, p_1])
+ return self.pi_agg
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ # TODO: cleanup this validation
+ _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
+ lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
+ lambda_1 = self.lambda_agg[1]
+ lambda_msg = np.array([lambda_0, lambda_1])
+ if not any(lambda_msg):
+ raise InvalidLambdaMsgToParent
+ return self._normalize(lambda_msg)
diff --git a/beliefs/models/beliefupdate/Node.py b/beliefs/models/beliefupdate/Node.py
new file mode 100644
index 0000000..daa2f14
--- /dev/null
+++ b/beliefs/models/beliefupdate/Node.py
@@ -0,0 +1,179 @@
+import numpy as np
+from functools import reduce
+from enum import Enum
+
+
+class InvalidLambdaMsgToParent(Exception):
+ """Computed invalid lambda msg to send to parent."""
+ pass
+
+
+class MessageType(Enum):
+ LAMBDA = 'lambda'
+ PI = 'pi'
+
+
+class Node:
+ """A node in a DAG with methods to compute the belief (marginal probability
+ of the node given evidence) and compute pi/lambda messages to/from its neighbors
+ to update its belief.
+
+ Implemented from Pearl's belief propagation algorithm.
+ """
+ def __init__(self,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd):
+ """
+ Args
+ label_id: str
+ children: set of strings
+ parents: set of strings
+ cardinality: int, cardinality of the random variable the node represents
+ cpd: an instance of a conditional probability distribution,
+ e.g. BernoulliOrCPD or TabularCPD
+ """
+ self.label_id = label_id
+ self.children = children
+ self.parents = parents
+ self.cardinality = cardinality
+ self.cpd = cpd
+
+ self.pi_agg = None # np.array dimensions [1, cardinality]
+ self.lambda_agg = None # np.array dimensions [1, cardinality]
+
+ self.pi_received_msgs = self._init_received_msgs(parents)
+ self.lambda_received_msgs = self._init_received_msgs(children)
+
+ @classmethod
+ def from_cpd_class(cls,
+ label_id,
+ children,
+ parents,
+ cardinality,
+ cpd_class):
+ cpd = cpd_class(label_id, parents)
+ return cls(label_id, children, parents, cardinality, cpd)
+
+ @property
+ def belief(self):
+ if self.pi_agg.any() and self.lambda_agg.any():
+ belief = np.multiply(self.pi_agg, self.lambda_agg)
+ return self._normalize(belief)
+ else:
+ return None
+
+ def _normalize(self, value):
+ return value/value.sum()
+
+ @staticmethod
+ def _init_received_msgs(keys):
+ return {k: None for k in keys}
+
+ def _return_msgs_received_for_msg_type(self, message_type):
+ """
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ if message_type == MessageType.LAMBDA:
+ msg_values = [msg for msg in self.lambda_received_msgs.values()]
+ elif message_type == MessageType.PI:
+ msg_values = [msg for msg in self.pi_received_msgs.values()]
+ return msg_values
+
+ def validate_and_return_msgs_received_for_msg_type(self, message_type):
+ """
+ Check that all messages have been received from children (parents).
+ Raise error if all messages have not been received. Called
+ before calculating lambda_agg (pi_agg).
+
+ Input:
+ message_type: MessageType enum
+
+ Returns:
+ msg_values: list of message values (each an np.array)
+ """
+ msg_values = self._return_msgs_received_for_msg_type(message_type)
+
+ if any(msg is None for msg in msg_values):
+ raise ValueError(
+ "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.".
+ format(msg_type=message_type.value)
+ )
+ else:
+ return msg_values
+
+ def compute_pi_agg(self):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ def compute_lambda_agg(self):
+ if not self.children:
+ return self.lambda_agg
+ else:
+ lambda_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ self.lambda_agg = reduce(np.multiply, lambda_msg_values)
+ return self.lambda_agg
+
+ def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
+ if key not in received_msg_dict.keys():
+ raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}".
+ format(key, received_msg_dict.keys()))
+
+ if not isinstance(new_value, np.ndarray):
+ raise TypeError("Expected a new value of type numpy.ndarray, but got type {}".
+ format(type(new_value)))
+
+ if new_value.shape != (self.cardinality,):
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead".
+ format(self.cardinality, new_value.shape))
+ received_msg_dict[key] = new_value
+
+ def update_pi_msg_from_parent(self, parent, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.pi_received_msgs,
+ key=parent,
+ new_value=new_value)
+
+ def update_lambda_msg_from_child(self, child, new_value):
+ self._update_received_msg_by_key(received_msg_dict=self.lambda_received_msgs,
+ key=child,
+ new_value=new_value)
+
+ def compute_pi_msg_to_child(self, child_k):
+ lambda_msg_from_child = self.lambda_received_msgs[child_k]
+ if lambda_msg_from_child is not None:
+ with np.errstate(divide='ignore', invalid='ignore'):
+ # 0/0 := 0
+ return self._normalize(
+ np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
+ else:
+ raise ValueError("Can't compute pi message to child_{} without having received" \
+ "a lambda message from that child.")
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ # TODO: implement explict factor product operation
+ raise NotImplementedError
+
+ @property
+ def is_fully_initialized(self):
+ """
+ Returns True if all lambda and pi messages and lambda_agg and
+ pi_agg are not None, else False.
+ """
+ lambda_msgs = self._return_msgs_received_for_msg_type(MessageType.LAMBDA)
+ if any(msg is None for msg in lambda_msgs):
+ return False
+
+ pi_msgs = self._return_msgs_received_for_msg_type(MessageType.PI)
+ if any(msg is None for msg in pi_msgs):
+ return False
+
+ if (self.pi_agg is None) or (self.lambda_agg is None):
+ return False
+
+ return True