diff options
Diffstat (limited to 'beliefs/factors/discrete_factor.py')
-rw-r--r-- | beliefs/factors/discrete_factor.py | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index b75da28..708f00c 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -18,7 +18,7 @@ class DiscreteFactor: mapping variables to their states, of format {label_name: ['state1', 'state2']} """ self.variables = list(variables) - self.cardinality = cardinality + self.cardinality = list(cardinality) if values is None: self._values = None else: @@ -28,6 +28,13 @@ class DiscreteFactor: def __mul__(self, other): return self.product(other) + def copy(self): + """Return a copy of the factor""" + return self.__class__(self.variables, + self.cardinality, + self._values, + copy.deepcopy(self.state_names)) + @property def values(self): return self._values @@ -56,7 +63,7 @@ class DiscreteFactor: return self.values[tuple(state_coordinates)] def add_new_variables_from_other_factor(self, other): - """Add new variables to the factor.""" + """Add new variables from `other` factor to the factor.""" extra_vars = set(other.variables) - set(self.variables) # if all of these variables already exist there is nothing to do if len(extra_vars) == 0: @@ -69,33 +76,24 @@ class DiscreteFactor: new_card_var = other.get_cardinality(extra_vars) self.cardinality.extend([new_card_var[var] for var in extra_vars]) - return def get_cardinality(self, variables): return {var: self.cardinality[self.variables.index(var)] for var in variables} def product(self, other): - left = copy.deepcopy(self) + left = self.copy() if isinstance(other, (int, float)): - # TODO: handle case of multiplication by constant - pass + return self.values * other else: - # assert right is a class or subclass of DiscreteFactor - # that has attributes: variables, values; method: get_cardinality - right = copy.deepcopy(other) + assert isinstance(other, DiscreteFactor), \ + "__mul__ is only defined between subclasses of DiscreteFactor" + right = other.copy() left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) - print('var', left.variables) - print(left.cardinality) - print(left.values) - print('var', right.variables) - print(right.cardinality) - print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) - print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] @@ -110,7 +108,7 @@ class DiscreteFactor: vars: list, variables over which to marginalize the factor Returns - DiscreteFactor + DiscreteFactor, whose scope is set(self.variables) - set(vars) """ phi = copy.deepcopy(self) |