diff options
author | Cathy Yeh <cathy@driver.xyz> | 2017-12-13 18:47:32 -0800 |
---|---|---|
committer | Cathy Yeh <cathy@driver.xyz> | 2017-12-14 15:39:52 -0800 |
commit | 7053fefc6f9e43b1e252d1f551401a7a70b52e93 (patch) | |
tree | 2b232454fbc71f4acc877cbf03f61565bae88b98 /beliefs/factors | |
parent | 10f5c49ea6767f54d59f88eb4064bb4959d14c6b (diff) | |
download | beliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.tar.gz beliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.tar.bz2 beliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.zip |
cleanup print statements, stale comments, minor TODOs
Diffstat (limited to 'beliefs/factors')
-rw-r--r-- | beliefs/factors/bernoulli_and_cpd.py | 7 | ||||
-rw-r--r-- | beliefs/factors/bernoulli_or_cpd.py | 7 | ||||
-rw-r--r-- | beliefs/factors/cpd.py | 29 | ||||
-rw-r--r-- | beliefs/factors/discrete_factor.py | 32 |
4 files changed, 37 insertions, 38 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py index 15802c2..291398f 100644 --- a/beliefs/factors/bernoulli_and_cpd.py +++ b/beliefs/factors/bernoulli_and_cpd.py @@ -12,9 +12,10 @@ class BernoulliAndCPD(TabularCPD): """ def __init__(self, variable, parents=[]): """ - Args: - variable: int or string - parents: optional, list of int and/or strings + Args + variable: int or string + parents: list, + (optional) list of int and/or strings """ super().__init__(variable=variable, variable_card=2, diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py index 5b661a1..b5e6ae5 100644 --- a/beliefs/factors/bernoulli_or_cpd.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -12,9 +12,10 @@ class BernoulliOrCPD(TabularCPD): """ def __init__(self, variable, parents=[]): """ - Args: - variable: int or string - parents: optional, list of int and/or strings + Args + variable: int or string + parents: list, + (optional) list of int and/or strings """ super().__init__(variable=variable, variable_card=2, diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py index 9e7191f..c7883c9 100644 --- a/beliefs/factors/cpd.py +++ b/beliefs/factors/cpd.py @@ -1,3 +1,4 @@ +import copy import numpy as np from beliefs.factors.discrete_factor import DiscreteFactor @@ -7,16 +8,18 @@ class TabularCPD(DiscreteFactor): Defines the conditional probability table for a discrete variable whose parents are also discrete. """ - def __init__(self, variable, variable_card, - parents=[], parents_card=[], values=[], state_names=None): + def __init__(self, variable, variable_card, parents=[], parents_card=[], + values=[], state_names=None): """ - Args: - variable: int or string - variable_card: int - parents: optional, list of int and/or strings - parents_card: optional, list of int - values: optional, 2d list or array - state_names: dictionary (optional), + Args + variable: int or string + variable_card: int + parents: list, + (optional) list of int and/or strings + parents_card: list, + (optional) list of int + values: 2-d list or array (optional) + state_names: dictionary (optional), mapping variables to their states, of format {label_name: ['state1', 'state2']} """ super().__init__(variables=[variable] + parents, @@ -24,7 +27,7 @@ class TabularCPD(DiscreteFactor): values=values, state_names=state_names) self.variable = variable - self.parents = parents + self.parents = list(parents) def get_values(self): """ @@ -36,8 +39,4 @@ class TabularCPD(DiscreteFactor): return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:])) def copy(self): - return self.__class__(self.variable, - self.cardinality[0], - self.parents, - self.cardinality[1:], - self._values) + return copy.deepcopy(self) diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py index b75da28..708f00c 100644 --- a/beliefs/factors/discrete_factor.py +++ b/beliefs/factors/discrete_factor.py @@ -18,7 +18,7 @@ class DiscreteFactor: mapping variables to their states, of format {label_name: ['state1', 'state2']} """ self.variables = list(variables) - self.cardinality = cardinality + self.cardinality = list(cardinality) if values is None: self._values = None else: @@ -28,6 +28,13 @@ class DiscreteFactor: def __mul__(self, other): return self.product(other) + def copy(self): + """Return a copy of the factor""" + return self.__class__(self.variables, + self.cardinality, + self._values, + copy.deepcopy(self.state_names)) + @property def values(self): return self._values @@ -56,7 +63,7 @@ class DiscreteFactor: return self.values[tuple(state_coordinates)] def add_new_variables_from_other_factor(self, other): - """Add new variables to the factor.""" + """Add new variables from `other` factor to the factor.""" extra_vars = set(other.variables) - set(self.variables) # if all of these variables already exist there is nothing to do if len(extra_vars) == 0: @@ -69,33 +76,24 @@ class DiscreteFactor: new_card_var = other.get_cardinality(extra_vars) self.cardinality.extend([new_card_var[var] for var in extra_vars]) - return def get_cardinality(self, variables): return {var: self.cardinality[self.variables.index(var)] for var in variables} def product(self, other): - left = copy.deepcopy(self) + left = self.copy() if isinstance(other, (int, float)): - # TODO: handle case of multiplication by constant - pass + return self.values * other else: - # assert right is a class or subclass of DiscreteFactor - # that has attributes: variables, values; method: get_cardinality - right = copy.deepcopy(other) + assert isinstance(other, DiscreteFactor), \ + "__mul__ is only defined between subclasses of DiscreteFactor" + right = other.copy() left.add_new_variables_from_other_factor(right) right.add_new_variables_from_other_factor(left) - print('var', left.variables) - print(left.cardinality) - print(left.values) - print('var', right.variables) - print(right.cardinality) - print(right.values) # reorder variables in right factor to match order in left source_axes = list(range(right.values.ndim)) - print('source_axes', source_axes) destination_axes = [right.variables.index(var) for var in left.variables] right.variables = [right.variables[idx] for idx in destination_axes] @@ -110,7 +108,7 @@ class DiscreteFactor: vars: list, variables over which to marginalize the factor Returns - DiscreteFactor + DiscreteFactor, whose scope is set(self.variables) - set(vars) """ phi = copy.deepcopy(self) |