diff options
author | Cathy Yeh <cathy@driver.xyz> | 2018-01-18 21:57:50 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2018-01-18 21:57:50 -0800 |
commit | 2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e (patch) | |
tree | fc71d343eec17b59d8af81eb768e4fe2eab167c2 /beliefs/factors/discrete_factor.py | |
parent | 65d822247e30b6e104a8c09d3b930487b9f20a58 (diff) | |
parent | c93c352b2f68a2bbcde2241e61d9fb52504a67a9 (diff) | |
download | beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.tar.gz beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.tar.bz2 beliefs-2366e92bdb9c81bc2bd7132a00ed5c16a5160c5e.zip |
Merge branch 'generic_discrete_factor'. Implements explicit discrete factor methodsv0.1.0
Diffstat (limited to 'beliefs/factors/discrete_factor.py')
-rw-r--r-- | beliefs/factors/discrete_factor.py | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py new file mode 100644 index 0000000..708f00c --- /dev/null +++ b/beliefs/factors/discrete_factor.py @@ -0,0 +1,126 @@ +import copy +import numpy as np + + +class DiscreteFactor: + + def __init__(self, variables, cardinality, values=None, state_names=None): + """ + Args + variables: list, + variables in the scope of the factor + cardinality: list, + cardinalities of each variable, where len(cardinality)=len(variables) + values: list, + row vector of values of variables with ordering such that right-most variables + defined in `variables` cycle through their values the fastest + state_names: dictionary, + mapping variables to their states, of format {label_name: ['state1', 'state2']} + """ + self.variables = list(variables) + self.cardinality = list(cardinality) + if values is None: + self._values = None + else: + self._values = np.array(values).reshape(self.cardinality) + self.state_names = state_names + + def __mul__(self, other): + return self.product(other) + + def copy(self): + """Return a copy of the factor""" + return self.__class__(self.variables, + self.cardinality, + self._values, + copy.deepcopy(self.state_names)) + + @property + def values(self): + return self._values + + def update_values(self, new_values): + """We make this available because _values is allowed to be None on init""" + self._values = np.array(new_values).reshape(self.cardinality) + + def get_value_for_state_vector(self, dict_of_states): + """ + Return the value for a dictionary of variable states. + + Args + dict_of_states: dictionary, + of format {label_name1: 'state1', label_name2: 'True'} + Returns + probability, a float, the factor value for a specific combination of variable states + """ + assert sorted(dict_of_states.keys()) == sorted(self.variables), \ + "The keys for the dictionary of states must match the variables in factor scope." + state_coordinates = [] + for var in self.variables: + var_state = dict_of_states[var] + idx_in_var_axis = self.state_names[var].index(var_state) + state_coordinates.append(idx_in_var_axis) + return self.values[tuple(state_coordinates)] + + def add_new_variables_from_other_factor(self, other): + """Add new variables from `other` factor to the factor.""" + extra_vars = set(other.variables) - set(self.variables) + # if all of these variables already exist there is nothing to do + if len(extra_vars) == 0: + return + # otherwise, extend the values array + slice_ = [slice(None)] * len(self.variables) + slice_.extend([np.newaxis] * len(extra_vars)) + self._values = self._values[slice_] + self.variables.extend(extra_vars) + + new_card_var = other.get_cardinality(extra_vars) + self.cardinality.extend([new_card_var[var] for var in extra_vars]) + + def get_cardinality(self, variables): + return {var: self.cardinality[self.variables.index(var)] for var in variables} + + def product(self, other): + left = self.copy() + + if isinstance(other, (int, float)): + return self.values * other + else: + assert isinstance(other, DiscreteFactor), \ + "__mul__ is only defined between subclasses of DiscreteFactor" + right = other.copy() + left.add_new_variables_from_other_factor(right) + right.add_new_variables_from_other_factor(left) + + # reorder variables in right factor to match order in left + source_axes = list(range(right.values.ndim)) + destination_axes = [right.variables.index(var) for var in left.variables] + right.variables = [right.variables[idx] for idx in destination_axes] + + # rearrange values in right factor to correspond to the reordered variables + right._values = np.moveaxis(right.values, source_axes, destination_axes) + left._values = left.values * right.values + return left + + def marginalize(self, vars): + """ + Args + vars: list, + variables over which to marginalize the factor + Returns + DiscreteFactor, whose scope is set(self.variables) - set(vars) + """ + phi = copy.deepcopy(self) + + var_indexes = [] + for var in vars: + if var not in phi.variables: + raise ValueError('{} not in scope'.format(var)) + else: + var_indexes.append(self.variables.index(var)) + + index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes)) + phi.variables = [self.variables[index] for index in index_to_keep] + phi.cardinality = [self.cardinality[index] for index in index_to_keep] + phi._values = np.sum(phi.values, axis=tuple(var_indexes)) + return phi |