aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-12-03 19:16:32 -0800
committerCathy Yeh <cathy@driver.xyz>2017-12-03 20:35:30 -0800
commit8dc7ae89677fca16ee974a30cff8c4df53c955ce (patch)
tree6b021dcd7902a2952cc97872f6200469b7dab51b
parente5937060658f7e8ac484e5489f2b35b4ecb96d35 (diff)
downloadbeliefs-8dc7ae89677fca16ee974a30cff8c4df53c955ce.tar.gz
beliefs-8dc7ae89677fca16ee974a30cff8c4df53c955ce.tar.bz2
beliefs-8dc7ae89677fca16ee974a30cff8c4df53c955ce.zip
PR comments
-rw-r--r--beliefs/factors/bernoulli_or_cpd.py (renamed from beliefs/factors/BernoulliOrCPD.py)2
-rw-r--r--beliefs/factors/cpd.py (renamed from beliefs/factors/CPD.py)0
-rw-r--r--beliefs/inference/belief_propagation.py44
-rw-r--r--beliefs/models/DirectedGraph.py36
-rw-r--r--beliefs/models/base_models.py (renamed from beliefs/models/BayesianModel.py)43
-rw-r--r--beliefs/models/belief_update_node_model.py (renamed from beliefs/models/beliefupdate/Node.py)158
-rw-r--r--beliefs/models/beliefupdate/BeliefUpdateNodeModel.py91
-rw-r--r--beliefs/models/beliefupdate/BernoulliOrNode.py47
-rw-r--r--beliefs/utils/edges_helper.py136
-rw-r--r--tests/test_belief_propagation.py12
10 files changed, 219 insertions, 350 deletions
diff --git a/beliefs/factors/BernoulliOrCPD.py b/beliefs/factors/bernoulli_or_cpd.py
index 2c6a31e..bfb3a95 100644
--- a/beliefs/factors/BernoulliOrCPD.py
+++ b/beliefs/factors/bernoulli_or_cpd.py
@@ -1,6 +1,6 @@
import numpy as np
-from beliefs.factors.CPD import TabularCPD
+from beliefs.factors.cpd import TabularCPD
class BernoulliOrCPD(TabularCPD):
diff --git a/beliefs/factors/CPD.py b/beliefs/factors/cpd.py
index a286aaa..a286aaa 100644
--- a/beliefs/factors/CPD.py
+++ b/beliefs/factors/cpd.py
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 02f5595..7ec648d 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -1,11 +1,17 @@
import numpy as np
from collections import namedtuple
+import logging
-from beliefs.models.beliefupdate.Node import InvalidLambdaMsgToParent
-from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel
+from beliefs.models.belief_update_node_model import (
+ InvalidLambdaMsgToParent,
+ BeliefUpdateNodeModel
+)
from beliefs.utils.math_helper import is_kronecker_delta
+logger = logging.getLogger(__name__)
+
+
MsgPassers = namedtuple('MsgPassers', ['msg_receiver', 'msg_sender'])
@@ -51,7 +57,7 @@ class BeliefPropagation:
return
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
- print("Node", node_to_update_label_id)
+ logging.info("Node: %s", node_to_update_label_id)
node = self.model.nodes_dict[node_to_update_label_id]
@@ -59,8 +65,8 @@ class BeliefPropagation:
# outgoing msg from the node to update
parent_ids = set(node.parents) - set([msg_sender_label_id])
child_ids = set(node.children) - set([msg_sender_label_id])
- print("parent_ids:", parent_ids)
- print("child_ids:", child_ids)
+ logging.info("parent_ids: %s", str(parent_ids))
+ logging.info("child_ids: %s", str(child_ids))
if msg_sender_label_id is not None:
# update triggered by receiving a message, not pinning to evidence
@@ -68,9 +74,9 @@ class BeliefPropagation:
if node_to_update_label_id not in evidence:
node.compute_pi_agg()
- print("belief propagation pi_agg", node.pi_agg)
+ logging.info("belief propagation pi_agg: %s", np.array2string(node.pi_agg))
node.compute_lambda_agg()
- print("belief propagation lambda_agg", node.lambda_agg)
+ logging.info("belief propagation lambda_agg: %s", np.array2string(node.lambda_agg))
for parent_id in parent_ids:
try:
@@ -114,13 +120,13 @@ class BeliefPropagation:
for child in node.lambda_received_msgs.keys():
node.update_lambda_msg_from_child(child=child,
new_value=ones_vector)
- print("Finished initializing Lambda(x) and lambda_received_msgs per node.")
+ logging.info("Finished initializing Lambda(x) and lambda_received_msgs per node.")
- print("Start downward sweep from nodes. Sending Pi messages only.")
+ logging.info("Start downward sweep from nodes. Sending Pi messages only.")
topdown_order = self.model.get_topologically_sorted_nodes(reverse=False)
for node_id in topdown_order:
- print('label in iteration through top-down order:', node_id)
+ logging.info('label in iteration through top-down order: %s', str(node_id))
node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
@@ -129,9 +135,9 @@ class BeliefPropagation:
node_sending_msg.compute_pi_agg()
for child_id in child_ids:
- print("child", child_id)
+ logging.info("child: %s", str(child_id))
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
- print(new_pi_msg)
+ logging.info("new_pi_msg: %s", np.array2string(new_pi_msg))
child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
@@ -158,10 +164,9 @@ class BeliefPropagation:
self.model.nodes_dict[evidence_id].lambda_agg = \
self.model.nodes_dict[evidence_id].lambda_agg * observed_value
- nodes_to_update.add(MsgPassers(msg_receiver=evidence_id,
- msg_sender=None))
+ nodes_to_update = [MsgPassers(msg_receiver=evidence_id, msg_sender=None)]
- self._belief_propagation(nodes_to_update=nodes_to_update,
+ self._belief_propagation(nodes_to_update=set(nodes_to_update),
evidence=evidence)
def query(self, evidence={}):
@@ -179,12 +184,13 @@ class BeliefPropagation:
Example
-------
- >> from label_graph_service.pgm.inference.belief_propagation import BeliefPropagation
- >> from label_graph_service.pgm.models.BernoulliOrModel import BernoulliOrModel
+ >> import numpy as np
+ >> from beliefs.inference.belief_propagation import BeliefPropagation
+ >> from beliefs.models.belief_update_node_model import BeliefUpdateNodeModel, BernoulliOrNode
>> edges = [('1', '3'), ('2', '3'), ('3', '5')]
- >> model = BernoulliOrModel(edges)
+ >> model = BeliefUpdateNodeModel.init_from_edges(edges, BernoulliOrNode)
>> infer = BeliefPropagation(model)
- >> result = infer.query({'2': np.array([0, 1])})
+ >> result = infer.query(evidence={'2': np.array([0, 1])})
"""
if not self.model.all_nodes_are_fully_initialized:
self.initialize_model()
diff --git a/beliefs/models/DirectedGraph.py b/beliefs/models/DirectedGraph.py
deleted file mode 100644
index 84b3a02..0000000
--- a/beliefs/models/DirectedGraph.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import networkx as nx
-
-
-class DirectedGraph(nx.DiGraph):
- """
- Base class for all directed graphical models.
- """
- def __init__(self, edges=None, node_labels=None):
- """
- Input:
- edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
- node_labels: a list of strings of node labels
- """
- super().__init__()
- if edges is not None:
- self.add_edges_from(edges)
- if node_labels is not None:
- self.add_nodes_from(node_labels)
-
- def get_leaves(self):
- """
- Returns a list of leaves of the graph.
- """
- return [node for node, out_degree in self.out_degree() if out_degree == 0]
-
- def get_roots(self):
- """
- Returns a list of roots of the graph.
- """
- return [node for node, in_degree in self.in_degree() if in_degree == 0]
-
- def get_topologically_sorted_nodes(self, reverse=False):
- if reverse:
- return list(reversed(list(nx.topological_sort(self))))
- else:
- return nx.topological_sort(self)
diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/base_models.py
index b57f968..cb91566 100644
--- a/beliefs/models/BayesianModel.py
+++ b/beliefs/models/base_models.py
@@ -1,10 +1,43 @@
-import copy
import networkx as nx
-from beliefs.models.DirectedGraph import DirectedGraph
from beliefs.utils.math_helper import is_kronecker_delta
+class DirectedGraph(nx.DiGraph):
+ """
+ Base class for all directed graphical models.
+ """
+ def __init__(self, edges=None, node_labels=None):
+ """
+ Input:
+ edges: an edge list, e.g. [(parent1, child1), (parent1, child2)]
+ node_labels: a list of strings of node labels
+ """
+ super().__init__()
+ if edges is not None:
+ self.add_edges_from(edges)
+ if node_labels is not None:
+ self.add_nodes_from(node_labels)
+
+ def get_leaves(self):
+ """
+ Returns a list of leaves of the graph.
+ """
+ return [node for node, out_degree in self.out_degree() if out_degree == 0]
+
+ def get_roots(self):
+ """
+ Returns a list of roots of the graph.
+ """
+ return [node for node, in_degree in self.in_degree() if in_degree == 0]
+
+ def get_topologically_sorted_nodes(self, reverse=False):
+ if reverse:
+ return list(reversed(list(nx.topological_sort(self))))
+ else:
+ return nx.topological_sort(self)
+
+
class BayesianModel(DirectedGraph):
"""
Bayesian model stores nodes and edges described by conditional probability
@@ -69,8 +102,8 @@ class BayesianModel(DirectedGraph):
return vars_in_definite_state - observed
def _get_ancestors_of(self, observed):
- """Return list of ancestors of observed labels, including the observed labels themselves."""
- ancestors = observed.copy()
+ """Return list of ancestors of observed labels"""
+ ancestors = set()
for label in observed:
ancestors.update(nx.ancestors(self, label))
return ancestors
@@ -87,7 +120,9 @@ class BayesianModel(DirectedGraph):
reachable_observed_vars: set of strings, observed labels (variables with direct
evidence) that are reachable from the source label.
"""
+ # ancestors of observed labels, including observed labels
ancestors_of_observed = self._get_ancestors_of(observed)
+ ancestors_of_observed.update(observed)
visit_list = set()
visit_list.add((source, 'up'))
diff --git a/beliefs/models/beliefupdate/Node.py b/beliefs/models/belief_update_node_model.py
index daa2f14..667e0f1 100644
--- a/beliefs/models/beliefupdate/Node.py
+++ b/beliefs/models/belief_update_node_model.py
@@ -1,6 +1,13 @@
+import copy
+from enum import Enum
import numpy as np
+import itertools
from functools import reduce
-from enum import Enum
+
+import networkx as nx
+
+from beliefs.models.base_models import BayesianModel
+from beliefs.factors.bernoulli_or_cpd import BernoulliOrCPD
class InvalidLambdaMsgToParent(Exception):
@@ -13,6 +20,98 @@ class MessageType(Enum):
PI = 'pi'
+class BeliefUpdateNodeModel(BayesianModel):
+ """
+ A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties
+ and methods for Pearl's belief update algorithm.
+
+ ref: "Fusion, Propagation, and Structuring in Belief Networks"
+ Artificial Intelligence 29 (1986) 241-288
+
+ """
+ def __init__(self, nodes_dict):
+ """
+ Input:
+ nodes_dict: dict
+ a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
+ """
+ super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
+ variables=list(nodes_dict.keys()),
+ cpds=[node.cpd for node in nodes_dict.values()])
+
+ self.nodes_dict = nodes_dict
+
+ @classmethod
+ def init_from_edges(cls, edges, node_class):
+ """Create nodes from the same node class.
+
+ Input:
+ edges: list of edge tuples of form ('parent', 'child')
+ node_class: the Node class or subclass from which to
+ create all the nodes from edges.
+ """
+ nodes = set()
+ g = nx.DiGraph(edges)
+
+ for label in set(itertools.chain(*edges)):
+ node = node_class(label_id=label,
+ children=list(g.successors(label)),
+ parents=list(g.predecessors(label)))
+ nodes.add(node)
+ nodes_dict = {node.label_id: node for node in nodes}
+ return cls(nodes_dict)
+
+ @staticmethod
+ def _get_edges_from_nodes(nodes):
+ """
+ Return list of all directed edges in nodes.
+
+ Args:
+ nodes: an iterable of objects of the Node class or subclass
+ Returns:
+ edges: list of edge tuples
+ """
+ edges = set()
+ for node in nodes:
+ if node.parents:
+ edge_tuples = zip(node.parents, [node.label_id]*len(node.parents))
+ edges.update(edge_tuples)
+ return list(edges)
+
+ def set_boundary_conditions(self):
+ """
+ 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
+ probability of x.
+
+ 2. Leaf nodes: if x is a node with no children, set Lambda(x)
+ to an (unnormalized) unit vector, of length the cardinality of x.
+ """
+ for root in self.get_roots():
+ self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
+
+ for leaf in self.get_leaves():
+ self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
+
+ @property
+ def all_nodes_are_fully_initialized(self):
+ """
+ Returns True if, for all nodes in the model, all lambda and pi
+ messages and lambda_agg and pi_agg are not None, else False.
+ """
+ for node in self.nodes_dict.values():
+ if not node.is_fully_initialized:
+ return False
+ return True
+
+ def copy(self):
+ """
+ Returns a copy of the model.
+ """
+ copy_nodes = copy.deepcopy(self.nodes_dict)
+ copy_model = self.__class__(nodes_dict=copy_nodes)
+ return copy_model
+
+
class Node:
"""A node in a DAG with methods to compute the belief (marginal probability
of the node given evidence) and compute pi/lambda messages to/from its neighbors
@@ -102,8 +201,8 @@ class Node:
if any(msg is None for msg in msg_values):
raise ValueError(
- "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg.".
- format(msg_type=message_type.value)
+ "Missing value for {msg_type} msg from child: can't compute {msg_type}_agg."
+ .format(msg_type=message_type.value)
)
else:
return msg_values
@@ -122,16 +221,16 @@ class Node:
def _update_received_msg_by_key(self, received_msg_dict, key, new_value):
if key not in received_msg_dict.keys():
- raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}".
- format(key, received_msg_dict.keys()))
+ raise ValueError("Label id '{}' to update message isn't in allowed set of keys: {}"
+ .format(key, received_msg_dict.keys()))
if not isinstance(new_value, np.ndarray):
- raise TypeError("Expected a new value of type numpy.ndarray, but got type {}".
- format(type(new_value)))
+ raise TypeError("Expected a new value of type numpy.ndarray, but got type {}"
+ .format(type(new_value)))
if new_value.shape != (self.cardinality,):
- raise ValueError("Expected new value to be of dimensions ({},) but got {} instead".
- format(self.cardinality, new_value.shape))
+ raise ValueError("Expected new value to be of dimensions ({},) but got {} instead"
+ .format(self.cardinality, new_value.shape))
received_msg_dict[key] = new_value
def update_pi_msg_from_parent(self, parent, new_value):
@@ -152,8 +251,7 @@ class Node:
return self._normalize(
np.nan_to_num(np.divide(self.belief, lambda_msg_from_child)))
else:
- raise ValueError("Can't compute pi message to child_{} without having received" \
- "a lambda message from that child.")
+ raise ValueError("Can't compute pi message to child_{} without having received a lambda message from that child.")
def compute_lambda_msg_to_parent(self, parent_k):
# TODO: implement explict factor product operation
@@ -177,3 +275,41 @@ class Node:
return False
return True
+
+
+class BernoulliOrNode(Node):
+ def __init__(self,
+ label_id,
+ children,
+ parents):
+ super().__init__(label_id=label_id,
+ children=children,
+ parents=parents,
+ cardinality=2,
+ cpd=BernoulliOrCPD(label_id, parents))
+
+ def compute_pi_agg(self):
+ if not self.parents:
+ self.pi_agg = self.cpd.values
+ else:
+ pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ parents_p0 = [p[0] for p in pi_msg_values]
+ p_0 = reduce(lambda x, y: x*y, parents_p0)
+ p_1 = 1 - p_0
+ self.pi_agg = np.array([p_0, p_1])
+ return self.pi_agg
+
+ def compute_lambda_msg_to_parent(self, parent_k):
+ if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
+ return np.ones([self.cardinality])
+ else:
+ # TODO: cleanup this validation
+ _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
+ p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
+ p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
+ lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
+ lambda_1 = self.lambda_agg[1]
+ lambda_msg = np.array([lambda_0, lambda_1])
+ if not any(lambda_msg):
+ raise InvalidLambdaMsgToParent
+ return self._normalize(lambda_msg)
diff --git a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py b/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py
deleted file mode 100644
index d74eaa7..0000000
--- a/beliefs/models/beliefupdate/BeliefUpdateNodeModel.py
+++ /dev/null
@@ -1,91 +0,0 @@
-import copy
-import numpy as np
-
-from beliefs.models.BayesianModel import BayesianModel
-from beliefs.utils.edges_helper import EdgesHelper
-
-
-class BeliefUpdateNodeModel(BayesianModel):
- """
- A Bayesian model storing nodes (e.g. Node or BernoulliOrNode) implementing properties
- and methods for Pearl's belief update algorithm.
-
- ref: "Fusion, Propagation, and Structuring in Belief Networks"
- Artificial Intelligence 29 (1986) 241-288
-
- """
- def __init__(self, nodes_dict):
- """
- Input:
- nodes_dict: dict
- a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
- """
- super().__init__(edges=self._get_edges_from_nodes(nodes_dict.values()),
- variables=list(nodes_dict.keys()),
- cpds=[node.cpd for node in nodes_dict.values()])
-
- self.nodes_dict = nodes_dict
-
- @classmethod
- def from_edges(cls, edges, node_class):
- """Create nodes from the same node class.
-
- Input:
- edges: list of edge tuples of form ('parent', 'child')
- node_class: the Node class or subclass from which to
- create all the nodes from edges.
- """
- edges_helper = EdgesHelper(edges)
- nodes = edges_helper.create_nodes_from_edges(node_class)
- nodes_dict = {node.label_id: node for node in nodes}
- return cls(nodes_dict)
-
- @staticmethod
- def _get_edges_from_nodes(nodes):
- """
- Return list of all directed edges in nodes.
-
- Args:
- nodes: an iterable of objects of the Node class or subclass
- Returns:
- edges: list of edge tuples
- """
- edges = set()
- for node in nodes:
- if node.parents:
- edge_tuples = zip(node.parents, [node.label_id]*len(node.parents))
- edges.update(edge_tuples)
- return list(edges)
-
- def set_boundary_conditions(self):
- """
- 1. Root nodes: if x is a node with no parents, set Pi(x) = prior
- probability of x.
-
- 2. Leaf nodes: if x is a node with no children, set Lambda(x)
- to an (unnormalized) unit vector, of length the cardinality of x.
- """
- for root in self.get_roots():
- self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
-
- for leaf in self.get_leaves():
- self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
-
- @property
- def all_nodes_are_fully_initialized(self):
- """
- Returns True if, for all nodes in the model, all lambda and pi
- messages and lambda_agg and pi_agg are not None, else False.
- """
- for node in self.nodes_dict.values():
- if not node.is_fully_initialized:
- return False
- return True
-
- def copy(self):
- """
- Returns a copy of the model.
- """
- copy_nodes = copy.deepcopy(self.nodes_dict)
- copy_model = self.__class__(nodes_dict=copy_nodes)
- return copy_model
diff --git a/beliefs/models/beliefupdate/BernoulliOrNode.py b/beliefs/models/beliefupdate/BernoulliOrNode.py
deleted file mode 100644
index 3386275..0000000
--- a/beliefs/models/beliefupdate/BernoulliOrNode.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import numpy as np
-from functools import reduce
-
-from beliefs.models.beliefupdate.Node import (
- Node,
- MessageType,
- InvalidLambdaMsgToParent
-)
-from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
-
-
-class BernoulliOrNode(Node):
- def __init__(self,
- label_id,
- children,
- parents):
- super().__init__(label_id=label_id,
- children=children,
- parents=parents,
- cardinality=2,
- cpd=BernoulliOrCPD(label_id, parents))
-
- def compute_pi_agg(self):
- if not self.parents:
- self.pi_agg = self.cpd.values
- else:
- pi_msg_values = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- parents_p0 = [p[0] for p in pi_msg_values]
- p_0 = reduce(lambda x, y: x*y, parents_p0)
- p_1 = 1 - p_0
- self.pi_agg = np.array([p_0, p_1])
- return self.pi_agg
-
- def compute_lambda_msg_to_parent(self, parent_k):
- if np.array_equal(self.lambda_agg, np.ones([self.cardinality])):
- return np.ones([self.cardinality])
- else:
- # TODO: cleanup this validation
- _ = self.validate_and_return_msgs_received_for_msg_type(MessageType.PI)
- p0_excluding_k = [msg[0] for par_id, msg in self.pi_received_msgs.items() if par_id != parent_k]
- p0_product = reduce(lambda x, y: x*y, p0_excluding_k, 1)
- lambda_0 = self.lambda_agg[1] + (self.lambda_agg[0] - self.lambda_agg[1])*p0_product
- lambda_1 = self.lambda_agg[1]
- lambda_msg = np.array([lambda_0, lambda_1])
- if not any(lambda_msg):
- raise InvalidLambdaMsgToParent
- return self._normalize(lambda_msg)
diff --git a/beliefs/utils/edges_helper.py b/beliefs/utils/edges_helper.py
deleted file mode 100644
index 130686c..0000000
--- a/beliefs/utils/edges_helper.py
+++ /dev/null
@@ -1,136 +0,0 @@
-from collections import defaultdict
-
-from beliefs.models.beliefupdate.Node import Node
-from beliefs.factors.BernoulliOrCPD import BernoulliOrCPD
-
-
-class EdgesHelper:
- """Class with convenience methods for working with edges."""
- def __init__(self, edges):
- self.edges = edges
-
- def get_label_to_children_dict(self):
- """returns dictionary keyed on label, with value a set of children"""
- label_to_children_dict = defaultdict(set)
- for parent, child in self.edges:
- label_to_children_dict[parent].add(child)
- return label_to_children_dict
-
- def get_label_to_parents_dict(self):
- """returns dictionary keyed on label, with value a set of parents
- Only used to help create dummy factors from edges (not for algo).
- """
- label_to_parents_dict = defaultdict(set)
-
- for parent, child in self.edges:
- label_to_parents_dict[child].add(parent)
- return label_to_parents_dict
-
- def get_labels_from_edges(self):
- """Return the set of labels that comprise the vertices of a list of edge tuples."""
- all_labels = set()
- for parent, child in self.edges:
- all_labels.update({parent, child})
- return all_labels
-
- def create_cpds_from_edges(self, CPD=BernoulliOrCPD):
- """
- Create factors from list of edges.
-
- Input:
- cpd: a factor class, assumed initialization takes in a label_id, the label_id of
- the child (should = label_id of the node), and set of label_ids of parents.
-
- Returns:
- factors: a set of (unique) factors of the graph
- """
- labels = self.get_labels_from_edges()
- label_to_parents = self.get_label_to_parents_dict()
-
- factors = set()
-
- for label in labels:
- parents = label_to_parents[label]
- cpd = CPD(label, parents)
- factors.add(cpd)
- return factors
-
- def get_label_to_factor_dict(self, CPD=BernoulliOrCPD):
- """Create a dictionary mapping each label_id to the CPD/factor where
- that label_id is a child.
-
- Returns:
- label_to_factor: dict mapping each label to the cpd that
- has that label as a child.
- """
- factors = self.create_cpds_from_edges(CPD=CPD)
-
- label_to_factor = dict()
- for factor in factors:
- label_to_factor[factor.child] = factor
- return label_to_factor
-
- def get_label_to_node_dict(self, CPD=BernoulliOrCPD):
- """Create a dictionary mapping each label_id to a Node instance.
-
- Returns:
- label_to_node: dict mapping each label to the node that has that
- label as a label_id.
- """
- nodes = self.create_nodes_from_edges()
-
- label_to_node = dict()
- for node in nodes:
- label_to_node[node.label_id] = node
- return label_to_node
-
- def get_label_to_node_dict_for_manual_cpds(self, cpds_list):
- """Create a dictionary mapping each label_id to a node that is
- instantiated with a manually defined pgmpy factor instance.
-
- Input:
- cpds_list - list of instances of pgmpy factors, e.g. TabularCPD
-
- Returns:
- label_to_node: dict mapping each label to the node that has that
- label as a label_id.
- """
- label_to_children = self.get_label_to_children_dict()
- label_to_parents = self.get_label_to_parents_dict()
-
- label_to_node = dict()
- for cpd in cpds_list:
- label_id = cpd.variable
-
- node = Node(label_id=label_id,
- children=label_to_children[label_id],
- parents=label_to_parents[label_id],
- cardinality=2,
- cpd=cpd)
- label_to_node[label_id] = node
-
- return label_to_node
-
- def create_nodes_from_edges(self, node_class):
- """
- Create instances of the node_class. Assumes the node class is
- initialized by label_id, children, and parents.
-
- Returns:
- nodes: a set of (unique) nodes of the graph
- """
- labels = self.get_labels_from_edges()
- labels_to_parents = self.get_label_to_parents_dict()
- labels_to_children = self.get_label_to_children_dict()
-
- nodes = set()
-
- for label in labels:
- parents = list(labels_to_parents[label])
- children = list(labels_to_children[label])
-
- node = node_class(label_id=label,
- children=children,
- parents=parents)
- nodes.add(node)
- return nodes
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 264ddae..5c5a612 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -3,8 +3,10 @@ import pytest
from pytest import approx
from beliefs.inference.belief_propagation import BeliefPropagation, ConflictingEvidenceError
-from beliefs.models.beliefupdate.BeliefUpdateNodeModel import BeliefUpdateNodeModel
-from beliefs.models.beliefupdate.BernoulliOrNode import BernoulliOrNode
+from beliefs.models.belief_update_node_model import (
+ BeliefUpdateNodeModel,
+ BernoulliOrNode
+)
@pytest.fixture(scope='module')
@@ -37,17 +39,17 @@ def many_parents_edges():
@pytest.fixture(scope='function')
def four_node_model(edges_four_nodes):
- return BeliefUpdateNodeModel.from_edges(edges_four_nodes, BernoulliOrNode)
+ return BeliefUpdateNodeModel.init_from_edges(edges_four_nodes, BernoulliOrNode)
@pytest.fixture(scope='function')
def simple_model(simple_edges):
- return BeliefUpdateNodeModel.from_edges(simple_edges, BernoulliOrNode)
+ return BeliefUpdateNodeModel.init_from_edges(simple_edges, BernoulliOrNode)
@pytest.fixture(scope='function')
def many_parents_model(many_parents_edges):
- return BeliefUpdateNodeModel.from_edges(many_parents_edges, BernoulliOrNode)
+ return BeliefUpdateNodeModel.init_from_edges(many_parents_edges, BernoulliOrNode)
@pytest.fixture(scope='function')