diff --git a/pydynamo/core/graph.py b/pydynamo/core/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..44c58f65496541ef39f8572e2eee20af5c3ac686 --- /dev/null +++ b/pydynamo/core/graph.py @@ -0,0 +1,173 @@ +"""Functions of System which handles networks. +""" +import networkx as nx + +# Graph +def get_cst_graph(self): + """Get the graph of influences for constants: an arrow from constant A to constant B if B needs A to be computed. + + Returns + ------- + networkx.DiGraph + Graph of constant influences. + """ + G = nx.DiGraph() + for c in self.eqs['cst']: + G.add_node(c) + args = self.eqs['cst'][c]['args'] + for node_type in args: + for node in args[node_type]: + if node_type == 'var': + raise(Exception( + f'Constant {c} is set with a variable ({"_".join(node)}):\n' + f'{c} = {self.eqs["cst"][c]["line"]}')) + if node_type == 'cst': + G.add_edge(node, c) + return G + +def get_init_graph(self): + """Get the graph of influences for variables at initialisation: an arrow from variable A to variable B if B needs A to be computed. + + Returns + ------- + networkx.DiGraph + Graph of variable initialisation influences. + """ + G = nx.DiGraph() + for v in self.eqs['init']: + G.add_node(v) + args = self.eqs['init'][v]['args'] + for av, ai in args['var']: + if av in self.eqs['update']: + G.add_edge(av, v) + for v in self.eqs['update']: + if not self.is_initialized(v): + G.add_node(v) + args = self.eqs['update'][v]['args'] + for av, ai in args['var']: + if av in self.eqs['update']: + G.add_edge(av, v) + return G + +def get_update_graph(self): + """Get the graph of influences for variables and their indices at updating step: an arrow from variable (A, i) to variable (B, k) if (B, k) needs (A, i) to be computed. + + Returns + ------- + networkx.DiGraph + Graph of variable influences at updating step. + """ + G = nx.DiGraph() + for v in self.eqs['update']: + G.add_node((v, 'k')) + args = self.eqs['update'][v]['args'] + for av, ai in args['var']: + if av in self.eqs['update']: + G.add_edge((av, ai), (v, 'k')) + return G + +def get_update_graph_quotient(self): + """Get the graph of influences for variables at updating step: an arrow from variable A to variable B if B needs A to be computed. + + Returns + ------- + networkx.DiGraph + Graph of variable influences at updating step. + """ + G = nx.DiGraph() + for v in self.eqs['update']: + G.add_node(v) + args = self.eqs['update'][v]['args'] + for av, _ in args['var']: + if av in self.eqs['update']: + G.add_edge(av, v) + return G + +def get_influence_graph(self): + """Get the graph of influences: an arrow from A to B if B needs A (at initialisation or updating step) to be computed. + + Returns + ------- + networkx.DiGraph + Graph of influences. + """ + G = nx.DiGraph() + for v in self.eqs['update']: + G.add_node(v) + args = self.eqs['update'][v]['args'] + for av, _ in args['var']: + if av in self.eqs['update']: + G.add_edge(av, v) + for ac in args['cst']: + if ac in self.eqs['cst']: + G.add_edge(ac, v) + + for v in self.eqs['init']: + G.add_node(v) + args = self.eqs['init'][v]['args'] + for av, _ in args['var']: + if av in self.eqs['init']: + G.add_edge(av, v) + for ac in args['cst']: + if ac in self.eqs['cst']: + G.add_edge(ac, v) + + for c in self.eqs['cst']: + G.add_node(c) + args = self.eqs['cst'][c]['args'] + for ac in args['cst']: + if ac in self.eqs['cst']: + G.add_edge(ac, c) + return G + +def set_order(self, gtype): + """Returns the order to set constants, intitialize or update variables. + + Parameters + ---------- + gtype : str + Type of graph, either 'cst', 'init', or 'update'. + """ + if gtype == 'cst': + G = self.get_cst_graph() + elif gtype == 'init': + G = self.get_init_graph() + elif gtype == 'update': + G = self.get_update_graph() + else: + raise Exception("Wrong type of graph") + return nx.topological_sort(G) + +def assert_update_acyclic(self): + """Assert that the updating graph is acyclic, and print the cycle in case there is some. + """ + G = self.get_update_graph() + b = '\\' + assert nx.is_directed_acyclic_graph(G), \ + "Update is not acyclic:\n"\ + + "\n".join(f"{'.'.join(j)} <- {'.'.join(i)}: " + f"{'.'.join(j)} = {re.sub(b+'_([jk])', '.'+b+'1',self.eqs['update'][j[0]]['line'])}" + for i, j in reversed(nx.find_cycle(G)))\ + + "\nPlease design an update scheme that is not cyclic." + +def assert_init_acyclic(self): + """Assert that the initialisation graph is acyclic, and print the cycle in case there is some. + """ + G = self.get_init_graph() + if not nx.is_directed_acyclic_graph(G): + msg = "Initialisation is not acyclic:\n" + for i, j in reversed(nx.find_cycle(G)): + line = self.eqs['update'][j]['line'] + if j in self.eqs['init']: + line = self.eqs['init'][j]['line'] + msg += f"{j} <- {i}: " + msg += f"{j}.i = {line} \n" + msg +="Please design an initialisation scheme that is not cyclic." + raise AssertionError(msg) + +def assert_cst_acyclic(self): + """Assert that the constant setting graph is acyclic. + """ + assert nx.is_directed_acyclic_graph(self.get_cst_graph()),\ + ("Cycle detected for constant equations", + nx.find_cycle(self.get_update_graph())) diff --git a/pydynamo/core/system.py b/pydynamo/core/system.py index 7ddc09ba453c28e00bcb419e5415c2d45e2bd3ab..2a6dbe652c938ef3849e87f08250811695cdd0d8 100644 --- a/pydynamo/core/system.py +++ b/pydynamo/core/system.py @@ -2,7 +2,6 @@ """ import inspect import numpy as np -import networkx as nx from itertools import chain import re @@ -29,6 +28,8 @@ class System: from .plot_system import plot, plot_non_linearity, plot_compare, show_influence_graph from .politics import new_cst_politic, new_var_politic, new_table_politic, new_politic + from .graph import get_cst_graph, get_init_graph, get_update_graph, get_influence_graph, set_order + from .graph import assert_cst_acyclic, assert_init_acyclic, assert_update_acyclic def __init__(self, code=None, prepare=True): """Initialise a System, empty or from pydynamo code. @@ -115,8 +116,6 @@ class System: if prepare: self.prepare() - - def add_comments(self, comments): """Add comments to the System. @@ -328,123 +327,6 @@ class System: """ return var in self.eqs['init'] - # Graph - def get_cst_graph(self): - """Get the graph of influences for constants: an arrow from constant A to constant B if B needs A to be computed. - - Returns - ------- - networkx.DiGraph - Graph of constant influences. - """ - G = nx.DiGraph() - for c in self.eqs['cst']: - G.add_node(c) - args = self.eqs['cst'][c]['args'] - for node_type in args: - for node in args[node_type]: - if node_type == 'var': - raise(Exception( - f'Constant {c} is set with a variable ({"_".join(node)}):\n' - f'{c} = {self.eqs["cst"][c]["line"]}')) - if node_type == 'cst': - G.add_edge(node, c) - return G - - def get_init_graph(self): - """Get the graph of influences for variables at initialisation: an arrow from variable A to variable B if B needs A to be computed. - - Returns - ------- - networkx.DiGraph - Graph of variable initialisation influences. - """ - G = nx.DiGraph() - for v in self.eqs['init']: - G.add_node(v) - args = self.eqs['init'][v]['args'] - for av, ai in args['var']: - if av in self.eqs['update']: - G.add_edge(av, v) - for v in self.eqs['update']: - if not self.is_initialized(v): - G.add_node(v) - args = self.eqs['update'][v]['args'] - for av, ai in args['var']: - if av in self.eqs['update']: - G.add_edge(av, v) - return G - - def get_update_graph(self): - """Get the graph of influences for variables and their indices at updating step: an arrow from variable (A, i) to variable (B, k) if (B, k) needs (A, i) to be computed. - - Returns - ------- - networkx.DiGraph - Graph of variable influences at updating step. - """ - G = nx.DiGraph() - for v in self.eqs['update']: - G.add_node((v, 'k')) - args = self.eqs['update'][v]['args'] - for av, ai in args['var']: - if av in self.eqs['update']: - G.add_edge((av, ai), (v, 'k')) - return G - - def get_update_graph_quotient(self): - """Get the graph of influences for variables at updating step: an arrow from variable A to variable B if B needs A to be computed. - - Returns - ------- - networkx.DiGraph - Graph of variable influences at updating step. - """ - G = nx.DiGraph() - for v in self.eqs['update']: - G.add_node(v) - args = self.eqs['update'][v]['args'] - for av, _ in args['var']: - if av in self.eqs['update']: - G.add_edge(av, v) - return G - - def get_influence_graph(self): - """Get the graph of influences: an arrow from A to B if B needs A (at initialisation or updating step) to be computed. - - Returns - ------- - networkx.DiGraph - Graph of influences. - """ - G = nx.DiGraph() - for v in self.eqs['update']: - G.add_node(v) - args = self.eqs['update'][v]['args'] - for av, _ in args['var']: - if av in self.eqs['update']: - G.add_edge(av, v) - for ac in args['cst']: - if ac in self.eqs['cst']: - G.add_edge(ac, v) - - for v in self.eqs['init']: - G.add_node(v) - args = self.eqs['init'][v]['args'] - for av, _ in args['var']: - if av in self.eqs['init']: - G.add_edge(av, v) - for ac in args['cst']: - if ac in self.eqs['cst']: - G.add_edge(ac, v) - - for c in self.eqs['cst']: - G.add_node(c) - args = self.eqs['cst'][c]['args'] - for ac in args['cst']: - if ac in self.eqs['cst']: - G.add_edge(ac, c) - return G # Assertions def assert_cst_defined(self): @@ -466,6 +348,7 @@ class System: assert (v in self.eqs['init'] or v in self.eqs['update']) , f'Error: Variable {v} neither updated nor initialized' + # Functions setting def set_fun(self, node, fun_name, args, line): """Set an updating, initialisation or constant setting function to the System. It evaluates a lambda function with the line equation inside. @@ -604,8 +487,8 @@ class System: def set_all_csts(self): """Set every constant constant according to its equation and arguments ONLY IF the constant is not set yet. """ - G = self.get_cst_graph() - for cst in nx.topological_sort(G): + + for cst in self.set_order('cst'): try: arg_names = self.eqs['cst'][cst]['args']['cst'] args = {name: getattr(self, name) for name in arg_names} @@ -698,7 +581,7 @@ class System: def _init_all(self): """Initialise every variable. Iterate over the initialisaition order and set the first value of each variable. """ - for var in nx.topological_sort(self.get_init_graph()): + for var in self.set_order('init'): if self.is_initialized(var): update_fun = getattr(self, 'init_' + var) args = self.eqs['init'][var]['args'] @@ -748,7 +631,7 @@ class System: and is ordered in the topological order of the updating graph """ self._update_loop = [] - for var, i in nx.topological_sort(self.get_update_graph()): + for var, i in self.set_order('update'): if i == 'k': u = {'var': var} try: @@ -846,9 +729,7 @@ class System: self.assert_cst_defined() self.assert_init_defined() self.assert_update_defined() - assert nx.is_directed_acyclic_graph(self.get_cst_graph()),\ - ("Cycle detected for constant equations", - nx.find_cycle(self.get_update_graph())) + self.assert_cst_acyclic() self.set_all_funs() self.set_all_csts() @@ -878,32 +759,7 @@ class System: self.eqs[f_type][var] = {'args':dargs,'line': None} fun.__okdic__ = True - def assert_update_acyclic(self): - """Assert that the updating graph is acyclic, and print the cycle in case there is some. - """ - G = self.get_update_graph() - b = '\\' - assert nx.is_directed_acyclic_graph(G), \ - "Update is not acyclic:\n"\ - + "\n".join(f"{'.'.join(j)} <- {'.'.join(i)}: " - f"{'.'.join(j)} = {re.sub(b+'_([jk])', '.'+b+'1',self.eqs['update'][j[0]]['line'])}" - for i, j in reversed(nx.find_cycle(G)))\ - + "\nPlease design an update scheme that is not cyclic." - - def assert_init_acyclic(self): - """Assert that the initialisation graph is acyclic, and print the cycle in case there is some. - """ - G = self.get_init_graph() - if not nx.is_directed_acyclic_graph(G): - msg = "Initialisation is not acyclic:\n" - for i, j in reversed(nx.find_cycle(G)): - line = self.eqs['update'][j]['line'] - if j in self.eqs['init']: - line = self.eqs['init'][j]['line'] - msg += f"{j} <- {i}: " - msg += f"{j}.i = {line} \n" - msg +="Please design an initialisation scheme that is not cyclic." - raise AssertionError(msg) + def run(self, N=None, dt=1): """ @@ -971,4 +827,3 @@ class System: return False except: return any(a!=b) return {cst: (v1, v2) for cst, (v1, v2) in both.items() if diff(v1, v2)} -