diff options
Diffstat (limited to 'beliefs/utils')
-rw-r--r-- | beliefs/utils/math_helper.py | 14 | ||||
-rw-r--r-- | beliefs/utils/random_variables.py | 17 |
2 files changed, 21 insertions, 10 deletions
diff --git a/beliefs/utils/math_helper.py b/beliefs/utils/math_helper.py index a25ea68..12325e1 100644 --- a/beliefs/utils/math_helper.py +++ b/beliefs/utils/math_helper.py @@ -1,10 +1,16 @@ -"""Random math utils.""" +"""Math utils""" def is_kronecker_delta(vector): - """Returns True if vector is a kronecker delta vector, False otherwise. - Specific evidence ('YES' or 'NO') is a kronecker delta vector, whereas - virtual evidence ('MAYBE') is not. + """ + Check if vector is a kronecker delta. + + Args: + vector: iterable of numbers + Returns: + bool, True if vector is a kronecker delta vector, False otherwise. + In belief propagation, specific evidence (variable is directly observed) + is a kronecker delta vector, but virtual evidence is not. """ count = 0 for x in vector: diff --git a/beliefs/utils/random_variables.py b/beliefs/utils/random_variables.py index 1a0b0f7..cad07aa 100644 --- a/beliefs/utils/random_variables.py +++ b/beliefs/utils/random_variables.py @@ -1,3 +1,4 @@ +"""Utilities for working with models and random variables.""" def get_reachable_observed_variables_for_inferred_variables(model, observed=set()): @@ -6,12 +7,16 @@ def get_reachable_observed_variables_for_inferred_variables(model, observed=set( ("reachable observed variables") that influenced the beliefs of variables inferred to be in a definite state. - INPUT - model: instance of BayesianModel class or subclass - observed: set of labels (strings) corresponding to vars pinned to definite - state during inference. - RETURNS - dict, of form key - source label (a string), value - a list of strings + Args + model: instance of BayesianModel class or subclass + observed: set, + set of labels (strings) corresponding to variables pinned to a definite + state during inference. + Returns + dict, + key, value pairs {source_label_id: reachable_observed_vars}, where + source_label_id is an int or string, and reachable_observed_vars is a list + of label_ids """ if not observed: return {} |