import copy
import numpy as np
class DiscreteFactor:
def __init__(self, variables, cardinality, values=None, state_names=None):
"""
Args
variables: list,
variables in the scope of the factor
cardinality: list,
cardinalities of each variable, where len(cardinality)=len(variables)
values: list,
row vector of values of variables with ordering such that right-most variables
defined in `variables` cycle through their values the fastest
state_names: dictionary,
mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
self.variables = list(variables)
self.cardinality = list(cardinality)
if values is None:
self._values = None
else:
self._values = np.array(values).reshape(self.cardinality)
self.state_names = state_names
def __mul__(self, other):
return self.product(other)
def copy(self):
"""Return a copy of the factor"""
return self.__class__(self.variables,
self.cardinality,
self._values,
copy.deepcopy(self.state_names))
@property
def values(self):
return self._values
def update_values(self, new_values):
"""We make this available because _values is allowed to be None on init"""
self._values = np.array(new_values).reshape(self.cardinality)
def get_value_for_state_vector(self, dict_of_states):
"""
Return the value for a dictionary of variable states.
Args
dict_of_states: dictionary,
of format {label_name1: 'state1', label_name2: 'True'}
Returns
probability, a float, the factor value for a specific combination of variable states
"""
assert sorted(dict_of_states.keys()) == sorted(self.variables), \
"The keys for the dictionary of states must match the variables in factor scope."
state_coordinates = []
for var in self.variables:
var_state = dict_of_states[var]
idx_in_var_axis = self.state_names[var].index(var_state)
state_coordinates.append(idx_in_var_axis)
return self.values[tuple(state_coordinates)]
def add_new_variables_from_other_factor(self, other):
"""Add new variables from `other` factor to the factor."""
extra_vars = set(other.variables) - set(self.variables)
# if all of these variables already exist there is nothing to do
if len(extra_vars) == 0:
return
# otherwise, extend the values array
slice_ = [slice(None)] * len(self.variables)
slice_.extend([np.newaxis] * len(extra_vars))
self._values = self._values[slice_]
self.variables.extend(extra_vars)
new_card_var = other.get_cardinality(extra_vars)
self.cardinality.extend([new_card_var[var] for var in extra_vars])
def get_cardinality(self, variables):
return {var: self.cardinality[self.variables.index(var)] for var in variables}
def product(self, other):
left = self.copy()
if isinstance(other, (int, float)):
return self.values * other
else:
assert isinstance(other, DiscreteFactor), \
"__mul__ is only defined between subclasses of DiscreteFactor"
right = other.copy()
left.add_new_variables_from_other_factor(right)
right.add_new_variables_from_other_factor(left)
# reorder variables in right factor to match order in left
source_axes = list(range(right.values.ndim))
destination_axes = [right.variables.index(var) for var in left.variables]
right.variables = [right.variables[idx] for idx in destination_axes]
# rearrange values in right factor to correspond to the reordered variables
right._values = np.moveaxis(right.values, source_axes, destination_axes)
left._values = left.values * right.values
return left
def marginalize(self, vars):
"""
Args
vars: list,
variables over which to marginalize the factor
Returns
DiscreteFactor, whose scope is set(self.variables) - set(vars)
"""
phi = copy.deepcopy(self)
var_indexes = []
for var in vars:
if var not in phi.variables:
raise ValueError('{} not in scope'.format(var))
else:
var_indexes.append(self.variables.index(var))
index_to_keep = sorted(set(range(len(self.variables))) - set(var_indexes))
phi.variables = [self.variables[index] for index in index_to_keep]
phi.cardinality = [self.cardinality[index] for index in index_to_keep]
phi._values = np.sum(phi.values, axis=tuple(var_indexes))
return phi