aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/models')
-rw-r--r--beliefs/models/BayesianModel.py24
-rw-r--r--beliefs/models/BernoulliOrModel.py2
-rw-r--r--beliefs/models/DirectedGraph.py11
3 files changed, 19 insertions, 18 deletions
diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/BayesianModel.py
index bdfd037..6257a57 100644
--- a/beliefs/models/BayesianModel.py
+++ b/beliefs/models/BayesianModel.py
@@ -12,18 +12,18 @@ class BayesianModel(DirectedGraph):
Bayesian model stores nodes and edges described by conditional probability
distributions.
"""
- def __init__(self, edges, nodes=None):
+ def __init__(self, edges, nodes_dict=None):
"""
Input:
edges: list of edge tuples of form ('parent', 'child')
nodes: (optional) dict
a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
"""
- if nodes is not None:
- super().__init__(edges, nodes.keys())
+ if nodes_dict is not None:
+ super().__init__(edges, nodes_dict.keys())
else:
super().__init__(edges)
- self.nodes = nodes
+ self.nodes_dict = nodes_dict
@classmethod
def from_node_class(cls, edges, node_class):
@@ -57,10 +57,10 @@ class BayesianModel(DirectedGraph):
to an (unnormalized) unit vector, of length the cardinality of x.
"""
for root in self.get_roots():
- self.nodes[root].pi_agg = self.nodes[root].cpd.values
+ self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
for leaf in self.get_leaves():
- self.nodes[leaf].lambda_agg = np.ones([self.nodes[leaf].cardinality])
+ self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
@property
def all_nodes_are_fully_initialized(self):
@@ -68,7 +68,7 @@ class BayesianModel(DirectedGraph):
Returns True if, for all nodes in the model, all lambda and pi
messages and lambda_agg and pi_agg are not None, else False.
"""
- for node in self.nodes.values():
+ for node in self.nodes_dict.values():
if not node.is_fully_initialized:
return False
return True
@@ -77,8 +77,8 @@ class BayesianModel(DirectedGraph):
"""
Returns a copy of the model.
"""
- copy_edges = self.edges().copy()
- copy_nodes = copy.deepcopy(self.nodes)
+ copy_edges = list(self.edges()).copy()
+ copy_nodes = copy.deepcopy(self.nodes_dict)
copy_model = self.__class__(edges=copy_edges, nodes=copy_nodes)
return copy_model
@@ -90,7 +90,7 @@ class BayesianModel(DirectedGraph):
RETURNS
set of strings (labels)
"""
- return {label for label, node in self.nodes.items() if is_kronecker_delta(node.belief)}
+ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}
def get_unobserved_variables_in_definite_state(self, observed=set()):
"""
@@ -105,9 +105,9 @@ class BayesianModel(DirectedGraph):
# Assert that beliefs of directly observed vars are kronecker deltas
for label in observed:
- assert is_kronecker_delta(self.nodes[label].belief), \
+ assert is_kronecker_delta(self.nodes_dict[label].belief), \
("Observed label has belief {} but should be kronecker delta"
- .format(self.nodes[label].belief))
+ .format(self.nodes_dict[label].belief))
vars_in_definite_state = self.get_variables_in_definite_state()
assert observed <= vars_in_definite_state, \
diff --git a/beliefs/models/BernoulliOrModel.py b/beliefs/models/BernoulliOrModel.py
index da18fb6..bf2b44c 100644
--- a/beliefs/models/BernoulliOrModel.py
+++ b/beliefs/models/BernoulliOrModel.py
@@ -14,4 +14,4 @@ class BernoulliOrModel(BayesianModel):
"""
if nodes is None:
nodes = self.create_nodes(edges, node_class=BernoulliOrNode)
- super().__init__(edges, nodes=nodes)
+ super().__init__(edges, nodes_dict=nodes)
diff --git a/beliefs/models/DirectedGraph.py b/beliefs/models/DirectedGraph.py
index 8fce894..84b3a02 100644
--- a/beliefs/models/DirectedGraph.py
+++ b/beliefs/models/DirectedGraph.py
@@ -21,15 +21,16 @@ class DirectedGraph(nx.DiGraph):
"""
Returns a list of leaves of the graph.
"""
- return [node for node, out_degree in self.out_degree_iter() if
- out_degree == 0]
+ return [node for node, out_degree in self.out_degree() if out_degree == 0]
def get_roots(self):
"""
Returns a list of roots of the graph.
"""
- return [node for node, in_degree in self.in_degree().items() if
- in_degree == 0]
+ return [node for node, in_degree in self.in_degree() if in_degree == 0]
def get_topologically_sorted_nodes(self, reverse=False):
- return nx.topological_sort(self, reverse=reverse)
+ if reverse:
+ return list(reversed(list(nx.topological_sort(self))))
+ else:
+ return nx.topological_sort(self)