aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/cpd.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/factors/cpd.py')
-rw-r--r--beliefs/factors/cpd.py45
1 files changed, 21 insertions, 24 deletions
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py
index a286aaa..c7883c9 100644
--- a/beliefs/factors/cpd.py
+++ b/beliefs/factors/cpd.py
@@ -1,32 +1,33 @@
+import copy
import numpy as np
+from beliefs.factors.discrete_factor import DiscreteFactor
-class TabularCPD:
+class TabularCPD(DiscreteFactor):
"""
Defines the conditional probability table for a discrete variable
whose parents are also discrete.
-
- TODO: have this inherit from DiscreteFactor implementing explicit factor methods
"""
- def __init__(self, variable, variable_card,
- parents=[], parents_card=[], values=[]):
+ def __init__(self, variable, variable_card, parents=[], parents_card=[],
+ values=[], state_names=None):
"""
- Args:
- variable: int or string
- variable_card: int
- parents: optional, list of int and/or strings
- parents_card: optional, list of int
- values: optional, 2d list or array
+ Args
+ variable: int or string
+ variable_card: int
+ parents: list,
+ (optional) list of int and/or strings
+ parents_card: list,
+ (optional) list of int
+ values: 2-d list or array (optional)
+ state_names: dictionary (optional),
+ mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
+ super().__init__(variables=[variable] + parents,
+ cardinality=[variable_card] + parents_card,
+ values=values,
+ state_names=state_names)
self.variable = variable
- self.parents = parents
- self.variables = [variable] + parents
- self.cardinality = [variable_card] + parents_card
- self._values = np.array(values)
-
- @property
- def values(self):
- return self._values
+ self.parents = list(parents)
def get_values(self):
"""
@@ -38,8 +39,4 @@ class TabularCPD:
return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
def copy(self):
- return self.__class__(self.variable,
- self.cardinality[0],
- self.parents,
- self.cardinality[1:],
- self._values)
+ return copy.deepcopy(self)