Coverage for gpaw/doctools/codegraph.py: 97%

180 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""Tool for generating graps of objects.""" 

2from pathlib import Path 

3from typing import Any 

4 

5import ase 

6import numpy as np 

7from gpaw.core.atom_arrays import AtomArraysLayout 

8from gpaw.new.ase_interface import GPAW 

9from gpaw.new.brillouin import BZPoints 

10from gpaw.dft import Parameters 

11 

12 

13def create_nodes(obj, *objects, include): 

14 node1 = create_node(obj, include) 

15 nodes = {node.name: node for node in node1.nodes()} 

16 for obj in objects: 

17 for node in create_node(obj, include).nodes(): 

18 if node.name not in nodes: 

19 nodes[node.name] = node 

20 

21 for name, node in nodes.items(): 

22 node.subclasses = [] 

23 for cls in node.obj.__class__.__subclasses__(): 

24 if cls.__name__ in nodes: 

25 node.subclasses.append(nodes[cls.__name__]) 

26 

27 for _ in range(1): 

28 new = {} 

29 for name, node in nodes.items(): 

30 bases = node.obj.__class__.__bases__ 

31 cls = bases[0] 

32 if cls is not object: 

33 base = nodes.get(cls.__name__) 

34 base = base or new.get(cls.__name__) 

35 if base is None: 

36 base = Node(node.obj, node.attrs, [], None) 

37 base.name = cls.__name__ 

38 base.has = node.has.copy() 

39 new[cls.__name__] = base 

40 if node not in base.subclasses: 

41 base.subclasses.append(node) 

42 if node is not base: 

43 node.base = base 

44 

45 for name, node in new.items(): 

46 nodes[name] = node 

47 

48 for node in nodes.values(): 

49 node.fix() 

50 

51 return list(nodes.values()) 

52 

53 

54def create_node(obj, include): 

55 attrs = [] 

56 arrows = [] 

57 for key, value in obj.__dict__.items(): 

58 if key[:2] != '__': 

59 if not include(value): 

60 attrs.append(key) 

61 else: 

62 arrows.append(key) 

63 return Node(obj, attrs, arrows, include) 

64 

65 

66class Node: 

67 def __init__(self, obj, attrs, arrows, include): 

68 self.obj = obj 

69 self.name = obj.__class__.__name__ 

70 self.attrs = attrs 

71 self.has = {key: create_node(getattr(obj, key), include) 

72 for key in arrows} 

73 self.base = None 

74 self.subclasses = [] 

75 self.rgb = None 

76 

77 def __repr__(self): 

78 return (f'Node({self.name}, {self.attrs}, {list(self.has)}, ' + 

79 f'{self.base.name if self.base is not None else None}, ' + 

80 f'{[o.name for o in self.subclasses]})') 

81 

82 def nodes(self): 

83 yield self 

84 for node in self.has.values(): 

85 yield from node.nodes() 

86 

87 def keys(self): 

88 return set(self.attrs + list(self.has)) 

89 

90 def superclass(self): 

91 return self if self.base is None else self.base.superclass() 

92 

93 def fix(self): 

94 if self.subclasses: 

95 if len(self.subclasses) > 1: 

96 keys = self.subclasses[0].keys() 

97 for node in self.subclasses[1:]: 

98 keys &= node.keys() 

99 else: 

100 keys = self.keys() 

101 self.attrs = [attr for attr in self.attrs if attr in keys] 

102 self.has = {key: value for key, value in self.has.items() 

103 if key in keys} 

104 for obj in self.subclasses: 

105 obj.attrs = [attr for attr in obj.attrs if attr not in keys] 

106 obj.has = {key: value for key, value in obj.has.items() 

107 if key not in keys} 

108 

109 def color(self, rgb): 

110 self.rgb = rgb 

111 for obj in self.subclasses: 

112 obj.color(rgb) 

113 

114 def plot(self, g): 

115 kwargs = {'style': 'filled', 

116 'fillcolor': self.rgb} if self.rgb else {} 

117 if self.attrs: 

118 a = r'\n'.join(add_type(attr, getattr(self.obj, attr)) 

119 for attr in self.attrs) 

120 txt = f'{{{self.name} | {a}}}' 

121 else: 

122 txt = self.name 

123 g.node(self.name, txt, **kwargs) 

124 

125 

126def add_type(name: str, obj: Any) -> str: 

127 type = obj.__class__.__name__ 

128 return f'{name}: {type}' 

129 

130 

131def plot_graph(figname, nodes, colors={}, replace={}): 

132 import graphviz 

133 g = graphviz.Digraph(node_attr={'shape': 'record'}) 

134 

135 for node in nodes: 

136 if node.name in colors: 

137 node.color(colors[node.name]) 

138 

139 for node in nodes: 

140 node.plot(g) 

141 for key, value in node.has.items(): 

142 key = replace.get(key, key) 

143 g.edge(node.name, value.superclass().name, label=key) 

144 if node.base: 

145 g.edge(node.base.name, node.name, arrowhead='onormal') 

146 

147 g.render(figname, format='svg') 

148 

149 try: 

150 Path(figname).unlink() # remove "dot" file 

151 except FileNotFoundError: 

152 pass 

153 

154 

155def abc(): 

156 class A: 

157 def __init__(self, b): 

158 self.a = 1 

159 self.b = b 

160 

161 def m(self): 

162 pass 

163 

164 class B: 

165 pass 

166 

167 class C(B): 

168 pass 

169 

170 nodes = create_nodes( 

171 A(C()), B(), 

172 include=lambda obj: obj.__class__.__name__ in 'ABC') 

173 plot_graph('abc', nodes, {'B': '#ffddff'}) 

174 

175 

176def code(): 

177 fd = GPAW(mode='fd', txt=None) 

178 pw = GPAW(mode='pw', txt=None) 

179 lcao = GPAW(mode='lcao', txt=None) 

180 a = ase.Atoms('H', cell=[2, 2, 2], pbc=1) 

181 

182 class Atoms: 

183 def __init__(self, calc): 

184 self.calc = calc 

185 

186 a0 = Atoms(fd) 

187 fd.get_potential_energy(a) 

188 pw.get_potential_energy(a) 

189 lcao.get_potential_energy(a) 

190 ibzwfs = fd.dft.ibzwfs 

191 ibzwfs.wfs_qs = ibzwfs.wfs_qs[0][0] 

192 

193 colors = {'BZPoints': '#ddffdd', 

194 'PotentialCalculator': '#ffdddd', 

195 'WaveFunctions': '#ddddff', 

196 'Eigensolver': '#ffffdd', 

197 'PoissonSolver': '#ffeedd', 

198 'Hamiltonian': '#eeeeee'} 

199 

200 def include(obj): 

201 try: 

202 mod = obj.__module__ 

203 except AttributeError: 

204 return False 

205 

206 return mod.startswith('gpaw.new') 

207 

208 things = [pw, lcao, 

209 lcao.dft.ibzwfs.wfs_qs[0][0], 

210 BZPoints(np.zeros((5, 3)))] 

211 nodes = create_nodes(a0, *things, include=include) 

212 plot_graph('code', nodes, colors, 

213 replace={'wfs_qs': 'wfs_qs[q][s]'}) 

214 

215 # scf.svg: 

216 nodes = create_nodes( 

217 fd.dft.density.nct_aX, 

218 pw.dft.density.nct_aX, 

219 include=lambda obj: obj.__class__.__name__.startswith('Atom')) 

220 plot_graph('acf', nodes, {'AtomCenteredFunctions': '#ddffff'}) 

221 

222 # da.svg: 

223 nodes = create_nodes( 

224 fd.dft.ibzwfs.wfs_qs.psit_nX, 

225 pw.dft.ibzwfs.wfs_qs[0][0].psit_nX, 

226 include=lambda obj: 

227 getattr(obj, '__module__', '').startswith('gpaw.core') and 

228 obj.__class__.__name__ != '_lru_cache_wrapper') 

229 plot_graph('da', nodes, {'DistributedArrays': '#eeeeee', 

230 'Domain': '#dddddd'}) 

231 

232 

233def builders(): 

234 b = [] 

235 a = ase.Atoms('H', cell=[2, 2, 2], pbc=1) 

236 for mode in ['fd', 'pw', 'lcao']: 

237 b.append(Parameters(mode=mode).dft_component_builder(a)) 

238 nodes = create_nodes( 

239 *b, 

240 include=lambda obj: 

241 obj.__class__.__name__.endswith('Builder') or 

242 obj.__class__.__name__ == 'InputParameters') 

243 plot_graph('builder', nodes, {'DFTComponentsBuilder': '#ffeedd'}) 

244 

245 

246def aa(): 

247 nodes = create_nodes( 

248 AtomArraysLayout([1]).empty(), 

249 include=lambda obj: obj.__class__.__name__.startswith('Atom')) 

250 nodes = [node for node in nodes if node.name != 'AtomArrays'] 

251 for node in nodes: 

252 print(node) 

253 print(node.has) 

254 if node.name == 'DistributedArrays': 

255 node.name = 'AtomArrays' 

256 node.attrs.remove('dv') 

257 break 

258 plot_graph('aa', [node for node in nodes if node.name[0] == 'A']) 

259 

260 

261def main(): 

262 abc() 

263 code() 

264 builders() 

265 aa() 

266 

267 

268if __name__ == '__main__': 

269 main()