aboutsummaryrefslogtreecommitdiff
path: root/beliefs/models/DirectedGraph.py
blob: 84b3a02a6c5e4db7da077da566da4f21a0349bf6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import networkx as nx


class DirectedGraph(nx.DiGraph):
    """
    Base class for all directed graphical models.
    """
    def __init__(self, edges=None, node_labels=None):
        """
        Input:
            edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
            node_labels: a list of strings of node labels
        """
        super().__init__()
        if edges is not None:
            self.add_edges_from(edges)
        if node_labels is not None:
            self.add_nodes_from(node_labels)

    def get_leaves(self):
        """
        Returns a list of leaves of the graph.
        """
        return [node for node, out_degree in self.out_degree() if out_degree == 0]

    def get_roots(self):
        """
        Returns a list of roots of the graph.
        """
        return [node for node, in_degree in self.in_degree() if in_degree == 0]

    def get_topologically_sorted_nodes(self, reverse=False):
        if reverse:
            return list(reversed(list(nx.topological_sort(self))))
        else:
            return nx.topological_sort(self)