diff options
author | Cathy Yeh <cathy@driver.xyz> | 2018-01-18 21:57:50 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2018-01-18 21:57:50 -0800 |
commit | 2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e (patch) | |
tree | fc71d343eec17b59d8af81eb768e4fe2eab167c2 /beliefs/factors/cpd.py | |
parent | 65d822247e30b6e104a8c09d3b930487b9f20a58 (diff) | |
parent | c93c352b2f68a2bbcde2241e61d9fb52504a67a9 (diff) | |
download | beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.tar.gz beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.tar.bz2 beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.zip |
Merge branch 'generic_discrete_factor'. Implements explicit discrete factor methodsv0.1.0
Diffstat (limited to 'beliefs/factors/cpd.py')
-rw-r--r-- | beliefs/factors/cpd.py | 45 |
1 files changed, 21 insertions, 24 deletions
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py index a286aaa..c7883c9 100644 --- a/beliefs/factors/cpd.py +++ b/beliefs/factors/cpd.py @@ -1,32 +1,33 @@ +import copy import numpy as np +from beliefs.factors.discrete_factor import DiscreteFactor -class TabularCPD: +class TabularCPD(DiscreteFactor): """ Defines the conditional probability table for a discrete variable whose parents are also discrete. - - TODO: have this inherit from DiscreteFactor implementing explicit factor methods """ - def __init__(self, variable, variable_card, - parents=[], parents_card=[], values=[]): + def __init__(self, variable, variable_card, parents=[], parents_card=[], + values=[], state_names=None): """ - Args: - variable: int or string - variable_card: int - parents: optional, list of int and/or strings - parents_card: optional, list of int - values: optional, 2d list or array + Args + variable: int or string + variable_card: int + parents: list, + (optional) list of int and/or strings + parents_card: list, + (optional) list of int + values: 2-d list or array (optional) + state_names: dictionary (optional), + mapping variables to their states, of format {label_name: ['state1', 'state2']} """ + super().__init__(variables=[variable] + parents, + cardinality=[variable_card] + parents_card, + values=values, + state_names=state_names) self.variable = variable - self.parents = parents - self.variables = [variable] + parents - self.cardinality = [variable_card] + parents_card - self._values = np.array(values) - - @property - def values(self): - return self._values + self.parents = list(parents) def get_values(self): """ @@ -38,8 +39,4 @@ class TabularCPD: return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) def copy(self): - return self.__class__(self.variable, - self.cardinality[0], - self.parents, - self.cardinality[1:], - self._values) + return copy.deepcopy(self) |