aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCathy Yeh <cathy@driver.xyz>2017-11-13 15:34:33 -0800
committerCathy Yeh <cathy@driver.xyz>2017-11-17 13:48:32 -0800
commitb16e990b7e4d00e427d4445ba38eef0fb967963a (patch)
tree5ecb0137fde4779c0158eaf66f89a8df6288d250
parent77d8b323d4f6e05ca97d9cbef43ac85fd8040d61 (diff)
downloadbeliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.tar.gz
beliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.tar.bz2
beliefs-b16e990b7e4d00e427d4445ba38eef0fb967963a.zip
changes to work with bump from networkx 1.11 to 2.0
some nx functions now return iterators
-rw-r--r--beliefs/inference/belief_propagation.py22
-rw-r--r--beliefs/models/BayesianModel.py24
-rw-r--r--beliefs/models/BernoulliOrModel.py2
-rw-r--r--beliefs/models/DirectedGraph.py11
-rw-r--r--tests/test_belief_propagation.py3
5 files changed, 32 insertions, 30 deletions
diff --git a/beliefs/inference/belief_propagation.py b/beliefs/inference/belief_propagation.py
index 972fd5d..ecd5e9c 100644
--- a/beliefs/inference/belief_propagation.py
+++ b/beliefs/inference/belief_propagation.py
@@ -50,7 +50,7 @@ class BeliefPropagation:
node_to_update_label_id, msg_sender_label_id = nodes_to_update.pop()
print("Node", node_to_update_label_id)
- node = self.model.nodes[node_to_update_label_id]
+ node = self.model.nodes_dict[node_to_update_label_id]
# exclude the message sender (either a parent or child) from getting an
# outgoing msg from the node to update
@@ -75,7 +75,7 @@ class BeliefPropagation:
except InvalidLambdaMsgToParent:
raise ConflictingEvidenceError(evidence=evidence)
- parent_node = self.model.nodes[parent_id]
+ parent_node = self.model.nodes_dict[parent_id]
parent_node.update_lambda_msg_from_child(child=node_to_update_label_id,
new_value=new_lambda_msg)
nodes_to_update.add(MsgPassers(msg_receiver=parent_id,
@@ -83,7 +83,7 @@ class BeliefPropagation:
for child_id in child_ids:
new_pi_msg = node.compute_pi_msg_to_child(child_k=child_id)
- child_node = self.model.nodes[child_id]
+ child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_to_update_label_id,
new_value=new_pi_msg)
nodes_to_update.add(MsgPassers(msg_receiver=child_id,
@@ -104,7 +104,7 @@ class BeliefPropagation:
"""
self.model.set_boundary_conditions()
- for node in self.model.nodes.values():
+ for node in self.model.nodes_dict.values():
ones_vector = np.ones([node.cardinality])
node.lambda_agg = ones_vector
@@ -119,7 +119,7 @@ class BeliefPropagation:
for node_id in topdown_order:
print('label in iteration through top-down order:', node_id)
- node_sending_msg = self.model.nodes[node_id]
+ node_sending_msg = self.model.nodes_dict[node_id]
child_ids = node_sending_msg.children
if node_sending_msg.pi_agg is None:
@@ -130,7 +130,7 @@ class BeliefPropagation:
new_pi_msg = node_sending_msg.compute_pi_msg_to_child(child_k=child_id)
print(new_pi_msg)
- child_node = self.model.nodes[child_id]
+ child_node = self.model.nodes_dict[child_id]
child_node.update_pi_msg_from_parent(parent=node_id,
new_value=new_pi_msg)
@@ -143,17 +143,17 @@ class BeliefPropagation:
for evidence_id, observed_value in evidence.items():
nodes_to_update = set()
- if evidence_id not in self.model.nodes.keys():
+ if evidence_id not in self.model.nodes_dict.keys():
raise KeyError("Evidence supplied for non-existent label_id: {}"
.format(evidence_id))
if is_kronecker_delta(observed_value):
# specific evidence
- self.model.nodes[evidence_id].lambda_agg = observed_value
+ self.model.nodes_dict[evidence_id].lambda_agg = observed_value
else:
# virtual evidence
- self.model.nodes[evidence_id].lambda_agg = \
- self.model.nodes[evidence_id].lambda_agg * observed_value
+ self.model.nodes_dict[evidence_id].lambda_agg = \
+ self.model.nodes_dict[evidence_id].lambda_agg * observed_value
nodes_to_update.add(MsgPassers(msg_receiver=evidence_id,
msg_sender=None))
@@ -189,4 +189,4 @@ class BeliefPropagation:
if evidence:
self._run_belief_propagation(evidence)
- return {label_id: node.belief for label_id, node in self.model.nodes.items()}
+ return {label_id: node.belief for label_id, node in self.model.nodes_dict.items()}
diff --git a/beliefs/models/BayesianModel.py b/beliefs/models/BayesianModel.py
index bdfd037..6257a57 100644
--- a/beliefs/models/BayesianModel.py
+++ b/beliefs/models/BayesianModel.py
@@ -12,18 +12,18 @@ class BayesianModel(DirectedGraph):
Bayesian model stores nodes and edges described by conditional probability
distributions.
"""
- def __init__(self, edges, nodes=None):
+ def __init__(self, edges, nodes_dict=None):
"""
Input:
edges: list of edge tuples of form ('parent', 'child')
nodes: (optional) dict
a dict key, value pair as {label_id: instance_of_node_class_or_subclass}
"""
- if nodes is not None:
- super().__init__(edges, nodes.keys())
+ if nodes_dict is not None:
+ super().__init__(edges, nodes_dict.keys())
else:
super().__init__(edges)
- self.nodes = nodes
+ self.nodes_dict = nodes_dict
@classmethod
def from_node_class(cls, edges, node_class):
@@ -57,10 +57,10 @@ class BayesianModel(DirectedGraph):
to an (unnormalized) unit vector, of length the cardinality of x.
"""
for root in self.get_roots():
- self.nodes[root].pi_agg = self.nodes[root].cpd.values
+ self.nodes_dict[root].pi_agg = self.nodes_dict[root].cpd.values
for leaf in self.get_leaves():
- self.nodes[leaf].lambda_agg = np.ones([self.nodes[leaf].cardinality])
+ self.nodes_dict[leaf].lambda_agg = np.ones([self.nodes_dict[leaf].cardinality])
@property
def all_nodes_are_fully_initialized(self):
@@ -68,7 +68,7 @@ class BayesianModel(DirectedGraph):
Returns True if, for all nodes in the model, all lambda and pi
messages and lambda_agg and pi_agg are not None, else False.
"""
- for node in self.nodes.values():
+ for node in self.nodes_dict.values():
if not node.is_fully_initialized:
return False
return True
@@ -77,8 +77,8 @@ class BayesianModel(DirectedGraph):
"""
Returns a copy of the model.
"""
- copy_edges = self.edges().copy()
- copy_nodes = copy.deepcopy(self.nodes)
+ copy_edges = list(self.edges()).copy()
+ copy_nodes = copy.deepcopy(self.nodes_dict)
copy_model = self.__class__(edges=copy_edges, nodes=copy_nodes)
return copy_model
@@ -90,7 +90,7 @@ class BayesianModel(DirectedGraph):
RETURNS
set of strings (labels)
"""
- return {label for label, node in self.nodes.items() if is_kronecker_delta(node.belief)}
+ return {label for label, node in self.nodes_dict.items() if is_kronecker_delta(node.belief)}
def get_unobserved_variables_in_definite_state(self, observed=set()):
"""
@@ -105,9 +105,9 @@ class BayesianModel(DirectedGraph):
# Assert that beliefs of directly observed vars are kronecker deltas
for label in observed:
- assert is_kronecker_delta(self.nodes[label].belief), \
+ assert is_kronecker_delta(self.nodes_dict[label].belief), \
("Observed label has belief {} but should be kronecker delta"
- .format(self.nodes[label].belief))
+ .format(self.nodes_dict[label].belief))
vars_in_definite_state = self.get_variables_in_definite_state()
assert observed <= vars_in_definite_state, \
diff --git a/beliefs/models/BernoulliOrModel.py b/beliefs/models/BernoulliOrModel.py
index da18fb6..bf2b44c 100644
--- a/beliefs/models/BernoulliOrModel.py
+++ b/beliefs/models/BernoulliOrModel.py
@@ -14,4 +14,4 @@ class BernoulliOrModel(BayesianModel):
"""
if nodes is None:
nodes = self.create_nodes(edges, node_class=BernoulliOrNode)
- super().__init__(edges, nodes=nodes)
+ super().__init__(edges, nodes_dict=nodes)
diff --git a/beliefs/models/DirectedGraph.py b/beliefs/models/DirectedGraph.py
index 8fce894..84b3a02 100644
--- a/beliefs/models/DirectedGraph.py
+++ b/beliefs/models/DirectedGraph.py
@@ -21,15 +21,16 @@ class DirectedGraph(nx.DiGraph):
"""
Returns a list of leaves of the graph.
"""
- return [node for node, out_degree in self.out_degree_iter() if
- out_degree == 0]
+ return [node for node, out_degree in self.out_degree() if out_degree == 0]
def get_roots(self):
"""
Returns a list of roots of the graph.
"""
- return [node for node, in_degree in self.in_degree().items() if
- in_degree == 0]
+ return [node for node, in_degree in self.in_degree() if in_degree == 0]
def get_topologically_sorted_nodes(self, reverse=False):
- return nx.topological_sort(self, reverse=reverse)
+ if reverse:
+ return list(reversed(list(nx.topological_sort(self))))
+ else:
+ return nx.topological_sort(self)
diff --git a/tests/test_belief_propagation.py b/tests/test_belief_propagation.py
index 223ba78..ef7ffb0 100644
--- a/tests/test_belief_propagation.py
+++ b/tests/test_belief_propagation.py
@@ -166,7 +166,8 @@ def test_belief_propagation_modify_model_inplace(simple_model):
_ = infer.query(evidence={})
assert simple_model.all_nodes_are_fully_initialized
- beliefs_from_model = {node_id: node.belief[1] for node_id, node in simple_model.nodes.items()}
+ beliefs_from_model = {node_id: node.belief[1] for
+ node_id, node in simple_model.nodes_dict.items()}
compare_dictionaries(expected, beliefs_from_model)