Coverage for gpaw/atom/plot_dataset.py: 89%
287 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1from __future__ import annotations
3import functools
4import textwrap
5import os
6from ast import literal_eval
7from collections.abc import Callable, Iterable
8from types import SimpleNamespace
9from typing import Any, TYPE_CHECKING
10from xml.dom import minidom
11from warnings import warn
13import numpy as np
15from .. import typing
16from ..basis_data import Basis, BasisPlotter
17from ..setup_data import SetupData, read_maybe_unzipping, search_for_file
18from .aeatom import AllElectronAtom, colors
19from .generator2 import (PAWSetupGenerator, parameters,
20 generate, plot_log_derivs)
21from .radialgd import AERadialGridDescriptor
23if TYPE_CHECKING:
24 from matplotlib.axes import Axes
25 from matplotlib.figure import Figure
28_PartialWaveItem = tuple[int, # l
29 int, # n
30 float, # r_cut
31 float, # energy
32 typing.Array1D, # phi_g
33 typing.Array1D] # phit_g
34_ProjectorItem = tuple[int, # l
35 int, # n
36 float, # energy
37 typing.Array1D] # pt_g
40def plot_partial_waves(ax: 'Axes',
41 symbol: str,
42 name: str,
43 rgd: AERadialGridDescriptor,
44 cutoff: float,
45 iterator: Iterable[_PartialWaveItem]) -> None:
46 r_g = rgd.r_g
47 group_by_l: dict[int, list[_PartialWaveItem]] = {}
48 bg_color = _get_patch_color(ax)
49 for item in sorted(iterator):
50 group_by_l.setdefault(item[0], []).append(item)
51 for l, items in group_by_l.items():
52 weights = _get_blend_weights(len(items))
53 for weight, (_, n, rcut, e, phi_g, phit_g) in zip(weights, items):
54 if n == -1:
55 gc = rgd.ceil(rcut)
56 label = '*{} ({:.2f} Ha)'.format('spdf'[l], e)
57 else:
58 gc = len(rgd)
59 label = '%d%s (%.2f Ha)' % (n, 'spdf'[l], e)
60 color = _blend_colors(colors[l],
61 background=bg_color,
62 foreground_weight=weight)
63 ax.plot(r_g[:gc], (phi_g * r_g)[:gc], color=color, label=label)
64 ax.plot(r_g[:gc], (phit_g * r_g)[:gc], '--', color=color)
65 ax.axis(xmin=0, xmax=3 * cutoff)
66 ax.set_title(f'Partial waves: {symbol} {name}')
67 ax.set_xlabel('radius [Bohr]')
68 ax.set_ylabel(r'$r\phi_{n\ell}(r)$')
69 ax.legend()
72def plot_projectors(ax: 'Axes',
73 symbol: str,
74 name: str,
75 rgd: AERadialGridDescriptor,
76 cutoff: float,
77 iterator: Iterable[_ProjectorItem]) -> None:
78 r_g = rgd.r_g
79 group_by_l: dict[int, list[_ProjectorItem]] = {}
80 bg_color = _get_patch_color(ax)
81 for item in sorted(iterator):
82 group_by_l.setdefault(item[0], []).append(item)
83 for l, items in group_by_l.items():
84 weights = _get_blend_weights(len(items))
85 for weight, (_, n, e, pt_g) in zip(weights, items):
86 if n == -1:
87 label = '*{} ({:.2f} Ha)'.format('spdf'[l], e)
88 else:
89 label = '%d%s (%.2f Ha)' % (n, 'spdf'[l], e)
90 ax.plot(r_g, pt_g * r_g,
91 color=_blend_colors(colors[l],
92 background=bg_color,
93 foreground_weight=weight),
94 label=label)
95 ax.axis(xmin=0, xmax=cutoff)
96 ax.set_title(f'Projectors: {symbol} {name}')
97 ax.set_xlabel('radius [Bohr]')
98 ax.set_ylabel(r'$r\tilde{p}(r)$')
99 ax.legend()
102def plot_potential_components(ax: 'Axes',
103 symbol: str,
104 name: str,
105 rgd: AERadialGridDescriptor,
106 cutoff: float,
107 components: dict[str, typing.Array1D]) -> None:
108 assert components
109 radial_grid = rgd.r_g
110 for color, (key, label) in zip(
111 colors,
112 [('xc', 'xc'), ('zero', '0'), ('hartree', 'H'),
113 ('pseudo', 'ps'), ('all_electron', 'ae')]):
114 if key in components:
115 ax.plot(radial_grid, components[key], color=color, label=label)
116 arrays = [array for key, array in components.items()
117 if key != 'all_electron']
118 ax.axis(xmin=0,
119 xmax=2 * cutoff,
120 ymin=min(array[1:].min() for array in arrays),
121 ymax=max(0, *(array[1:].max() for array in arrays)))
122 ax.set_title(f'Potential components: {symbol} {name}')
123 ax.set_xlabel('radius [Bohr]')
124 ax.set_ylabel('potential [Ha]')
125 ax.legend()
128def _get_setup_symbol_and_name(setup: SetupData) -> tuple[str, str]:
129 return setup.symbol, setup.setupname
132def _get_gen_symbol_and_name(gen: PAWSetupGenerator) -> tuple[str, str]:
133 aea = gen.aea
134 return aea.symbol, aea.xc.name
137def _get_setup_cutoff(setup: SetupData) -> float:
138 cutoff = setup.r0
139 if cutoff is not None:
140 return cutoff
142 # `.r0` can be `None` for 'old setups', whatever that means
143 name = f'{setup.symbol}{setup.Nv}'
144 params = parameters[name]
145 if len(params) == 3:
146 _, radii, extra = params
147 else:
148 _, radii = params
149 extra = {}
150 if 'r0' in extra: # E.g. N5
151 value = extra['r0']
152 if TYPE_CHECKING:
153 assert isinstance(value, float)
154 return value
155 if not isinstance(radii, Iterable):
156 radii = [radii]
157 return min(radii)
160def _normalize_with_radial_grid(array: typing.Array1D,
161 rgd: AERadialGridDescriptor) -> typing.Array1D:
162 result = rgd.zeros()
163 result[0] = np.nan
164 result[1:] = array[1:] / rgd.r_g[1:]
165 return result
168def _get_blend_weights(n: int, attenuation: float = .5) -> typing.Array1D:
169 return (1 - attenuation) ** np.arange(n)
172def _get_patch_color(ax: 'Axes') -> tuple[float, float, float]:
173 from matplotlib.colors import to_rgb
174 try:
175 color = ax.patch.get_facecolor()
176 if color is None:
177 color = 'w'
178 except AttributeError:
179 color = 'w'
180 return to_rgb(color)
183def _blend_colors(color, background='w', foreground_weight=1.):
184 # Too troublesome to type this and refactor to have `mypy`
185 # understand what we're doing, not worth it
186 from matplotlib.colors import to_rgb
188 color = np.array(to_rgb(color))
189 background = np.array(to_rgb(background))
190 return color * foreground_weight + background * (1 - foreground_weight)
193def get_plot_pwaves_params_from_generator(
194 gen: PAWSetupGenerator,
195) -> tuple[str, str,
196 AERadialGridDescriptor, float,
197 Iterable[_PartialWaveItem]]:
198 return (*_get_gen_symbol_and_name(gen),
199 gen.rgd,
200 gen.rcmax,
201 ((l, n, waves.rcut, e, phi_g, phit_g)
202 for l, waves in enumerate(gen.waves_l)
203 for n, e, phi_g, phit_g in zip(waves.n_n, waves.e_n,
204 waves.phi_ng, waves.phit_ng)))
207def get_plot_pwaves_params_from_setup(
208 setup: SetupData,
209) -> tuple[str, str,
210 AERadialGridDescriptor, float,
211 Iterable[_PartialWaveItem]]:
212 return (*_get_setup_symbol_and_name(setup),
213 setup.rgd,
214 _get_setup_cutoff(setup),
215 zip(setup.l_j, setup.n_j, setup.rcut_j, setup.eps_j,
216 setup.phi_jg, setup.phit_jg))
219def get_plot_projs_params_from_generator(
220 gen: PAWSetupGenerator,
221) -> tuple[str, str, AERadialGridDescriptor, float, Iterable[_ProjectorItem]]:
222 return (*_get_gen_symbol_and_name(gen),
223 gen.rgd,
224 gen.rcmax,
225 ((l, n, e, pt_g)
226 for l, waves in enumerate(gen.waves_l)
227 for n, e, pt_g in zip(waves.n_n, waves.e_n, waves.pt_ng)))
230def get_plot_projs_params_from_setup(
231 setup: SetupData,
232) -> tuple[str, str, AERadialGridDescriptor, float, Iterable[_ProjectorItem]]:
233 return (*_get_setup_symbol_and_name(setup),
234 setup.rgd,
235 _get_setup_cutoff(setup),
236 zip(setup.l_j, setup.n_j, setup.eps_j, setup.pt_jg))
239def get_plot_pot_comps_params_from_generator(
240 gen: PAWSetupGenerator,
241) -> tuple[str, str, AERadialGridDescriptor, float, dict[str, typing.Array1D]]:
242 assert gen.vtr_g is not None # Appease `mypy`
244 rgd = gen.rgd
245 normalize = functools.partial(_normalize_with_radial_grid, rgd=rgd)
246 zero = normalize(gen.v0r_g)
247 hartree = normalize(gen.vHtr_g)
248 pseudo = normalize(gen.vtr_g)
249 all_electron = normalize(gen.aea.vr_sg[0])
250 components = {'xc': gen.vxct_g,
251 'zero': zero,
252 'hartree': hartree,
253 'pseudo': pseudo,
254 'all_electron': all_electron}
255 return (*_get_gen_symbol_and_name(gen), rgd, gen.rcmax, components)
258def get_plot_pot_comps_params_from_setup(
259 setup: SetupData,
260) -> tuple[str, str, AERadialGridDescriptor, float, dict[str, typing.Array1D]]:
261 prefactor = (4 * np.pi) ** -.5
262 zero = setup.vbar_g * prefactor
263 if setup.vt_g is None:
264 pseudo = None
265 else:
266 pseudo = setup.vt_g * prefactor
267 symbol, xc_name = _get_setup_symbol_and_name(setup)
268 rgd = setup.rgd
269 normalize = functools.partial(_normalize_with_radial_grid, rgd=rgd)
271 # Reconstruct the AEA object
272 # (Note: this misses the empty bound states from projectors)
273 aea = AllElectronAtom(symbol, xc_name,
274 Z=setup.Z,
275 configuration=list(zip(setup.n_j, setup.l_j,
276 setup.f_j, setup.eps_j)))
277 if setup.has_corehole:
278 aea.add(setup.ncorehole, setup.lcorehole, -setup.fcorehole)
279 aea.initialize(rgd.N)
280 aea.run()
281 aea.scalar_relativistic = setup.type == 'scalar-relativistic'
282 aea.refine()
283 all_electron = normalize(aea.vr_sg[0])
284 components = {'zero': zero, 'all_electron': all_electron}
286 # FIXME: inconsistent with the `PAWSetupGenerator` results
287 # # Re-calculate the XC and Hartree parts
288 # from ..xc import XC
289 # from .all_electron import calculate_density, calculate_potentials
291 # n_g = calculate_density(setup.f_j,
292 # setup.phi_jg * rgd.r_g[None, :],
293 # rgd.r_g)
294 # n_g += setup.nc_g * prefactor
295 # _, vHr_g, xc, _ = calculate_potentials(rgd, XC(xc_name), n_g,
296 # setup.Z)
297 # hartree = normalize(vHr_g)
298 # components.update(xc=xc, hartree=hartree)
300 if pseudo is not None:
301 components['pseudo'] = pseudo
302 return (symbol, xc_name, rgd, _get_setup_cutoff(setup), components)
305def reconstruct_paw_gen(setup: SetupData,
306 basis: Basis | None = None) -> PAWSetupGenerator:
307 params = {'v0': None, **parse_generator_data(setup.generatordata)}
308 gen = generate(**params)
309 if basis is not None:
310 gen.basis = basis
311 return gen
314def read_basis_file(basis: str) -> Basis:
315 symbol, *chunks, end = os.path.basename(basis).split('.')
316 if end == 'gz':
317 *chunks, end = chunks
318 assert end == 'basis'
319 name = '.'.join(chunks)
320 return Basis.read_xml(symbol, name, basis)
323def read_setup_file(dataset: str) -> SetupData:
324 symbol, *name, xc = os.path.basename(dataset).split('.')
325 if xc == 'gz':
326 *name, xc = name
327 setup = SetupData(symbol, xc, readxml=False)
328 setup.read_xml(read_maybe_unzipping(dataset))
329 if not setup.generatordata:
330 generator, = (minidom.parseString(read_maybe_unzipping(dataset))
331 .getElementsByTagName('generator'))
332 text, = generator.childNodes
333 assert isinstance(text, minidom.Text)
334 setup.generatordata = textwrap.dedent(text.data).strip('\n')
335 return setup
338def parse_generator_data(data: str) -> dict[str, Any]:
339 params: dict[str, Any] = {}
340 for line in data.splitlines():
341 key, sep, value = line.rstrip(',').partition('=')
342 if not (sep and key.isidentifier()):
343 continue
344 try:
345 value = literal_eval(value)
346 except Exception:
347 continue
348 params[key] = value
349 return params
352def _get_figures_and_axes(
353 ngraphs: int,
354 separate_figures: bool = False) -> tuple[list['Figure'], list['Axes']]:
355 from matplotlib import pyplot as plt
357 if separate_figures:
358 figs = []
359 ax_objs = []
360 for _ in range(ngraphs):
361 fig = plt.figure()
362 figs.append(fig)
363 ax_objs.append(fig.gca())
364 return figs, ax_objs
366 assert ngraphs <= 6, f'Too many plots; expected <= 6, got {ngraphs}'
367 if ngraphs > 4:
368 layout = 2, 3
369 elif ngraphs > 2:
370 layout = 2, 2
371 else:
372 layout = 1, ngraphs
374 fig = plt.figure()
375 subplots = fig.subplots(*layout).flatten() # type: ignore
376 ntrimmed = layout[0] * layout[1] - ngraphs
377 if ntrimmed:
378 assert ntrimmed > 0, (f'Too many plots {ngraphs!r} '
379 f'for the layout {layout!r}')
380 for ax in subplots[-ntrimmed:]: # Remove unused subplots
381 ax.remove()
383 return [fig] * ngraphs, subplots[:ngraphs].tolist()
386def plot_dataset(
387 setup: SetupData,
388 *,
389 basis: Basis | None = None,
390 gen: PAWSetupGenerator | None = None,
391 plot_potential_components: bool = True,
392 plot_partial_waves: bool = True,
393 plot_projectors: bool = True,
394 plot_logarithmic_derivatives: str | None = None,
395 separate_figures: bool = False,
396 savefig: str | None = None,
397) -> tuple[list['Axes'], str | None]:
398 """
399 Return
400 ------
401 2-tuple: `tuple[list[Axes], <filename> | None]`
402 """
403 if gen is not None:
404 reconstruct = False
405 elif plot_logarithmic_derivatives or plot_potential_components:
406 reconstruct = True
407 else:
408 reconstruct = False
409 if reconstruct:
410 data = setup.generatordata
411 if parse_generator_data(data):
412 gen = reconstruct_paw_gen(setup, basis)
413 else:
414 if data:
415 data_status = 'malformed'
416 else:
417 data_status = 'missing'
418 msg = ('cannot reconstruct the `PAWSetupGenerator` object '
419 f'({data_status} `setup.generatordata`), '
420 'so the logarithmic derivatives and/or '
421 '(some of) the potential components cannot be plotted')
422 warn(msg, stacklevel=2)
423 plot_logarithmic_derivatives = None
425 plots: list[Callable] = []
427 if gen is None:
428 (symbol, name,
429 rgd, cutoff, ppw_iter) = get_plot_pwaves_params_from_setup(setup)
430 *_, pp_iter = get_plot_projs_params_from_setup(setup)
431 *_, pot_comps = get_plot_pot_comps_params_from_setup(setup)
432 else:
433 # TODO: maybe we can compare the `ppw_iter` and `pp_iter`
434 # between the stored and regenerated values for verification
435 (symbol, name,
436 rgd, cutoff, ppw_iter) = get_plot_pwaves_params_from_generator(gen)
437 *_, pp_iter = get_plot_projs_params_from_generator(gen)
438 *_, pot_comps = get_plot_pot_comps_params_from_generator(gen)
440 if plot_logarithmic_derivatives:
441 assert gen is not None
442 plots.append(functools.partial(
443 plot_log_derivs, gen, plot_logarithmic_derivatives, True))
444 if plot_potential_components:
445 plots.append(functools.partial(
446 # Name clash with local variable
447 globals()['plot_potential_components'],
448 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff,
449 components=pot_comps))
450 if plot_partial_waves:
451 plots.append(functools.partial(
452 # Name clash with local variable
453 globals()['plot_partial_waves'],
454 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff,
455 iterator=ppw_iter))
456 if plot_projectors:
457 plots.append(functools.partial(
458 # Name clash with local variable
459 globals()['plot_projectors'],
460 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff,
461 iterator=pp_iter))
463 if basis is not None:
464 plots.append(functools.partial(BasisPlotter().plot, basis))
466 if savefig is not None:
467 separate_figures = False
468 figs, ax_objs = _get_figures_and_axes(len(plots), separate_figures)
469 assert len(figs) == len(ax_objs) == len(plots)
470 for ax, plot_func in zip(ax_objs, plots):
471 plot_func(ax=ax)
473 if savefig is not None:
474 assert len({id(fig) for fig in figs}) == 1
475 fig, *_ = figs
476 assert fig is not None
477 fig.savefig(savefig)
479 return ax_objs, savefig
482def main(args: SimpleNamespace) -> list['Axes']:
483 from matplotlib import pyplot as plt
485 if args.search:
486 args.dataset, _ = search_for_file(args.dataset)
487 setup = read_setup_file(args.dataset)
488 sep_figs = args.outfile is None and args.separate_figures
489 ax_objs, fname = plot_dataset(
490 setup,
491 separate_figures=sep_figs,
492 plot_potential_components=args.potential_components,
493 plot_logarithmic_derivatives=args.logarithmic_derivatives,
494 savefig=args.outfile)
495 assert ax_objs
497 if fname is None:
498 plt.show()
499 return ax_objs
502class CLICommand:
503 """Plot the PAW dataset,
504 which by default includes the partial waves and the projectors.
505 """
506 @staticmethod
507 def add_arguments(parser):
508 add = parser.add_argument
509 add('-p', '--potential-components',
510 action='store_true',
511 help='Plot the potential components '
512 '(this reconstructs the full PAW setup generator object)')
513 add('-l', '--logarithmic-derivatives',
514 metavar='spdfg,e1:e2:de,radius',
515 help='Plot logarithmic derivatives '
516 '(this reconstructs the full PAW setup generator object). '
517 'Example: -l spdf,-1:1:0.05,1.3. '
518 'Energy range and/or radius can be left out.')
519 add('-s', '--separate-figures',
520 action='store_true',
521 help='If not plotting to a file, '
522 'plot the plots in separate figure windows/tabs, '
523 'instead of as subplots/panels in the same figure')
524 add('-S', '--search',
525 action='store_true',
526 help='Look into the installed datasets (see `gpaw info`) for the '
527 'XML file, instead of treating it as a path')
528 add('-o', '--outfile', '--write',
529 metavar='FILE',
530 help='Write the plots to FILE instead of `plt.show()`-ing them')
531 add('dataset',
532 metavar='FILE',
533 help='XML file from which to read the PAW dataset')
535 run = staticmethod(main)