aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/discrete_factor.py
blob: da8e6bf411766ecf1feecaa45b4e0b119fe6827d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import copy
import numpy as np


class DiscreteFactor:

    def __init__(self, variables, cardinality, values=None, state_names=None):
        """
        Args
            variables: list,
                variables in the scope of the factor
            cardinality: list,
                cardinalities of each variable, where len(cardinality)=len(variables)
            values: list,
                row vector of values of variables with ordering such that right-most variables
                defined in `variables` cycle through their values the fastest
            state_names: dictionary,
                mapping variables to their states, of format {label_name: ['state1', 'state2']}
        """
        self.variables = list(variables)
        self.cardinality = cardinality
        if values is None:
            self._values = None
        else:
            self._values = np.array(values).reshape(self.cardinality)
        self.state_names = state_names

    def __mul__(self, other):
        return self.product(other)

    @property
    def values(self):
        return self._values

    def update_values(self, new_values):
        """We make this available because _values is allowed to be None on init"""
        self._values = np.array(new_values).reshape(self.cardinality)

    def get_value_for_state_vector(self, dict_of_states):
        """
        Return the value for a dictionary of variable states.

        Args
            dict_of_states: dictionary,
                of format {label_name1: 'state1', label_name2: 'True'}
        Returns
            probability, a float, the factor value for a specific combination of variable states
        """
        assert sorted(dict_of_states.keys()) == sorted(self.variables), \
            "The keys for the dictionary of states must match the variables in factor scope."
        state_coordinates = []
        for var in self.variables:
            var_state = dict_of_states[var]
            idx_in_var_axis = self.state_names[var].index(var_state)
            state_coordinates.append(idx_in_var_axis)
        return self.values[tuple(state_coordinates)]

    def add_new_variables_from_other_factor(self, other):
        """Add new variables to the factor."""
        extra_vars = set(other.variables) - set(self.variables)
        # if all of these variables already exist there is nothing to do
        if len(extra_vars) == 0:
            return
        # otherwise, extend the values array
        slice_ = [slice(None)] * len(self.variables)
        slice_.extend([np.newaxis] * len(extra_vars))
        self._values = self._values[slice_]
        self.variables.extend(extra_vars)

        new_card_var = other.get_cardinality(extra_vars)
        self.cardinality.extend([new_card_var[var] for var in extra_vars])
        return

    def get_cardinality(self, variables):
        return {var: self.cardinality[self.variables.index(var)] for var in variables}

    def product(self, other):
        left = copy.deepcopy(self)

        if isinstance(other, (int, float)):
            # TODO: handle case of multiplication by constant
            pass
        else:
            # assert right is a class or subclass of DiscreteFactor
            # that has attributes: variables, values; method: get_cardinality
            right = copy.deepcopy(other)
            left.add_new_variables_from_other_factor(right)
            right.add_new_variables_from_other_factor(left)

        # reorder variables in right factor to match order in left
        source_axes = list(range(right.values.ndim))
        destination_axes = [right.variables.index(var) for var in left.variables]
        right.variables = [right.variables[idx] for idx in destination_axes]

        # rearrange values in right factor to correspond to the reordered variables
        right._values = np.moveaxis(right.values, source_axes, destination_axes)
        left._values = left.values * right.values
        return left

    def marginalize(self, vars):
        """
        Args
            vars: list,
                variables over which to marginalize the factor
        Returns
            DiscreteFactor
        """
        phi = copy.deepcopy(self)

        var_indexes = []
        for var in vars:
            if var not in phi.variables:
                raise ValueError('{} not in scope'.format(var))
            else:
                var_indexes.append(self.variables.index(var))

        index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes))
        phi.variables = [self.variables[index] for index in index_to_keep]
        phi.cardinality = [self.cardinality[index] for index in index_to_keep]
        phi._values = np.sum(phi.values, axis=tuple(var_indexes))
        return phi