blob: 84b3a02a6c5e4db7da077da566da4f21a0349bf6 (
plain) (
tree)
|
|
import networkx as nx
class DirectedGraph(nx.DiGraph):
"""
Base class for all directed graphical models.
"""
def __init__(self, edges=None, node_labels=None):
"""
Input:
edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
node_labels: a list of strings of node labels
"""
super().__init__()
if edges is not None:
self.add_edges_from(edges)
if node_labels is not None:
self.add_nodes_from(node_labels)
def get_leaves(self):
"""
Returns a list of leaves of the graph.
"""
return [node for node, out_degree in self.out_degree() if out_degree == 0]
def get_roots(self):
"""
Returns a list of roots of the graph.
"""
return [node for node, in_degree in self.in_degree() if in_degree == 0]
def get_topologically_sorted_nodes(self, reverse=False):
if reverse:
return list(reversed(list(nx.topological_sort(self))))
else:
return nx.topological_sort(self)
|