Coverage for gpaw/doctools/codegraph.py: 97%
180 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""Tool for generating graps of objects."""
2from pathlib import Path
3from typing import Any
5import ase
6import numpy as np
7from gpaw.core.atom_arrays import AtomArraysLayout
8from gpaw.new.ase_interface import GPAW
9from gpaw.new.brillouin import BZPoints
10from gpaw.dft import Parameters
13def create_nodes(obj, *objects, include):
14 node1 = create_node(obj, include)
15 nodes = {node.name: node for node in node1.nodes()}
16 for obj in objects:
17 for node in create_node(obj, include).nodes():
18 if node.name not in nodes:
19 nodes[node.name] = node
21 for name, node in nodes.items():
22 node.subclasses = []
23 for cls in node.obj.__class__.__subclasses__():
24 if cls.__name__ in nodes:
25 node.subclasses.append(nodes[cls.__name__])
27 for _ in range(1):
28 new = {}
29 for name, node in nodes.items():
30 bases = node.obj.__class__.__bases__
31 cls = bases[0]
32 if cls is not object:
33 base = nodes.get(cls.__name__)
34 base = base or new.get(cls.__name__)
35 if base is None:
36 base = Node(node.obj, node.attrs, [], None)
37 base.name = cls.__name__
38 base.has = node.has.copy()
39 new[cls.__name__] = base
40 if node not in base.subclasses:
41 base.subclasses.append(node)
42 if node is not base:
43 node.base = base
45 for name, node in new.items():
46 nodes[name] = node
48 for node in nodes.values():
49 node.fix()
51 return list(nodes.values())
54def create_node(obj, include):
55 attrs = []
56 arrows = []
57 for key, value in obj.__dict__.items():
58 if key[:2] != '__':
59 if not include(value):
60 attrs.append(key)
61 else:
62 arrows.append(key)
63 return Node(obj, attrs, arrows, include)
66class Node:
67 def __init__(self, obj, attrs, arrows, include):
68 self.obj = obj
69 self.name = obj.__class__.__name__
70 self.attrs = attrs
71 self.has = {key: create_node(getattr(obj, key), include)
72 for key in arrows}
73 self.base = None
74 self.subclasses = []
75 self.rgb = None
77 def __repr__(self):
78 return (f'Node({self.name}, {self.attrs}, {list(self.has)}, ' +
79 f'{self.base.name if self.base is not None else None}, ' +
80 f'{[o.name for o in self.subclasses]})')
82 def nodes(self):
83 yield self
84 for node in self.has.values():
85 yield from node.nodes()
87 def keys(self):
88 return set(self.attrs + list(self.has))
90 def superclass(self):
91 return self if self.base is None else self.base.superclass()
93 def fix(self):
94 if self.subclasses:
95 if len(self.subclasses) > 1:
96 keys = self.subclasses[0].keys()
97 for node in self.subclasses[1:]:
98 keys &= node.keys()
99 else:
100 keys = self.keys()
101 self.attrs = [attr for attr in self.attrs if attr in keys]
102 self.has = {key: value for key, value in self.has.items()
103 if key in keys}
104 for obj in self.subclasses:
105 obj.attrs = [attr for attr in obj.attrs if attr not in keys]
106 obj.has = {key: value for key, value in obj.has.items()
107 if key not in keys}
109 def color(self, rgb):
110 self.rgb = rgb
111 for obj in self.subclasses:
112 obj.color(rgb)
114 def plot(self, g):
115 kwargs = {'style': 'filled',
116 'fillcolor': self.rgb} if self.rgb else {}
117 if self.attrs:
118 a = r'\n'.join(add_type(attr, getattr(self.obj, attr))
119 for attr in self.attrs)
120 txt = f'{{{self.name} | {a}}}'
121 else:
122 txt = self.name
123 g.node(self.name, txt, **kwargs)
126def add_type(name: str, obj: Any) -> str:
127 type = obj.__class__.__name__
128 return f'{name}: {type}'
131def plot_graph(figname, nodes, colors={}, replace={}):
132 import graphviz
133 g = graphviz.Digraph(node_attr={'shape': 'record'})
135 for node in nodes:
136 if node.name in colors:
137 node.color(colors[node.name])
139 for node in nodes:
140 node.plot(g)
141 for key, value in node.has.items():
142 key = replace.get(key, key)
143 g.edge(node.name, value.superclass().name, label=key)
144 if node.base:
145 g.edge(node.base.name, node.name, arrowhead='onormal')
147 g.render(figname, format='svg')
149 try:
150 Path(figname).unlink() # remove "dot" file
151 except FileNotFoundError:
152 pass
155def abc():
156 class A:
157 def __init__(self, b):
158 self.a = 1
159 self.b = b
161 def m(self):
162 pass
164 class B:
165 pass
167 class C(B):
168 pass
170 nodes = create_nodes(
171 A(C()), B(),
172 include=lambda obj: obj.__class__.__name__ in 'ABC')
173 plot_graph('abc', nodes, {'B': '#ffddff'})
176def code():
177 fd = GPAW(mode='fd', txt=None)
178 pw = GPAW(mode='pw', txt=None)
179 lcao = GPAW(mode='lcao', txt=None)
180 a = ase.Atoms('H', cell=[2, 2, 2], pbc=1)
182 class Atoms:
183 def __init__(self, calc):
184 self.calc = calc
186 a0 = Atoms(fd)
187 fd.get_potential_energy(a)
188 pw.get_potential_energy(a)
189 lcao.get_potential_energy(a)
190 ibzwfs = fd.dft.ibzwfs
191 ibzwfs.wfs_qs = ibzwfs.wfs_qs[0][0]
193 colors = {'BZPoints': '#ddffdd',
194 'PotentialCalculator': '#ffdddd',
195 'WaveFunctions': '#ddddff',
196 'Eigensolver': '#ffffdd',
197 'PoissonSolver': '#ffeedd',
198 'Hamiltonian': '#eeeeee'}
200 def include(obj):
201 try:
202 mod = obj.__module__
203 except AttributeError:
204 return False
206 return mod.startswith('gpaw.new')
208 things = [pw, lcao,
209 lcao.dft.ibzwfs.wfs_qs[0][0],
210 BZPoints(np.zeros((5, 3)))]
211 nodes = create_nodes(a0, *things, include=include)
212 plot_graph('code', nodes, colors,
213 replace={'wfs_qs': 'wfs_qs[q][s]'})
215 # scf.svg:
216 nodes = create_nodes(
217 fd.dft.density.nct_aX,
218 pw.dft.density.nct_aX,
219 include=lambda obj: obj.__class__.__name__.startswith('Atom'))
220 plot_graph('acf', nodes, {'AtomCenteredFunctions': '#ddffff'})
222 # da.svg:
223 nodes = create_nodes(
224 fd.dft.ibzwfs.wfs_qs.psit_nX,
225 pw.dft.ibzwfs.wfs_qs[0][0].psit_nX,
226 include=lambda obj:
227 getattr(obj, '__module__', '').startswith('gpaw.core') and
228 obj.__class__.__name__ != '_lru_cache_wrapper')
229 plot_graph('da', nodes, {'DistributedArrays': '#eeeeee',
230 'Domain': '#dddddd'})
233def builders():
234 b = []
235 a = ase.Atoms('H', cell=[2, 2, 2], pbc=1)
236 for mode in ['fd', 'pw', 'lcao']:
237 b.append(Parameters(mode=mode).dft_component_builder(a))
238 nodes = create_nodes(
239 *b,
240 include=lambda obj:
241 obj.__class__.__name__.endswith('Builder') or
242 obj.__class__.__name__ == 'InputParameters')
243 plot_graph('builder', nodes, {'DFTComponentsBuilder': '#ffeedd'})
246def aa():
247 nodes = create_nodes(
248 AtomArraysLayout([1]).empty(),
249 include=lambda obj: obj.__class__.__name__.startswith('Atom'))
250 nodes = [node for node in nodes if node.name != 'AtomArrays']
251 for node in nodes:
252 print(node)
253 print(node.has)
254 if node.name == 'DistributedArrays':
255 node.name = 'AtomArrays'
256 node.attrs.remove('dv')
257 break
258 plot_graph('aa', [node for node in nodes if node.name[0] == 'A'])
261def main():
262 abc()
263 code()
264 builders()
265 aa()
268if __name__ == '__main__':
269 main()