diff options
-rw-r--r-- | beliefs/factors/bernoulli_or_cpd.py (renamed from beliefs/factors/BernoulliOrCPD.py) | 2 | ||||
-rw-r--r-- | beliefs/factors/cpd.py (renamed from beliefs/factors/CPD.py) | 0 | ||||
-rw-r--r-- | beliefs/inference/belief_propagation.py | 44 | ||||
-rw-r--r-- | beliefs/models/DirectedGraph.py | 36 | ||||
-rw-r--r-- | beliefs/models/base_models.py (renamed from beliefs/models/BayesianModel.py) | 43 | ||||
-rw-r--r-- | beliefs/models/belief_update_node_model.py (renamed from beliefs/models/beliefupdate/Node.py) | 158 | ||||
-rw-r--r-- | beliefs/models/beliefupdate/BeliefUpdateNodeModel.py | 91 | ||||
-rw-r--r-- | beliefs/models/beliefupdate/BernoulliOrNode.py | 47 | ||||
-rw-r--r-- | beliefs/utils/edges_helper.py | 136 | ||||
-rw-r--r-- | tests/test_belief_propagation.py | 12 |
10 files changed, 219 insertions, 350 deletions
diff --git a/beliefs/factors/BernoulliOrCPD.py b/beliefs/factors/bernoulli_or_cpd.py index 2c6a31e..bfb3a95 100644 --- a/beliefs/factors/BernoulliOrCPD.py +++ b/beliefs/factors/bernoulli_or_cpd.py @@ -1,6 +1,6 @@ import numpy as np -from beliefs.factors.CPD import TabularCPD +from beliefs.factors.cpd import TabularCPD class BernoulliOrCPD(TabularCPD): diff --git a/beliefs/factors/CPD.py b/beliefs/factors/cpd.py index a286aaa..a286aaa 100644 --- a/beliefs/factors/CPD.py +++ b/beliefs/factors/cpd.py diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py index 02f5595..7ec648d 100644 --- a/beliefs/inference/belief_propagation.py +++ b/beliefs/inference/belief_propagation.py @@ -1,11 +1,17 @@ import numpy as np from collections import namedtuple +import logging -from beliefs.models.beliefupdate.Node import InvalidLambdaMsgToParent -from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel +from beliefs.models.belief_update_node_model import ( + InvalidLambdaMsgToParent, + BeliefUpdateNodeModel +) from beliefs.utils.math_helper import is_kronecker_delta +logger = logging.getLogger(__name__) + + MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender']) @@ -51,7 +57,7 @@ class BeliefPropagation: return node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop() - print("Node", node_to_update_label_id) + logging.info("Node: %s", node_to_update_label_id) node = self.model.nodes_dict[node_to_update_label_id] @@ -59,8 +65,8 @@ class BeliefPropagation: # outgoing msg from the node to update parent_ids = set(node.parents) - set([msg_sender_label_id]) child_ids = set(node.children) - set([msg_sender_label_id]) - print("parent_ids:", parent_ids) - print("child_ids:", child_ids) + logging.info("parent_ids: %s", str(parent_ids)) + logging.info("child_ids: %s", str(child_ids)) if msg_sender_label_id is not None: # update triggered by receiving a message, not pinning to evidence @@ -68,9 +74,9 @@ class BeliefPropagation: if node_to_update_label_id not in evidence: node.compute_pi_agg() - print("belief propagation pi_agg", node.pi_agg) + logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg)) node.compute_lambda_agg() - print("belief propagation lambda_agg", node.lambda_agg) + logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg)) for parent_id in parent_ids: try: @@ -114,13 +120,13 @@ class BeliefPropagation: for child in node.lambda_received_msgs.keys(): node.update_lambda_msg_from_child(child=child, new_value=ones_vector) - print("Finished initializing Lambda(x) and lambda_received_msgs per node.") + logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.") - print("Start downward sweep from nodes. Sending Pi messages only.") + logging.info("Start downward sweep from nodes. Sending Pi messages only.") topdown_order = self.model.get_topologically_sorted_nodes(reverse=False) for node_id in topdown_order: - print('label in iteration through top-down order:', node_id) + logging.info('label in iteration through top-down order: %s', str(node_id)) node_sending_msg = self.model.nodes_dict[node_id] child_ids = node_sending_msg.children @@ -129,9 +135,9 @@ class BeliefPropagation: node_sending_msg.compute_pi_agg() for child_id in child_ids: - print("child", child_id) + logging.info("child: %s", str(child_id)) new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id) - print(new_pi_msg) + logging.info("new_pi_msg: %s", np.array2string(new_pi_msg)) child_node = self.model.nodes_dict[child_id] child_node.update_pi_msg_from_parent(parent=node_id, @@ -158,10 +164,9 @@ class BeliefPropagation: self.model.nodes_dict[evidence_id].lambda_agg = \ self.model.nodes_dict[evidence_id].lambda_agg * observed_value - nodes_to_update.add(MsgPassers(msg_receiver=evidence_id, - msg_sender=None)) + nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)] - self._belief_propagation(nodes_to_update=nodes_to_update, + self._belief_propagation(nodes_to_update=set(nodes_to_update), evidence=evidence) def query(self, evidence={}): @@ -179,12 +184,13 @@ class BeliefPropagation: Example ------- - >> from label_graph_service.pgm.inference.belief_propagation import BeliefPropagation - >> from label_graph_service.pgm.models.BernoulliOrModel import BernoulliOrModel + >> import numpy as np + >> from beliefs.inference.belief_propagation import BeliefPropagation + >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode >> edges = [('1', '3'), ('2', '3'), ('3', '5')] - >> model = BernoulliOrModel(edges) + >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode) >> infer = BeliefPropagation(model) - >> result = infer.query({'2': np.array([0, 1])}) + >> result = infer.query(evidence={'2': np.array([0, 1])}) """ if not self.model.all_nodes_are_fully_initialized: self.initialize_model() diff --git a/beliefs/models/DirectedGraph.py b/beliefs/models/DirectedGraph.py deleted file mode 100644 index 84b3a02..0000000 --- a/beliefs/models/DirectedGraph.py +++ /dev/null @@ -1,36 +0,0 @@ -import networkx as nx - - -class DirectedGraph(nx.DiGraph): - """ - Base class for all directed graphical models. - """ - def __init__(self, edges=None, node_labels=None): - """ - Input: - edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] - node_labels: a list of strings of node labels - """ - super().__init__() - if edges is not None: - self.add_edges_from(edges) - if node_labels is not None: - self.add_nodes_from(node_labels) - - def get_leaves(self): - """ - Returns a list of leaves of the graph. - """ - return [node for node, out_degree in self.out_degree() if out_degree == 0] - - def get_roots(self): - """ - Returns a list of roots of the graph. - """ - return [node for node, in_degree in self.in_degree() if in_degree == 0] - - def get_topologically_sorted_nodes(self, reverse=False): - if reverse: - return list(reversed(list(nx.topological_sort(self)))) - else: - return nx.topological_sort(self) diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/base_models.py index b57f968..cb91566 100644 --- a/beliefs/models/BayesianModel.py +++ b/beliefs/models/base_models.py @@ -1,10 +1,43 @@ -import copy import networkx as nx -from beliefs.models.DirectedGraph import DirectedGraph from beliefs.utils.math_helper import is_kronecker_delta +class DirectedGraph(nx.DiGraph): + """ + Base class for all directed graphical models. + """ + def __init__(self, edges=None, node_labels=None): + """ + Input: + edges: an edge list, e.g. [(parent1, child1), (parent1, child2)] + node_labels: a list of strings of node labels + """ + super().__init__() + if edges is not None: + self.add_edges_from(edges) + if node_labels is not None: + self.add_nodes_from(node_labels) + + def get_leaves(self): + """ + Returns a list of leaves of the graph. + """ + return [node for node, out_degree in self.out_degree() if out_degree == 0] + + def get_roots(self): + """ + Returns a list of roots of the graph. + """ + return [node for node, in_degree in self.in_degree() if in_degree == 0] + + def get_topologically_sorted_nodes(self, reverse=False): + if reverse: + return list(reversed(list(nx.topological_sort(self)))) + else: + return nx.topological_sort(self) + + class BayesianModel(DirectedGraph): """ Bayesian model stores nodes and edges described by conditional probability @@ -69,8 +102,8 @@ class BayesianModel(DirectedGraph): return vars_in_definite_state - observed def _get_ancestors_of(self, observed): - """Return list of ancestors of observed labels, including the observed labels themselves.""" - ancestors = observed.copy() + """Return list of ancestors of observed labels""" + ancestors = set() for label in observed: ancestors.update(nx.ancestors(self, label)) return ancestors @@ -87,7 +120,9 @@ class BayesianModel(DirectedGraph): reachable_observed_vars: set of strings, observed labels (variables with direct evidence) that are reachable from the source label. """ + # ancestors of observed labels, including observed labels ancestors_of_observed = self._get_ancestors_of(observed) + ancestors_of_observed.update(observed) visit_list = set() visit_list.add((source, 'up')) diff --git a/beliefs/models/beliefupdate/Node.py b/beliefs/models/belief_update_node_model.py index daa2f14..667e0f1 100644 --- a/beliefs/models/beliefupdate/Node.py +++ b/beliefs/models/belief_update_node_model.py @@ -1,6 +1,13 @@ +import copy +from enum import Enum import numpy as np +import itertools from functools import reduce -from enum import Enum + +import networkx as nx + +from beliefs.models.base_models import BayesianModel +from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD class InvalidLambdaMsgToParent(Exception): @@ -13,6 +20,98 @@ class MessageType(Enum): PI = 'pi' +class BeliefUpdateNodeModel(BayesianModel): + """ + A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties + and methods for Pearl's belief update algorithm. + + ref: "Fusion, Propagation, and Structuring in Belief Networks" + Artificial Intelligence 29 (1986) 241-288 + + """ + def __init__(self, nodes_dict): + """ + Input: + nodes_dict: dict + a dict key, value pair as {label_id: instance_of_node_class_or_subclass} + """ + super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), + variables=list(nodes_dict.keys()), + cpds=[node.cpd for node in nodes_dict.values()]) + + self.nodes_dict = nodes_dict + + @classmethod + def init_from_edges(cls, edges, node_class): + """Create nodes from the same node class. + + Input: + edges: list of edge tuples of form ('parent', 'child') + node_class: the Node class or subclass from which to + create all the nodes from edges. + """ + nodes = set() + g = nx.DiGraph(edges) + + for label in set(itertools.chain(*edges)): + node = node_class(label_id=label, + children=list(g.successors(label)), + parents=list(g.predecessors(label))) + nodes.add(node) + nodes_dict = {node.label_id: node for node in nodes} + return cls(nodes_dict) + + @staticmethod + def _get_edges_from_nodes(nodes): + """ + Return list of all directed edges in nodes. + + Args: + nodes: an iterable of objects of the Node class or subclass + Returns: + edges: list of edge tuples + """ + edges = set() + for node in nodes: + if node.parents: + edge_tuples = zip(node.parents, [node.label_id]*len(node.parents)) + edges.update(edge_tuples) + return list(edges) + + def set_boundary_conditions(self): + """ + 1. Root nodes: if x is a node with no parents, set Pi(x) = prior + probability of x. + + 2. Leaf nodes: if x is a node with no children, set Lambda(x) + to an (unnormalized) unit vector, of length the cardinality of x. + """ + for root in self.get_roots(): + self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values + + for leaf in self.get_leaves(): + self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) + + @property + def all_nodes_are_fully_initialized(self): + """ + Returns True if, for all nodes in the model, all lambda and pi + messages and lambda_agg and pi_agg are not None, else False. + """ + for node in self.nodes_dict.values(): + if not node.is_fully_initialized: + return False + return True + + def copy(self): + """ + Returns a copy of the model. + """ + copy_nodes = copy.deepcopy(self.nodes_dict) + copy_model = self.__class__(nodes_dict=copy_nodes) + return copy_model + + class Node: """A node in a DAG with methods to compute the belief (marginal probability of the node given evidence) and compute pi/lambda messages to/from its neighbors @@ -102,8 +201,8 @@ class Node: if any(msg is None for msg in msg_values): raise ValueError( - "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.". - format(msg_type=message_type.value) + "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg." + .format(msg_type=message_type.value) ) else: return msg_values @@ -122,16 +221,16 @@ class Node: def _update_received_msg_by_key(self, received_msg_dict, key, new_value): if key not in received_msg_dict.keys(): - raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}". - format(key, received_msg_dict.keys())) + raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}" + .format(key, received_msg_dict.keys())) if not isinstance(new_value, np.ndarray): - raise TypeError("Expected a new value of type numpy.ndarray, but got type {}". - format(type(new_value))) + raise TypeError("Expected a new value of type numpy.ndarray, but got type {}" + .format(type(new_value))) if new_value.shape != (self.cardinality,): - raise ValueError("Expected new value to be of dimensions ({},) but got {} instead". - format(self.cardinality, new_value.shape)) + raise ValueError("Expected new value to be of dimensions ({},) but got {} instead" + .format(self.cardinality, new_value.shape)) received_msg_dict[key] = new_value def update_pi_msg_from_parent(self, parent, new_value): @@ -152,8 +251,7 @@ class Node: return self._normalize( np.nan_to_num(np.divide(self.belief, lambda_msg_from_child))) else: - raise ValueError("Can't compute pi message to child_{} without having received" \ - "a lambda message from that child.") + raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.") def compute_lambda_msg_to_parent(self, parent_k): # TODO: implement explict factor product operation @@ -177,3 +275,41 @@ class Node: return False return True + + +class BernoulliOrNode(Node): + def __init__(self, + label_id, + children, + parents): + super().__init__(label_id=label_id, + children=children, + parents=parents, + cardinality=2, + cpd=BernoulliOrCPD(label_id, parents)) + + def compute_pi_agg(self): + if not self.parents: + self.pi_agg = self.cpd.values + else: + pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + parents_p0 = [p[0] for p in pi_msg_values] + p_0 = reduce(lambda x, y: x*y, parents_p0) + p_1 = 1 - p_0 + self.pi_agg = np.array([p_0, p_1]) + return self.pi_agg + + def compute_lambda_msg_to_parent(self, parent_k): + if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): + return np.ones([self.cardinality]) + else: + # TODO: cleanup this validation + _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) + p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] + p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) + lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product + lambda_1 = self.lambda_agg[1] + lambda_msg = np.array([lambda_0, lambda_1]) + if not any(lambda_msg): + raise InvalidLambdaMsgToParent + return self._normalize(lambda_msg) diff --git a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py deleted file mode 100644 index d74eaa7..0000000 --- a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py +++ /dev/null @@ -1,91 +0,0 @@ -import copy -import numpy as np - -from beliefs.models.BayesianModel import BayesianModel -from beliefs.utils.edges_helper import EdgesHelper - - -class BeliefUpdateNodeModel(BayesianModel): - """ - A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties - and methods for Pearl's belief update algorithm. - - ref: "Fusion, Propagation, and Structuring in Belief Networks" - Artificial Intelligence 29 (1986) 241-288 - - """ - def __init__(self, nodes_dict): - """ - Input: - nodes_dict: dict - a dict key, value pair as {label_id: instance_of_node_class_or_subclass} - """ - super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()), - variables=list(nodes_dict.keys()), - cpds=[node.cpd for node in nodes_dict.values()]) - - self.nodes_dict = nodes_dict - - @classmethod - def from_edges(cls, edges, node_class): - """Create nodes from the same node class. - - Input: - edges: list of edge tuples of form ('parent', 'child') - node_class: the Node class or subclass from which to - create all the nodes from edges. - """ - edges_helper = EdgesHelper(edges) - nodes = edges_helper.create_nodes_from_edges(node_class) - nodes_dict = {node.label_id: node for node in nodes} - return cls(nodes_dict) - - @staticmethod - def _get_edges_from_nodes(nodes): - """ - Return list of all directed edges in nodes. - - Args: - nodes: an iterable of objects of the Node class or subclass - Returns: - edges: list of edge tuples - """ - edges = set() - for node in nodes: - if node.parents: - edge_tuples = zip(node.parents, [node.label_id]*len(node.parents)) - edges.update(edge_tuples) - return list(edges) - - def set_boundary_conditions(self): - """ - 1. Root nodes: if x is a node with no parents, set Pi(x) = prior - probability of x. - - 2. Leaf nodes: if x is a node with no children, set Lambda(x) - to an (unnormalized) unit vector, of length the cardinality of x. - """ - for root in self.get_roots(): - self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values - - for leaf in self.get_leaves(): - self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality]) - - @property - def all_nodes_are_fully_initialized(self): - """ - Returns True if, for all nodes in the model, all lambda and pi - messages and lambda_agg and pi_agg are not None, else False. - """ - for node in self.nodes_dict.values(): - if not node.is_fully_initialized: - return False - return True - - def copy(self): - """ - Returns a copy of the model. - """ - copy_nodes = copy.deepcopy(self.nodes_dict) - copy_model = self.__class__(nodes_dict=copy_nodes) - return copy_model diff --git a/beliefs/models/beliefupdate/BernoulliOrNode.py b/beliefs/models/beliefupdate/BernoulliOrNode.py deleted file mode 100644 index 3386275..0000000 --- a/beliefs/models/beliefupdate/BernoulliOrNode.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np -from functools import reduce - -from beliefs.models.beliefupdate.Node import ( - Node, - MessageType, - InvalidLambdaMsgToParent -) -from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD - - -class BernoulliOrNode(Node): - def __init__(self, - label_id, - children, - parents): - super().__init__(label_id=label_id, - children=children, - parents=parents, - cardinality=2, - cpd=BernoulliOrCPD(label_id, parents)) - - def compute_pi_agg(self): - if not self.parents: - self.pi_agg = self.cpd.values - else: - pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - parents_p0 = [p[0] for p in pi_msg_values] - p_0 = reduce(lambda x, y: x*y, parents_p0) - p_1 = 1 - p_0 - self.pi_agg = np.array([p_0, p_1]) - return self.pi_agg - - def compute_lambda_msg_to_parent(self, parent_k): - if np.array_equal(self.lambda_agg, np.ones([self.cardinality])): - return np.ones([self.cardinality]) - else: - # TODO: cleanup this validation - _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI) - p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k] - p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1) - lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product - lambda_1 = self.lambda_agg[1] - lambda_msg = np.array([lambda_0, lambda_1]) - if not any(lambda_msg): - raise InvalidLambdaMsgToParent - return self._normalize(lambda_msg) diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py deleted file mode 100644 index 130686c..0000000 --- a/beliefs/utils/edges_helper.py +++ /dev/null @@ -1,136 +0,0 @@ -from collections import defaultdict - -from beliefs.models.beliefupdate.Node import Node -from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD - - -class EdgesHelper: - """Class with convenience methods for working with edges.""" - def __init__(self, edges): - self.edges = edges - - def get_label_to_children_dict(self): - """returns dictionary keyed on label, with value a set of children""" - label_to_children_dict = defaultdict(set) - for parent, child in self.edges: - label_to_children_dict[parent].add(child) - return label_to_children_dict - - def get_label_to_parents_dict(self): - """returns dictionary keyed on label, with value a set of parents - Only used to help create dummy factors from edges (not for algo). - """ - label_to_parents_dict = defaultdict(set) - - for parent, child in self.edges: - label_to_parents_dict[child].add(parent) - return label_to_parents_dict - - def get_labels_from_edges(self): - """Return the set of labels that comprise the vertices of a list of edge tuples.""" - all_labels = set() - for parent, child in self.edges: - all_labels.update({parent, child}) - return all_labels - - def create_cpds_from_edges(self, CPD=BernoulliOrCPD): - """ - Create factors from list of edges. - - Input: - cpd: a factor class, assumed initialization takes in a label_id, the label_id of - the child (should = label_id of the node), and set of label_ids of parents. - - Returns: - factors: a set of (unique) factors of the graph - """ - labels = self.get_labels_from_edges() - label_to_parents = self.get_label_to_parents_dict() - - factors = set() - - for label in labels: - parents = label_to_parents[label] - cpd = CPD(label, parents) - factors.add(cpd) - return factors - - def get_label_to_factor_dict(self, CPD=BernoulliOrCPD): - """Create a dictionary mapping each label_id to the CPD/factor where - that label_id is a child. - - Returns: - label_to_factor: dict mapping each label to the cpd that - has that label as a child. - """ - factors = self.create_cpds_from_edges(CPD=CPD) - - label_to_factor = dict() - for factor in factors: - label_to_factor[factor.child] = factor - return label_to_factor - - def get_label_to_node_dict(self, CPD=BernoulliOrCPD): - """Create a dictionary mapping each label_id to a Node instance. - - Returns: - label_to_node: dict mapping each label to the node that has that - label as a label_id. - """ - nodes = self.create_nodes_from_edges() - - label_to_node = dict() - for node in nodes: - label_to_node[node.label_id] = node - return label_to_node - - def get_label_to_node_dict_for_manual_cpds(self, cpds_list): - """Create a dictionary mapping each label_id to a node that is - instantiated with a manually defined pgmpy factor instance. - - Input: - cpds_list - list of instances of pgmpy factors, e.g. TabularCPD - - Returns: - label_to_node: dict mapping each label to the node that has that - label as a label_id. - """ - label_to_children = self.get_label_to_children_dict() - label_to_parents = self.get_label_to_parents_dict() - - label_to_node = dict() - for cpd in cpds_list: - label_id = cpd.variable - - node = Node(label_id=label_id, - children=label_to_children[label_id], - parents=label_to_parents[label_id], - cardinality=2, - cpd=cpd) - label_to_node[label_id] = node - - return label_to_node - - def create_nodes_from_edges(self, node_class): - """ - Create instances of the node_class. Assumes the node class is - initialized by label_id, children, and parents. - - Returns: - nodes: a set of (unique) nodes of the graph - """ - labels = self.get_labels_from_edges() - labels_to_parents = self.get_label_to_parents_dict() - labels_to_children = self.get_label_to_children_dict() - - nodes = set() - - for label in labels: - parents = list(labels_to_parents[label]) - children = list(labels_to_children[label]) - - node = node_class(label_id=label, - children=children, - parents=parents) - nodes.add(node) - return nodes diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py index 264ddae..5c5a612 100644 --- a/tests/test_belief_propagation.py +++ b/tests/test_belief_propagation.py @@ -3,8 +3,10 @@ import pytest from pytest import approx from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError -from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel -from beliefs.models.beliefupdate.BernoulliOrNode import BernoulliOrNode +from beliefs.models.belief_update_node_model import ( + BeliefUpdateNodeModel, + BernoulliOrNode +) @pytest.fixture(scope='module') @@ -37,17 +39,17 @@ def many_parents_edges(): @pytest.fixture(scope='function') def four_node_model(edges_four_nodes): - return BeliefUpdateNodeModel.from_edges(edges_four_nodes, BernoulliOrNode) + return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode) @pytest.fixture(scope='function') def simple_model(simple_edges): - return BeliefUpdateNodeModel.from_edges(simple_edges, BernoulliOrNode) + return BeliefUpdateNodeModel.init_from_edges(simple_edges, BernoulliOrNode) @pytest.fixture(scope='function') def many_parents_model(many_parents_edges): - return BeliefUpdateNodeModel.from_edges(many_parents_edges, BernoulliOrNode) + return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliOrNode) @pytest.fixture(scope='function') |