aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/cpd.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/factors/cpd.py')
-rw-r--r--beliefs/factors/cpd.py29
1 files changed, 14 insertions, 15 deletions
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py
index 9e7191f..c7883c9 100644
--- a/beliefs/factors/cpd.py
+++ b/beliefs/factors/cpd.py
@@ -1,3 +1,4 @@
+import copy
import numpy as np
from beliefs.factors.discrete_factor import DiscreteFactor
@@ -7,16 +8,18 @@ class TabularCPD(DiscreteFactor):
Defines the conditional probability table for a discrete variable
whose parents are also discrete.
"""
- def __init__(self, variable, variable_card,
- parents=[], parents_card=[], values=[], state_names=None):
+ def __init__(self, variable, variable_card, parents=[], parents_card=[],
+ values=[], state_names=None):
"""
- Args:
- variable: int or string
- variable_card: int
- parents: optional, list of int and/or strings
- parents_card: optional, list of int
- values: optional, 2d list or array
- state_names: dictionary (optional),
+ Args
+ variable: int or string
+ variable_card: int
+ parents: list,
+ (optional) list of int and/or strings
+ parents_card: list,
+ (optional) list of int
+ values: 2-d list or array (optional)
+ state_names: dictionary (optional),
mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
super().__init__(variables=[variable] + parents,
@@ -24,7 +27,7 @@ class TabularCPD(DiscreteFactor):
values=values,
state_names=state_names)
self.variable = variable
- self.parents = parents
+ self.parents = list(parents)
def get_values(self):
"""
@@ -36,8 +39,4 @@ class TabularCPD(DiscreteFactor):
return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
def copy(self):
- return self.__class__(self.variable,
- self.cardinality[0],
- self.parents,
- self.cardinality[1:],
- self._values)
+ return copy.deepcopy(self)