aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-20 11:40:02 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-20 11:40:02 -0800
commit71e384a741e52f94882b14062a3dc10e5f391533 (patch)
tree669b8c78e3c9c7e44cf58692fef81836b8cc94b9
parentb16e990b7e4d00e427d4445ba38eef0fb967963a (diff)
downloadbeliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.gz
beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.bz2
beliefs-71e384a741e52f94882b14062a3dc10e5f391533.zip
BernoulliOrCPD inherits from TabularCPD
-rw-r--r--beliefs/factors/BernoulliOrCPD.py37
-rw-r--r--beliefs/factors/BernoulliOrFactor.py42
-rw-r--r--beliefs/factors/CPD.py36
-rw-r--r--beliefs/inference/belief_propagation.py4
-rw-r--r--beliefs/types/BernoulliOrNode.py4
-rw-r--r--beliefs/types/Node.py2
-rw-r--r--beliefs/utils/edges_helper.py12
-rw-r--r--tests/test_belief_propagation.py3
8 files changed, 86 insertions, 54 deletions
diff --git a/beliefs/factors/BernoulliOrCPD.py b/beliefs/factors/BernoulliOrCPD.py
new file mode 100644
index 0000000..e4fcbf1
--- /dev/null
+++ b/beliefs/factors/BernoulliOrCPD.py
@@ -0,0 +1,37 @@
+import numpy as np
+
+from beliefs.factors.CPD import TabularCPD
+
+
+class BernoulliOrCPD(TabularCPD):
+ """CPD class for a Bernoulli random variable whose relationship to its
+ parents (also Bernoulli random variables) is described by OR logic.
+
+ If at least one of the variable's parents is True, then the variable
+ is True, and False otherwise.
+ """
+ def __init__(self, variable, parents=set()):
+ super().__init__(variable=variable,
+ variable_card=2,
+ parents=parents,
+ parents_card=[2]*len(parents),
+ values=None)
+ self._values = None
+
+ @property
+ def values(self):
+ if self._values is None:
+ self._values = self._build_kwise_values_array(len(self.variables))
+ self._values = self._values.reshape(self.cardinality)
+ return self._values
+
+ @staticmethod
+ def _build_kwise_values_array(k):
+ # special case a completely independent factor, and
+ # return the uniform prior
+ if k == 1:
+ return np.array([0.5, 0.5])
+
+ return np.array(
+ [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
+ )
diff --git a/beliefs/factors/BernoulliOrFactor.py b/beliefs/factors/BernoulliOrFactor.py
deleted file mode 100644
index 4f973ae..0000000
--- a/beliefs/factors/BernoulliOrFactor.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import numpy as np
-
-
-class BernoulliOrFactor:
- """CPD class for a Bernoulli random variable whose relationship to its
- parents is described by OR logic.
-
- If at least one of a child's parents is True, then the child is True, and
- False otherwise."""
- def __init__(self, child, parents=set()):
- self.child = child
- self.parents = set(parents)
- self.variables = set([child] + list(parents))
- self.cardinality = [2]*len(self.variables)
- self._values = None
-
- @property
- def values(self):
- if self._values is None:
- self._values = self._build_kwise_values_array(len(self.variables))
- self._values = self._values.reshape(self.cardinality)
- return self._values
-
- def get_values(self):
- """
- Returns the tabular cpd form of the values.
- """
- if len(self.cardinality) == 1:
- return self.values.reshape(1, np.prod(self.cardinality))
- else:
- return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
-
- @staticmethod
- def _build_kwise_values_array(k):
- # special case a completely independent factor, and
- # return the uniform prior
- if k == 1:
- return np.array([0.5, 0.5])
-
- return np.array(
- [1.,] + [0.]*(2**(k-1)-1) + [0.,] + [1.]*(2**(k-1)-1)
- )
diff --git a/beliefs/factors/CPD.py b/beliefs/factors/CPD.py
new file mode 100644
index 0000000..8de47b3
--- /dev/null
+++ b/beliefs/factors/CPD.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+
+class TabularCPD:
+ """
+ Defines the conditional probability table for a discrete variable
+ whose parents are also discrete.
+
+ TODO: have this inherit from DiscreteFactor
+ """
+ def __init__(self, variable, variable_card,
+ parents=[], parents_card=[], values=[]):
+ """
+ Args:
+ variable: int or string
+ variable_card: int
+ parents: optional, list of int and/or strings
+ parents_card: optional, list of int
+ values: optional, 2d list or array
+ """
+ self.variable = variable
+ self.parents = parents
+ self.variables = [variable] + parents
+ self.cardinality = [variable_card] + parents_card
+
+ if values:
+ self.values = np.array(values)
+
+ def get_values(self):
+ """
+ Returns the tabular cpd form of the values.
+ """
+ if len(self.cardinality) == 1:
+ return self.values.reshape(1, np.prod(self.cardinality))
+ else:
+ return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index ecd5e9c..37aa437 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -54,8 +54,8 @@ class BeliefPropagation:
# exclude the message sender (either a parent or child) from getting an
# outgoing msg from the node to update
- parent_ids = node.parents - set([msg_sender_label_id])
- child_ids = node.children - set([msg_sender_label_id])
+ parent_ids = set(node.parents) - set([msg_sender_label_id])
+ child_ids = set(node.children) - set([msg_sender_label_id])
print("parent_ids:", parent_ids)
print("child_ids:", child_ids)
diff --git a/beliefs/types/BernoulliOrNode.py b/beliefs/types/BernoulliOrNode.py
index 27da85a..ce497b9 100644
--- a/beliefs/types/BernoulliOrNode.py
+++ b/beliefs/types/BernoulliOrNode.py
@@ -6,7 +6,7 @@ from beliefs.types.Node import (
MessageType,
InvalidLambdaMsgToParent
)
-from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor
+from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
class BernoulliOrNode(Node):
@@ -18,7 +18,7 @@ class BernoulliOrNode(Node):
children=children,
parents=parents,
cardinality=2,
- cpd=BernoulliOrFactor(label_id, parents))
+ cpd=BernoulliOrCPD(label_id, parents))
def compute_pi_agg(self):
if not self.parents:
diff --git a/beliefs/types/Node.py b/beliefs/types/Node.py
index a8dca7c..a496acf 100644
--- a/beliefs/types/Node.py
+++ b/beliefs/types/Node.py
@@ -33,7 +33,7 @@ class Node:
parents: set of strings
cardinality: int, cardinality of the random variable the node represents
cpd: an instance of a conditional probability distribution,
- e.g. BernoulliOrFactor or pgmpy's TabularCPD
+ e.g. BernoulliOrCPD or TabularCPD
"""
self.label_id = label_id
self.children = children
diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py
index 7ac783c..c959a3b 100644
--- a/beliefs/utils/edges_helper.py
+++ b/beliefs/utils/edges_helper.py
@@ -1,7 +1,7 @@
from collections import defaultdict
from beliefs.types.Node import Node
-from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor
+from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
class EdgesHelper:
@@ -33,7 +33,7 @@ class EdgesHelper:
all_labels.update({parent, child})
return all_labels
- def create_cpds_from_edges(self, CPD=BernoulliOrFactor):
+ def create_cpds_from_edges(self, CPD=BernoulliOrCPD):
"""
Create factors from list of edges.
@@ -55,7 +55,7 @@ class EdgesHelper:
factors.add(cpd)
return factors
- def get_label_to_factor_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_factor_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to the CPD/factor where
that label_id is a child.
@@ -70,7 +70,7 @@ class EdgesHelper:
label_to_factor[factor.child] = factor
return label_to_factor
- def get_label_to_node_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_node_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to a Node instance.
Returns:
@@ -126,8 +126,8 @@ class EdgesHelper:
nodes = set()
for label in labels:
- parents = labels_to_parents[label]
- children = labels_to_children[label]
+ parents = list(labels_to_parents[label])
+ children = list(labels_to_children[label])
node = node_class(label_id=label,
children=children,
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index ef7ffb0..24ee94b 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -52,7 +52,8 @@ def many_parents_model(many_parents_edges):
@pytest.fixture(scope='function')
def one_node_model():
- a_node = BernoulliOrNode(label_id='x', children=set(), parents=set())
+ a_node = BernoulliOrNode(label_id='x', children=[], parents=[])
+ # a_node = BernoulliOrNode(label_id='x', children=set(), parents=set())
return BernoulliOrModel(edges=None, nodes={'x': a_node})