diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-11-20 11:40:02 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-11-20 11:40:02 -0800 |
commit | 71e384a741e52f94882b14062a3dc10e5f391533 (patch) | |
tree | 669b8c78e3c9c7e44cf58692fef81836b8cc94b9 /beliefs/factors/CPD.py | |
parent | b16e990b7e4d00e427d4445ba38eef0fb967963a (diff) | |
download | beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.gz beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.bz2 beliefs-71e384a741e52f94882b14062a3dc10e5f391533.zip |
BernoulliOrCPD inherits from TabularCPD
Diffstat (limited to 'beliefs/factors/CPD.py')
-rw-r--r-- | beliefs/factors/CPD.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/beliefs/factors/CPD.py b/beliefs/factors/CPD.py new file mode 100644 index 0000000..8de47b3 --- /dev/null +++ b/beliefs/factors/CPD.py @@ -0,0 +1,36 @@ +import numpy as np + + +class TabularCPD: + """ + Defines the conditional probability table for a discrete variable + whose parents are also discrete. + + TODO: have this inherit from DiscreteFactor + """ + def __init__(self, variable, variable_card, + parents=[], parents_card=[], values=[]): + """ + Args: + variable: int or string + variable_card: int + parents: optional, list of int and/or strings + parents_card: optional, list of int + values: optional, 2d list or array + """ + self.variable = variable + self.parents = parents + self.variables = [variable] + parents + self.cardinality = [variable_card] + parents_card + + if values: + self.values = np.array(values) + + def get_values(self): + """ + Returns the tabular cpd form of the values. + """ + if len(self.cardinality) == 1: + return self.values.reshape(1, np.prod(self.cardinality)) + else: + return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) |