aboutsummaryrefslogtreecommitdiff
path: root/beliefs/utils/edges_helper.py
blob: 130686c0e590627d945b23ac2816e22f33a74dad (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from collections import defaultdict

from beliefs.models.beliefupdate.Node import Node
from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD


class EdgesHelper:
    """Class with convenience methods for working with edges."""
    def __init__(self, edges):
        self.edges = edges

    def get_label_to_children_dict(self):
        """returns dictionary keyed on label, with value a set of children"""
        label_to_children_dict = defaultdict(set)
        for parent, child in self.edges:
            label_to_children_dict[parent].add(child)
        return label_to_children_dict

    def get_label_to_parents_dict(self):
        """returns dictionary keyed on label, with value a set of parents
        Only used to help create dummy factors from edges (not for algo).
        """
        label_to_parents_dict = defaultdict(set)

        for parent, child in self.edges:
            label_to_parents_dict[child].add(parent)
        return label_to_parents_dict

    def get_labels_from_edges(self):
        """Return the set of labels that comprise the vertices of a list of edge tuples."""
        all_labels = set()
        for parent, child in self.edges:
            all_labels.update({parent, child})
        return all_labels

    def create_cpds_from_edges(self, CPD=BernoulliOrCPD):
        """
        Create factors from list of edges.

        Input:
          cpd: a factor class, assumed initialization takes in a label_id, the label_id of
               the child (should = label_id of the node), and set of label_ids of parents.

        Returns:
          factors: a set of (unique) factors of the graph
        """
        labels = self.get_labels_from_edges()
        label_to_parents = self.get_label_to_parents_dict()

        factors = set()

        for label in labels:
            parents = label_to_parents[label]
            cpd = CPD(label, parents)
            factors.add(cpd)
        return factors

    def get_label_to_factor_dict(self, CPD=BernoulliOrCPD):
        """Create a dictionary mapping each label_id to the CPD/factor where
        that label_id is a child.

        Returns:
          label_to_factor: dict mapping each label to the cpd that
                          has that label as a child.
        """
        factors = self.create_cpds_from_edges(CPD=CPD)

        label_to_factor = dict()
        for factor in factors:
            label_to_factor[factor.child] = factor
        return label_to_factor

    def get_label_to_node_dict(self, CPD=BernoulliOrCPD):
        """Create a dictionary mapping each label_id to a Node instance.

        Returns:
          label_to_node: dict mapping each label to the node that has that
                         label as a label_id.
        """
        nodes = self.create_nodes_from_edges()

        label_to_node = dict()
        for node in nodes:
            label_to_node[node.label_id] = node
        return label_to_node

    def get_label_to_node_dict_for_manual_cpds(self, cpds_list):
        """Create a dictionary mapping each label_id to a node that is
        instantiated with a manually defined pgmpy factor instance.

        Input:
          cpds_list - list of instances of pgmpy factors, e.g. TabularCPD

        Returns:
          label_to_node: dict mapping each label to the node that has that
                         label as a label_id.
        """
        label_to_children = self.get_label_to_children_dict()
        label_to_parents = self.get_label_to_parents_dict()

        label_to_node = dict()
        for cpd in cpds_list:
            label_id = cpd.variable

            node = Node(label_id=label_id,
                        children=label_to_children[label_id],
                        parents=label_to_parents[label_id],
                        cardinality=2,
                        cpd=cpd)
            label_to_node[label_id] = node

        return label_to_node

    def create_nodes_from_edges(self, node_class):
        """
        Create instances of the node_class.  Assumes the node class is
        initialized by label_id, children, and parents.

        Returns:
          nodes: a set of (unique) nodes of the graph
        """
        labels = self.get_labels_from_edges()
        labels_to_parents = self.get_label_to_parents_dict()
        labels_to_children = self.get_label_to_children_dict()

        nodes = set()

        for label in labels:
            parents = list(labels_to_parents[label])
            children = list(labels_to_children[label])

            node = node_class(label_id=label,
                              children=children,
                              parents=parents)
            nodes.add(node)
        return nodes