diff options
Diffstat (limited to 'beliefs/factors/cpd.py')
-rw-r--r-- | beliefs/factors/cpd.py | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py index 9e7191f..c7883c9 100644 --- a/beliefs/factors/cpd.py +++ b/beliefs/factors/cpd.py @@ -1,3 +1,4 @@ +import copy import numpy as np from beliefs.factors.discrete_factor import DiscreteFactor @@ -7,16 +8,18 @@ class TabularCPD(DiscreteFactor): Defines the conditional probability table for a discrete variable whose parents are also discrete. """ - def __init__(self, variable, variable_card, - parents=[], parents_card=[], values=[], state_names=None): + def __init__(self, variable, variable_card, parents=[], parents_card=[], + values=[], state_names=None): """ - Args: - variable: int or string - variable_card: int - parents: optional, list of int and/or strings - parents_card: optional, list of int - values: optional, 2d list or array - state_names: dictionary (optional), + Args + variable: int or string + variable_card: int + parents: list, + (optional) list of int and/or strings + parents_card: list, + (optional) list of int + values: 2-d list or array (optional) + state_names: dictionary (optional), mapping variables to their states, of format {label_name: ['state1', 'state2']} """ super().__init__(variables=[variable] + parents, @@ -24,7 +27,7 @@ class TabularCPD(DiscreteFactor): values=values, state_names=state_names) self.variable = variable - self.parents = parents + self.parents = list(parents) def get_values(self): """ @@ -36,8 +39,4 @@ class TabularCPD(DiscreteFactor): return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) def copy(self): - return self.__class__(self.variable, - self.cardinality[0], - self.parents, - self.cardinality[1:], - self._values) + return copy.deepcopy(self) |