aboutsummaryrefslogtreecommitdiff
path: root/beliefs/utils/random_variables.py
blob: cad07aa1e022d3dbb6b3c8f568d9f38aff06c1c8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Utilities for working with models and random variables."""


def get_reachable_observed_variables_for_inferred_variables(model, observed=set()):
    """
    After performing inference on a BayesianModel, get the labels of observed variables
    ("reachable observed variables") that influenced the beliefs of variables inferred
    to be in a definite state.

    Args
        model: instance of BayesianModel class or subclass
        observed: set,
            set of labels (strings) corresponding to variables pinned to a definite
            state during inference.
    Returns
        dict,
            key, value pairs {source_label_id: reachable_observed_vars}, where
            source_label_id is an int or string, and reachable_observed_vars is a list
            of label_ids
    """
    if not observed:
        return {}

    source_vars = model.get_unobserved_variables_in_definite_state(observed)

    return {var: model.reachable_observed_variables(var, observed) for var in source_vars}