aboutsummaryrefslogtreecommitdiff
path: root/beliefs/utils/edges_helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/utils/edges_helper.py')
-rw-r--r--beliefs/utils/edges_helper.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py
index 7ac783c..c959a3b 100644
--- a/beliefs/utils/edges_helper.py
+++ b/beliefs/utils/edges_helper.py
@@ -1,7 +1,7 @@
from collections import defaultdict
from beliefs.types.Node import Node
-from beliefs.factors.BernoulliOrFactor import BernoulliOrFactor
+from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
class EdgesHelper:
@@ -33,7 +33,7 @@ class EdgesHelper:
all_labels.update({parent, child})
return all_labels
- def create_cpds_from_edges(self, CPD=BernoulliOrFactor):
+ def create_cpds_from_edges(self, CPD=BernoulliOrCPD):
"""
Create factors from list of edges.
@@ -55,7 +55,7 @@ class EdgesHelper:
factors.add(cpd)
return factors
- def get_label_to_factor_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_factor_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to the CPD/factor where
that label_id is a child.
@@ -70,7 +70,7 @@ class EdgesHelper:
label_to_factor[factor.child] = factor
return label_to_factor
- def get_label_to_node_dict(self, CPD=BernoulliOrFactor):
+ def get_label_to_node_dict(self, CPD=BernoulliOrCPD):
"""Create a dictionary mapping each label_id to a Node instance.
Returns:
@@ -126,8 +126,8 @@ class EdgesHelper:
nodes = set()
for label in labels:
- parents = labels_to_parents[label]
- children = labels_to_children[label]
+ parents = list(labels_to_parents[label])
+ children = list(labels_to_children[label])
node = node_class(label_id=label,
children=children,