From 71e384a741e52f94882b14062a3dc10e5f391533 Mon Sep 17 00:00:00 2001 From: Cathy Yeh Date: Mon, 20 Nov 2017 11:40:02 -0800 Subject: BernoulliOrCPD inherits from TabularCPD --- beliefs/factors/BernoulliOrCPD.py | 37 +++++++++++++++++++++++++++++ beliefs/factors/BernoulliOrFactor.py | 42 --------------------------------- beliefs/factors/CPD.py | 36 ++++++++++++++++++++++++++++ beliefs/inference/belief_propagation.py | 4 ++-- beliefs/types/BernoulliOrNode.py | 4 ++-- beliefs/types/Node.py | 2 +- beliefs/utils/edges_helper.py | 12 +++++----- tests/test_belief_propagation.py | 3 ++- 8 files changed, 86 insertions(+), 54 deletions(-) create mode 100644 beliefs/factors/BernoulliOrCPD.py delete mode 100644 beliefs/factors/BernoulliOrFactor.py create mode 100644 beliefs/factors/CPD.py diff --git a/beliefs/factors/BernoulliOrCPD.py b/beliefs/factors/BernoulliOrCPD.py new file mode 100644 index 0000000..e4fcbf1 --- /dev/null +++ b/beliefs/factors/BernoulliOrCPD.py @@ -0,0 +1,37 @@ +import numpy as np + +from beliefs.factors.CPD import TabularCPD + + +class BernoulliOrCPD(TabularCPD): + """CPD class for a Bernoulli random variable whose relationship to its + parents (also Bernoulli random variables) is described by OR logic. + + If at least one of the variable's parents is True, then the variable + is True, and False otherwise. + """ + def __init__(self, variable, parents=set()): + super().__init__(variable=variable, + variable_card=2, + parents=parents, + parents_card=[2]*len(parents), + values=None) + self._values = None + + @property + def values(self): + if self._values is None: + self._values = self._build_kwise_values_array(len(self.variables)) + self._values = self._values.reshape(self.cardinality) + return self._values + + @staticmethod + def _build_kwise_values_array(k): + # special case a completely independent factor, and + # return the uniform prior + if k == 1: + return np.array([0.5, 0.5]) + + return np.array( + [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1) + ) diff --git a/beliefs/factors/BernoulliOrFactor.py b/beliefs/factors/BernoulliOrFactor.py deleted file mode 100644 index 4f973ae..0000000 --- a/beliefs/factors/BernoulliOrFactor.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np - - -class BernoulliOrFactor: - """CPD class for a Bernoulli random variable whose relationship to its - parents is described by OR logic. - - If at least one of a child's parents is True, then the child is True, and - False otherwise.""" - def __init__(self, child, parents=set()): - self.child = child - self.parents = set(parents) - self.variables = set([child] + list(parents)) - self.cardinality = [2]*len(self.variables) - self._values = None - - @property - def values(self): - if self._values is None: - self._values = self._build_kwise_values_array(len(self.variables)) - self._values = self._values.reshape(self.cardinality) - return self._values - - def get_values(self): - """ - Returns the tabular cpd form of the values. - """ - if len(self.cardinality) == 1: - return self.values.reshape(1, np.prod(self.cardinality)) - else: - return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) - - @staticmethod - def _build_kwise_values_array(k): - # special case a completely independent factor, and - # return the uniform prior - if k == 1: - return np.array([0.5, 0.5]) - - return np.array( - [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1) - ) diff --git a/beliefs/factors/CPD.py b/beliefs/factors/CPD.py new file mode 100644 index 0000000..8de47b3 --- /dev/null +++ b/beliefs/factors/CPD.py @@ -0,0 +1,36 @@ +import numpy as np + + +class TabularCPD: + """ + Defines the conditional probability table for a discrete variable + whose parents are also discrete. + + TODO: have this inherit from DiscreteFactor + """ + def __init__(self, variable, variable_card, + parents=[], parents_card=[], values=[]): + """ + Args: + variable: int or string + variable_card: int + parents: optional, list of int and/or strings + parents_card: optional, list of int + values: optional, 2d list or array + """ + self.variable = variable + self.parents = parents + self.variables = [variable] + parents + self.cardinality = [variable_card] + parents_card + + if values: + self.values = np.array(values) + + def get_values(self): + """ + Returns the tabular cpd form of the values. + """ + if len(self.cardinality) == 1: + return self.values.reshape(1, np.prod(self.cardinality)) + else: + return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index ecd5e9c..37aa437 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -54,8 +54,8 @@ class BeliefPropagation: # exclude the message sender (either a parent or child) from getting an # outgoing msg from the node to update - parent_ids = node.parents - set([msg_sender_label_id]) - child_ids = node.children - set([msg_sender_label_id]) + parent_ids = set(node.parents) - set([msg_sender_label_id]) + child_ids = set(node.children) - set([msg_sender_label_id]) print("parent_ids:", parent_ids) print("child_ids:", child_ids) diff --git a/beliefs/types/BernoulliOrNode.py b/beliefs/types/BernoulliOrNode.py index 27da85a..ce497b9 100644 --- a/beliefs/types/BernoulliOrNode.py +++ b/beliefs/types/BernoulliOrNode.py @@ -6,7 +6,7 @@ from beliefs.types.Node import ( MessageType, InvalidLambdaMsgToParent ) -from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor +from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD class BernoulliOrNode(Node): @@ -18,7 +18,7 @@ class BernoulliOrNode(Node): children=children, parents=parents, cardinality=2, - cpd=BernoulliOrFactor(label_id, parents)) + cpd=BernoulliOrCPD(label_id, parents)) def compute_pi_agg(self): if not self.parents: diff --git a/beliefs/types/Node.py b/beliefs/types/Node.py index a8dca7c..a496acf 100644 --- a/beliefs/types/Node.py +++ b/beliefs/types/Node.py @@ -33,7 +33,7 @@ class Node: parents: set of strings cardinality: int, cardinality of the random variable the node represents cpd: an instance of a conditional probability distribution, - e.g. BernoulliOrFactor or pgmpy's TabularCPD + e.g. BernoulliOrCPD or TabularCPD """ self.label_id = label_id self.children = children diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py index 7ac783c..c959a3b 100644 --- a/beliefs/utils/edges_helper.py +++ b/beliefs/utils/edges_helper.py @@ -1,7 +1,7 @@ from collections import defaultdict from beliefs.types.Node import Node -from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor +from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD class EdgesHelper: @@ -33,7 +33,7 @@ class EdgesHelper: all_labels.update({parent, child}) return all_labels - def create_cpds_from_edges(self, CPD=BernoulliOrFactor): + def create_cpds_from_edges(self, CPD=BernoulliOrCPD): """ Create factors from list of edges. @@ -55,7 +55,7 @@ class EdgesHelper: factors.add(cpd) return factors - def get_label_to_factor_dict(self, CPD=BernoulliOrFactor): + def get_label_to_factor_dict(self, CPD=BernoulliOrCPD): """Create a dictionary mapping each label_id to the CPD/factor where that label_id is a child. @@ -70,7 +70,7 @@ class EdgesHelper: label_to_factor[factor.child] = factor return label_to_factor - def get_label_to_node_dict(self, CPD=BernoulliOrFactor): + def get_label_to_node_dict(self, CPD=BernoulliOrCPD): """Create a dictionary mapping each label_id to a Node instance. Returns: @@ -126,8 +126,8 @@ class EdgesHelper: nodes = set() for label in labels: - parents = labels_to_parents[label] - children = labels_to_children[label] + parents = list(labels_to_parents[label]) + children = list(labels_to_children[label]) node = node_class(label_id=label, children=children, diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index ef7ffb0..24ee94b 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -52,7 +52,8 @@ def many_parents_model(many_parents_edges): @pytest.fixture(scope='function') def one_node_model(): - a_node = BernoulliOrNode(label_id='x', children=set(), parents=set()) + a_node = BernoulliOrNode(label_id='x', children=[], parents=[]) + # a_node = BernoulliOrNode(label_id='x', children=set(), parents=set()) return BernoulliOrModel(edges=None, nodes={'x': a_node}) -- cgit v1.2.3