aboutsummaryrefslogtreecommitdiff
path: root/beliefs/utils/edges_helper.py
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-20 11:40:02 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-20 11:40:02 -0800
commit71e384a741e52f94882b14062a3dc10e5f391533 (patch)
tree669b8c78e3c9c7e44cf58692fef81836b8cc94b9 /beliefs/utils/edges_helper.py
parentb16e990b7e4d00e427d4445ba38eef0fb967963a (diff)
downloadbeliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.gz
beliefs-71e384a741e52f94882b14062a3dc10e5f391533.tar.bz2
beliefs-71e384a741e52f94882b14062a3dc10e5f391533.zip
BernoulliOrCPD inherits from TabularCPD
Diffstat (limited to 'beliefs/utils/edges_helper.py')
-rw-r--r--beliefs/utils/edges_helper.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py
index 7ac783c..c959a3b 100644
--- a/beliefs/utils/edges_helper.py
+++ b/beliefs/utils/edges_helper.py
@@ -1,7 +1,7 @@
from collections import defaultdict
from beliefs.types.Node import Node
-from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor
+from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
class EdgesHelper:
@@ -33,7 +33,7 @@ class EdgesHelper:
all_labels.update({parent, child})
return all_labels
- def create_cpds_from_edges(self, CPD=BernoulliOrFactor):
+ def create_cpds_from_edges(self, CPD=BernoulliOrCPD):
"""
Create factors from list of edges.
@@ -55,7 +55,7 @@ class EdgesHelper:
factors.add(cpd)
return factors
- def get_label_to_factor_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_factor_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to the CPD/factor where
that label_id is a child.
@@ -70,7 +70,7 @@ class EdgesHelper:
label_to_factor[factor.child] = factor
return label_to_factor
- def get_label_to_node_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_node_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to a Node instance.
Returns:
@@ -126,8 +126,8 @@ class EdgesHelper:
nodes = set()
for label in labels:
- parents = labels_to_parents[label]
- children = labels_to_children[label]
+ parents = list(labels_to_parents[label])
+ children = list(labels_to_children[label])
node = node_class(label_id=label,
children=children,