aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-13 18:47:32 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-14 15:39:52 -0800
commit7053fefc6f9e43b1e252d1f551401a7a70b52e93 (patch)
tree2b232454fbc71f4acc877cbf03f61565bae88b98 /beliefs/factors
parent10f5c49ea6767f54d59f88eb4064bb4959d14c6b (diff)
downloadbeliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.tar.gz
beliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.tar.bz2
beliefs-7053fefc6f9e43b1e252d1f551401a7a70b52e93.zip
cleanup print statements, stale comments, minor TODOs
Diffstat (limited to 'beliefs/factors')
-rw-r--r--beliefs/factors/bernoulli_and_cpd.py7
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py7
-rw-r--r--beliefs/factors/cpd.py29
-rw-r--r--beliefs/factors/discrete_factor.py32
4 files changed, 37 insertions, 38 deletions
diff --git a/beliefs/factors/bernoulli_and_cpd.py b/beliefs/factors/bernoulli_and_cpd.py
index 15802c2..291398f 100644
--- a/beliefs/factors/bernoulli_and_cpd.py
+++ b/beliefs/factors/bernoulli_and_cpd.py
@@ -12,9 +12,10 @@ class BernoulliAndCPD(TabularCPD):
"""
def __init__(self, variable, parents=[]):
"""
- Args:
- variable: int or string
- parents: optional, list of int and/or strings
+ Args
+ variable: int or string
+ parents: list,
+ (optional) list of int and/or strings
"""
super().__init__(variable=variable,
variable_card=2,
diff --git a/beliefs/factors/bernoulli_or_cpd.py b/beliefs/factors/bernoulli_or_cpd.py
index 5b661a1..b5e6ae5 100644
--- a/beliefs/factors/bernoulli_or_cpd.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -12,9 +12,10 @@ class BernoulliOrCPD(TabularCPD):
"""
def __init__(self, variable, parents=[]):
"""
- Args:
- variable: int or string
- parents: optional, list of int and/or strings
+ Args
+ variable: int or string
+ parents: list,
+ (optional) list of int and/or strings
"""
super().__init__(variable=variable,
variable_card=2,
diff --git a/beliefs/factors/cpd.py b/beliefs/factors/cpd.py
index 9e7191f..c7883c9 100644
--- a/beliefs/factors/cpd.py
+++ b/beliefs/factors/cpd.py
@@ -1,3 +1,4 @@
+import copy
import numpy as np
from beliefs.factors.discrete_factor import DiscreteFactor
@@ -7,16 +8,18 @@ class TabularCPD(DiscreteFactor):
Defines the conditional probability table for a discrete variable
whose parents are also discrete.
"""
- def __init__(self, variable, variable_card,
- parents=[], parents_card=[], values=[], state_names=None):
+ def __init__(self, variable, variable_card, parents=[], parents_card=[],
+ values=[], state_names=None):
"""
- Args:
- variable: int or string
- variable_card: int
- parents: optional, list of int and/or strings
- parents_card: optional, list of int
- values: optional, 2d list or array
- state_names: dictionary (optional),
+ Args
+ variable: int or string
+ variable_card: int
+ parents: list,
+ (optional) list of int and/or strings
+ parents_card: list,
+ (optional) list of int
+ values: 2-d list or array (optional)
+ state_names: dictionary (optional),
mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
super().__init__(variables=[variable] + parents,
@@ -24,7 +27,7 @@ class TabularCPD(DiscreteFactor):
values=values,
state_names=state_names)
self.variable = variable
- self.parents = parents
+ self.parents = list(parents)
def get_values(self):
"""
@@ -36,8 +39,4 @@ class TabularCPD(DiscreteFactor):
return self.values.reshape(self.cardinality[0], np.prod(self.cardinality[1:]))
def copy(self):
- return self.__class__(self.variable,
- self.cardinality[0],
- self.parents,
- self.cardinality[1:],
- self._values)
+ return copy.deepcopy(self)
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py
index b75da28..708f00c 100644
--- a/beliefs/factors/discrete_factor.py
+++ b/beliefs/factors/discrete_factor.py
@@ -18,7 +18,7 @@ class DiscreteFactor:
mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
self.variables = list(variables)
- self.cardinality = cardinality
+ self.cardinality = list(cardinality)
if values is None:
self._values = None
else:
@@ -28,6 +28,13 @@ class DiscreteFactor:
def __mul__(self, other):
return self.product(other)
+ def copy(self):
+ """Return a copy of the factor"""
+ return self.__class__(self.variables,
+ self.cardinality,
+ self._values,
+ copy.deepcopy(self.state_names))
+
@property
def values(self):
return self._values
@@ -56,7 +63,7 @@ class DiscreteFactor:
return self.values[tuple(state_coordinates)]
def add_new_variables_from_other_factor(self, other):
- """Add new variables to the factor."""
+ """Add new variables from `other` factor to the factor."""
extra_vars = set(other.variables) - set(self.variables)
# if all of these variables already exist there is nothing to do
if len(extra_vars) == 0:
@@ -69,33 +76,24 @@ class DiscreteFactor:
new_card_var = other.get_cardinality(extra_vars)
self.cardinality.extend([new_card_var[var] for var in extra_vars])
- return
def get_cardinality(self, variables):
return {var: self.cardinality[self.variables.index(var)] for var in variables}
def product(self, other):
- left = copy.deepcopy(self)
+ left = self.copy()
if isinstance(other, (int, float)):
- # TODO: handle case of multiplication by constant
- pass
+ return self.values * other
else:
- # assert right is a class or subclass of DiscreteFactor
- # that has attributes: variables, values; method: get_cardinality
- right = copy.deepcopy(other)
+ assert isinstance(other, DiscreteFactor), \
+ "__mul__ is only defined between subclasses of DiscreteFactor"
+ right = other.copy()
left.add_new_variables_from_other_factor(right)
right.add_new_variables_from_other_factor(left)
- print('var', left.variables)
- print(left.cardinality)
- print(left.values)
- print('var', right.variables)
- print(right.cardinality)
- print(right.values)
# reorder variables in right factor to match order in left
source_axes = list(range(right.values.ndim))
- print('source_axes', source_axes)
destination_axes = [right.variables.index(var) for var in left.variables]
right.variables = [right.variables[idx] for idx in destination_axes]
@@ -110,7 +108,7 @@ class DiscreteFactor:
vars: list,
variables over which to marginalize the factor
Returns
- DiscreteFactor
+ DiscreteFactor, whose scope is set(self.variables) - set(vars)
"""
phi = copy.deepcopy(self)