aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-20 17:05:37 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-21 13:18:34 -0800
commitd166e36eaf5803af035e444628c67701322b0eb6 (patch)
tree3e715d2ab34ce447222ccfa11bcde31065faae26 /beliefs/factors
parent71e384a741e52f94882b14062a3dc10e5f391533 (diff)
downloadbeliefs-d166e36eaf5803af035e444628c67701322b0eb6.tar.gz
beliefs-d166e36eaf5803af035e444628c67701322b0eb6.tar.bz2
beliefs-d166e36eaf5803af035e444628c67701322b0eb6.zip
refactor msg passing methods to BeliefUpdateNodeModel from BayesianModel
Diffstat (limited to 'beliefs/factors')
-rw-r--r--beliefs/factors/BernoulliOrCPD.py13
-rw-r--r--beliefs/factors/CPD.py15
2 files changed, 21 insertions, 7 deletions
diff --git a/beliefs/factors/BernoulliOrCPD.py b/beliefs/factors/BernoulliOrCPD.py
index e4fcbf1..2c6a31e 100644
--- a/beliefs/factors/BernoulliOrCPD.py
+++ b/beliefs/factors/BernoulliOrCPD.py
@@ -10,17 +10,22 @@ class BernoulliOrCPD(TabularCPD):
If at least one of the variable's parents is True, then the variable
is True, and False otherwise.
"""
- def __init__(self, variable, parents=set()):
+ def __init__(self, variable, parents=[]):
+ """
+ Args:
+ variable: int or string
+ parents: optional, list of int and/or strings
+ """
super().__init__(variable=variable,
variable_card=2,
parents=parents,
parents_card=[2]*len(parents),
- values=None)
- self._values = None
+ values=[])
+ self._values = []
@property
def values(self):
- if self._values is None:
+ if not any(self._values):
self._values = self._build_kwise_values_array(len(self.variables))
self._values = self._values.reshape(self.cardinality)
return self._values
diff --git a/beliefs/factors/CPD.py b/beliefs/factors/CPD.py
index 8de47b3..a286aaa 100644
--- a/beliefs/factors/CPD.py
+++ b/beliefs/factors/CPD.py
@@ -6,7 +6,7 @@ class TabularCPD:
Defines the conditional probability table for a discrete variable
whose parents are also discrete.
- TODO: have this inherit from DiscreteFactor
+ TODO: have this inherit from DiscreteFactor implementing explicit factor methods
"""
def __init__(self, variable, variable_card,
parents=[], parents_card=[], values=[]):
@@ -22,9 +22,11 @@ class TabularCPD:
self.parents = parents
self.variables = [variable] + parents
self.cardinality = [variable_card] + parents_card
+ self._values = np.array(values)
- if values:
- self.values = np.array(values)
+ @property
+ def values(self):
+ return self._values
def get_values(self):
"""
@@ -34,3 +36,10 @@ class TabularCPD:
return self.values.reshape(1, np.prod(self.cardinality))
else:
return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
+
+ def copy(self):
+ return self.__class__(self.variable,
+ self.cardinality[0],
+ self.parents,
+ self.cardinality[1:],
+ self._values)