diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-11-20 11:40:02 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-11-20 11:40:02 -0800 |
commit | 71e384a741e52f94882b14062a3dc10e5f391533 (patch) | |
tree | 669b8c78e3c9c7e44cf58692fef81836b8cc94b9 /beliefs/utils/edges_helper.py | |
parent | b16e990b7e4d00e427d4445ba38eef0fb967963a (diff) | |
download | beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.gz beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.bz2 beliefs-71e384a741e52f94882b14062a3dc10e5f391533.zip |
BernoulliOrCPD inherits from TabularCPD
Diffstat (limited to 'beliefs/utils/edges_helper.py')
-rw-r--r-- | beliefs/utils/edges_helper.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py index 7ac783c..c959a3b 100644 --- a/beliefs/utils/edges_helper.py +++ b/beliefs/utils/edges_helper.py @@ -1,7 +1,7 @@ from collections import defaultdict from beliefs.types.Node import Node -from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor +from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD class EdgesHelper: @@ -33,7 +33,7 @@ class EdgesHelper: all_labels.update({parent, child}) return all_labels - def create_cpds_from_edges(self, CPD=BernoulliOrFactor): + def create_cpds_from_edges(self, CPD=BernoulliOrCPD): """ Create factors from list of edges. @@ -55,7 +55,7 @@ class EdgesHelper: factors.add(cpd) return factors - def get_label_to_factor_dict(self, CPD=BernoulliOrFactor): + def get_label_to_factor_dict(self, CPD=BernoulliOrCPD): """Create a dictionary mapping each label_id to the CPD/factor where that label_id is a child. @@ -70,7 +70,7 @@ class EdgesHelper: label_to_factor[factor.child] = factor return label_to_factor - def get_label_to_node_dict(self, CPD=BernoulliOrFactor): + def get_label_to_node_dict(self, CPD=BernoulliOrCPD): """Create a dictionary mapping each label_id to a Node instance. Returns: @@ -126,8 +126,8 @@ class EdgesHelper: nodes = set() for label in labels: - parents = labels_to_parents[label] - children = labels_to_children[label] + parents = list(labels_to_parents[label]) + children = list(labels_to_children[label]) node = node_class(label_id=label, children=children, |