aboutsummaryrefslogtreecommitdiff
path: root/beliefs/factors/discrete_factor.py
diff options
context:
space:
mode:
Diffstat (limited to 'beliefs/factors/discrete_factor.py')
-rw-r--r--beliefs/factors/discrete_factor.py32
1 files changed, 15 insertions, 17 deletions
diff --git a/beliefs/factors/discrete_factor.py b/beliefs/factors/discrete_factor.py
index b75da28..708f00c 100644
--- a/beliefs/factors/discrete_factor.py
+++ b/beliefs/factors/discrete_factor.py
@@ -18,7 +18,7 @@ class DiscreteFactor:
mapping variables to their states, of format {label_name: ['state1', 'state2']}
"""
self.variables = list(variables)
- self.cardinality = cardinality
+ self.cardinality = list(cardinality)
if values is None:
self._values = None
else:
@@ -28,6 +28,13 @@ class DiscreteFactor:
def __mul__(self, other):
return self.product(other)
+ def copy(self):
+ """Return a copy of the factor"""
+ return self.__class__(self.variables,
+ self.cardinality,
+ self._values,
+ copy.deepcopy(self.state_names))
+
@property
def values(self):
return self._values
@@ -56,7 +63,7 @@ class DiscreteFactor:
return self.values[tuple(state_coordinates)]
def add_new_variables_from_other_factor(self, other):
- """Add new variables to the factor."""
+ """Add new variables from `other` factor to the factor."""
extra_vars = set(other.variables) - set(self.variables)
# if all of these variables already exist there is nothing to do
if len(extra_vars) == 0:
@@ -69,33 +76,24 @@ class DiscreteFactor:
new_card_var = other.get_cardinality(extra_vars)
self.cardinality.extend([new_card_var[var] for var in extra_vars])
- return
def get_cardinality(self, variables):
return {var: self.cardinality[self.variables.index(var)] for var in variables}
def product(self, other):
- left = copy.deepcopy(self)
+ left = self.copy()
if isinstance(other, (int, float)):
- # TODO: handle case of multiplication by constant
- pass
+ return self.values * other
else:
- # assert right is a class or subclass of DiscreteFactor
- # that has attributes: variables, values; method: get_cardinality
- right = copy.deepcopy(other)
+ assert isinstance(other, DiscreteFactor), \
+ "__mul__ is only defined between subclasses of DiscreteFactor"
+ right = other.copy()
left.add_new_variables_from_other_factor(right)
right.add_new_variables_from_other_factor(left)
- print('var', left.variables)
- print(left.cardinality)
- print(left.values)
- print('var', right.variables)
- print(right.cardinality)
- print(right.values)
# reorder variables in right factor to match order in left
source_axes = list(range(right.values.ndim))
- print('source_axes', source_axes)
destination_axes = [right.variables.index(var) for var in left.variables]
right.variables = [right.variables[idx] for idx in destination_axes]
@@ -110,7 +108,7 @@ class DiscreteFactor:
vars: list,
variables over which to marginalize the factor
Returns
- DiscreteFactor
+ DiscreteFactor, whose scope is set(self.variables) - set(vars)
"""
phi = copy.deepcopy(self)