Coverage for gpaw/atom/plot_dataset.py: 89%

287 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1from __future__ import annotations 

2 

3import functools 

4import textwrap 

5import os 

6from ast import literal_eval 

7from collections.abc import Callable, Iterable 

8from types import SimpleNamespace 

9from typing import Any, TYPE_CHECKING 

10from xml.dom import minidom 

11from warnings import warn 

12 

13import numpy as np 

14 

15from .. import typing 

16from ..basis_data import Basis, BasisPlotter 

17from ..setup_data import SetupData, read_maybe_unzipping, search_for_file 

18from .aeatom import AllElectronAtom, colors 

19from .generator2 import (PAWSetupGenerator, parameters, 

20 generate, plot_log_derivs) 

21from .radialgd import AERadialGridDescriptor 

22 

23if TYPE_CHECKING: 

24 from matplotlib.axes import Axes 

25 from matplotlib.figure import Figure 

26 

27 

28_PartialWaveItem = tuple[int, # l 

29 int, # n 

30 float, # r_cut 

31 float, # energy 

32 typing.Array1D, # phi_g 

33 typing.Array1D] # phit_g 

34_ProjectorItem = tuple[int, # l 

35 int, # n 

36 float, # energy 

37 typing.Array1D] # pt_g 

38 

39 

40def plot_partial_waves(ax: 'Axes', 

41 symbol: str, 

42 name: str, 

43 rgd: AERadialGridDescriptor, 

44 cutoff: float, 

45 iterator: Iterable[_PartialWaveItem]) -> None: 

46 r_g = rgd.r_g 

47 group_by_l: dict[int, list[_PartialWaveItem]] = {} 

48 bg_color = _get_patch_color(ax) 

49 for item in sorted(iterator): 

50 group_by_l.setdefault(item[0], []).append(item) 

51 for l, items in group_by_l.items(): 

52 weights = _get_blend_weights(len(items)) 

53 for weight, (_, n, rcut, e, phi_g, phit_g) in zip(weights, items): 

54 if n == -1: 

55 gc = rgd.ceil(rcut) 

56 label = '*{} ({:.2f} Ha)'.format('spdf'[l], e) 

57 else: 

58 gc = len(rgd) 

59 label = '%d%s (%.2f Ha)' % (n, 'spdf'[l], e) 

60 color = _blend_colors(colors[l], 

61 background=bg_color, 

62 foreground_weight=weight) 

63 ax.plot(r_g[:gc], (phi_g * r_g)[:gc], color=color, label=label) 

64 ax.plot(r_g[:gc], (phit_g * r_g)[:gc], '--', color=color) 

65 ax.axis(xmin=0, xmax=3 * cutoff) 

66 ax.set_title(f'Partial waves: {symbol} {name}') 

67 ax.set_xlabel('radius [Bohr]') 

68 ax.set_ylabel(r'$r\phi_{n\ell}(r)$') 

69 ax.legend() 

70 

71 

72def plot_projectors(ax: 'Axes', 

73 symbol: str, 

74 name: str, 

75 rgd: AERadialGridDescriptor, 

76 cutoff: float, 

77 iterator: Iterable[_ProjectorItem]) -> None: 

78 r_g = rgd.r_g 

79 group_by_l: dict[int, list[_ProjectorItem]] = {} 

80 bg_color = _get_patch_color(ax) 

81 for item in sorted(iterator): 

82 group_by_l.setdefault(item[0], []).append(item) 

83 for l, items in group_by_l.items(): 

84 weights = _get_blend_weights(len(items)) 

85 for weight, (_, n, e, pt_g) in zip(weights, items): 

86 if n == -1: 

87 label = '*{} ({:.2f} Ha)'.format('spdf'[l], e) 

88 else: 

89 label = '%d%s (%.2f Ha)' % (n, 'spdf'[l], e) 

90 ax.plot(r_g, pt_g * r_g, 

91 color=_blend_colors(colors[l], 

92 background=bg_color, 

93 foreground_weight=weight), 

94 label=label) 

95 ax.axis(xmin=0, xmax=cutoff) 

96 ax.set_title(f'Projectors: {symbol} {name}') 

97 ax.set_xlabel('radius [Bohr]') 

98 ax.set_ylabel(r'$r\tilde{p}(r)$') 

99 ax.legend() 

100 

101 

102def plot_potential_components(ax: 'Axes', 

103 symbol: str, 

104 name: str, 

105 rgd: AERadialGridDescriptor, 

106 cutoff: float, 

107 components: dict[str, typing.Array1D]) -> None: 

108 assert components 

109 radial_grid = rgd.r_g 

110 for color, (key, label) in zip( 

111 colors, 

112 [('xc', 'xc'), ('zero', '0'), ('hartree', 'H'), 

113 ('pseudo', 'ps'), ('all_electron', 'ae')]): 

114 if key in components: 

115 ax.plot(radial_grid, components[key], color=color, label=label) 

116 arrays = [array for key, array in components.items() 

117 if key != 'all_electron'] 

118 ax.axis(xmin=0, 

119 xmax=2 * cutoff, 

120 ymin=min(array[1:].min() for array in arrays), 

121 ymax=max(0, *(array[1:].max() for array in arrays))) 

122 ax.set_title(f'Potential components: {symbol} {name}') 

123 ax.set_xlabel('radius [Bohr]') 

124 ax.set_ylabel('potential [Ha]') 

125 ax.legend() 

126 

127 

128def _get_setup_symbol_and_name(setup: SetupData) -> tuple[str, str]: 

129 return setup.symbol, setup.setupname 

130 

131 

132def _get_gen_symbol_and_name(gen: PAWSetupGenerator) -> tuple[str, str]: 

133 aea = gen.aea 

134 return aea.symbol, aea.xc.name 

135 

136 

137def _get_setup_cutoff(setup: SetupData) -> float: 

138 cutoff = setup.r0 

139 if cutoff is not None: 

140 return cutoff 

141 

142 # `.r0` can be `None` for 'old setups', whatever that means 

143 name = f'{setup.symbol}{setup.Nv}' 

144 params = parameters[name] 

145 if len(params) == 3: 

146 _, radii, extra = params 

147 else: 

148 _, radii = params 

149 extra = {} 

150 if 'r0' in extra: # E.g. N5 

151 value = extra['r0'] 

152 if TYPE_CHECKING: 

153 assert isinstance(value, float) 

154 return value 

155 if not isinstance(radii, Iterable): 

156 radii = [radii] 

157 return min(radii) 

158 

159 

160def _normalize_with_radial_grid(array: typing.Array1D, 

161 rgd: AERadialGridDescriptor) -> typing.Array1D: 

162 result = rgd.zeros() 

163 result[0] = np.nan 

164 result[1:] = array[1:] / rgd.r_g[1:] 

165 return result 

166 

167 

168def _get_blend_weights(n: int, attenuation: float = .5) -> typing.Array1D: 

169 return (1 - attenuation) ** np.arange(n) 

170 

171 

172def _get_patch_color(ax: 'Axes') -> tuple[float, float, float]: 

173 from matplotlib.colors import to_rgb 

174 try: 

175 color = ax.patch.get_facecolor() 

176 if color is None: 

177 color = 'w' 

178 except AttributeError: 

179 color = 'w' 

180 return to_rgb(color) 

181 

182 

183def _blend_colors(color, background='w', foreground_weight=1.): 

184 # Too troublesome to type this and refactor to have `mypy` 

185 # understand what we're doing, not worth it 

186 from matplotlib.colors import to_rgb 

187 

188 color = np.array(to_rgb(color)) 

189 background = np.array(to_rgb(background)) 

190 return color * foreground_weight + background * (1 - foreground_weight) 

191 

192 

193def get_plot_pwaves_params_from_generator( 

194 gen: PAWSetupGenerator, 

195) -> tuple[str, str, 

196 AERadialGridDescriptor, float, 

197 Iterable[_PartialWaveItem]]: 

198 return (*_get_gen_symbol_and_name(gen), 

199 gen.rgd, 

200 gen.rcmax, 

201 ((l, n, waves.rcut, e, phi_g, phit_g) 

202 for l, waves in enumerate(gen.waves_l) 

203 for n, e, phi_g, phit_g in zip(waves.n_n, waves.e_n, 

204 waves.phi_ng, waves.phit_ng))) 

205 

206 

207def get_plot_pwaves_params_from_setup( 

208 setup: SetupData, 

209) -> tuple[str, str, 

210 AERadialGridDescriptor, float, 

211 Iterable[_PartialWaveItem]]: 

212 return (*_get_setup_symbol_and_name(setup), 

213 setup.rgd, 

214 _get_setup_cutoff(setup), 

215 zip(setup.l_j, setup.n_j, setup.rcut_j, setup.eps_j, 

216 setup.phi_jg, setup.phit_jg)) 

217 

218 

219def get_plot_projs_params_from_generator( 

220 gen: PAWSetupGenerator, 

221) -> tuple[str, str, AERadialGridDescriptor, float, Iterable[_ProjectorItem]]: 

222 return (*_get_gen_symbol_and_name(gen), 

223 gen.rgd, 

224 gen.rcmax, 

225 ((l, n, e, pt_g) 

226 for l, waves in enumerate(gen.waves_l) 

227 for n, e, pt_g in zip(waves.n_n, waves.e_n, waves.pt_ng))) 

228 

229 

230def get_plot_projs_params_from_setup( 

231 setup: SetupData, 

232) -> tuple[str, str, AERadialGridDescriptor, float, Iterable[_ProjectorItem]]: 

233 return (*_get_setup_symbol_and_name(setup), 

234 setup.rgd, 

235 _get_setup_cutoff(setup), 

236 zip(setup.l_j, setup.n_j, setup.eps_j, setup.pt_jg)) 

237 

238 

239def get_plot_pot_comps_params_from_generator( 

240 gen: PAWSetupGenerator, 

241) -> tuple[str, str, AERadialGridDescriptor, float, dict[str, typing.Array1D]]: 

242 assert gen.vtr_g is not None # Appease `mypy` 

243 

244 rgd = gen.rgd 

245 normalize = functools.partial(_normalize_with_radial_grid, rgd=rgd) 

246 zero = normalize(gen.v0r_g) 

247 hartree = normalize(gen.vHtr_g) 

248 pseudo = normalize(gen.vtr_g) 

249 all_electron = normalize(gen.aea.vr_sg[0]) 

250 components = {'xc': gen.vxct_g, 

251 'zero': zero, 

252 'hartree': hartree, 

253 'pseudo': pseudo, 

254 'all_electron': all_electron} 

255 return (*_get_gen_symbol_and_name(gen), rgd, gen.rcmax, components) 

256 

257 

258def get_plot_pot_comps_params_from_setup( 

259 setup: SetupData, 

260) -> tuple[str, str, AERadialGridDescriptor, float, dict[str, typing.Array1D]]: 

261 prefactor = (4 * np.pi) ** -.5 

262 zero = setup.vbar_g * prefactor 

263 if setup.vt_g is None: 

264 pseudo = None 

265 else: 

266 pseudo = setup.vt_g * prefactor 

267 symbol, xc_name = _get_setup_symbol_and_name(setup) 

268 rgd = setup.rgd 

269 normalize = functools.partial(_normalize_with_radial_grid, rgd=rgd) 

270 

271 # Reconstruct the AEA object 

272 # (Note: this misses the empty bound states from projectors) 

273 aea = AllElectronAtom(symbol, xc_name, 

274 Z=setup.Z, 

275 configuration=list(zip(setup.n_j, setup.l_j, 

276 setup.f_j, setup.eps_j))) 

277 if setup.has_corehole: 

278 aea.add(setup.ncorehole, setup.lcorehole, -setup.fcorehole) 

279 aea.initialize(rgd.N) 

280 aea.run() 

281 aea.scalar_relativistic = setup.type == 'scalar-relativistic' 

282 aea.refine() 

283 all_electron = normalize(aea.vr_sg[0]) 

284 components = {'zero': zero, 'all_electron': all_electron} 

285 

286 # FIXME: inconsistent with the `PAWSetupGenerator` results 

287 # # Re-calculate the XC and Hartree parts 

288 # from ..xc import XC 

289 # from .all_electron import calculate_density, calculate_potentials 

290 

291 # n_g = calculate_density(setup.f_j, 

292 # setup.phi_jg * rgd.r_g[None, :], 

293 # rgd.r_g) 

294 # n_g += setup.nc_g * prefactor 

295 # _, vHr_g, xc, _ = calculate_potentials(rgd, XC(xc_name), n_g, 

296 # setup.Z) 

297 # hartree = normalize(vHr_g) 

298 # components.update(xc=xc, hartree=hartree) 

299 

300 if pseudo is not None: 

301 components['pseudo'] = pseudo 

302 return (symbol, xc_name, rgd, _get_setup_cutoff(setup), components) 

303 

304 

305def reconstruct_paw_gen(setup: SetupData, 

306 basis: Basis | None = None) -> PAWSetupGenerator: 

307 params = {'v0': None, **parse_generator_data(setup.generatordata)} 

308 gen = generate(**params) 

309 if basis is not None: 

310 gen.basis = basis 

311 return gen 

312 

313 

314def read_basis_file(basis: str) -> Basis: 

315 symbol, *chunks, end = os.path.basename(basis).split('.') 

316 if end == 'gz': 

317 *chunks, end = chunks 

318 assert end == 'basis' 

319 name = '.'.join(chunks) 

320 return Basis.read_xml(symbol, name, basis) 

321 

322 

323def read_setup_file(dataset: str) -> SetupData: 

324 symbol, *name, xc = os.path.basename(dataset).split('.') 

325 if xc == 'gz': 

326 *name, xc = name 

327 setup = SetupData(symbol, xc, readxml=False) 

328 setup.read_xml(read_maybe_unzipping(dataset)) 

329 if not setup.generatordata: 

330 generator, = (minidom.parseString(read_maybe_unzipping(dataset)) 

331 .getElementsByTagName('generator')) 

332 text, = generator.childNodes 

333 assert isinstance(text, minidom.Text) 

334 setup.generatordata = textwrap.dedent(text.data).strip('\n') 

335 return setup 

336 

337 

338def parse_generator_data(data: str) -> dict[str, Any]: 

339 params: dict[str, Any] = {} 

340 for line in data.splitlines(): 

341 key, sep, value = line.rstrip(',').partition('=') 

342 if not (sep and key.isidentifier()): 

343 continue 

344 try: 

345 value = literal_eval(value) 

346 except Exception: 

347 continue 

348 params[key] = value 

349 return params 

350 

351 

352def _get_figures_and_axes( 

353 ngraphs: int, 

354 separate_figures: bool = False) -> tuple[list['Figure'], list['Axes']]: 

355 from matplotlib import pyplot as plt 

356 

357 if separate_figures: 

358 figs = [] 

359 ax_objs = [] 

360 for _ in range(ngraphs): 

361 fig = plt.figure() 

362 figs.append(fig) 

363 ax_objs.append(fig.gca()) 

364 return figs, ax_objs 

365 

366 assert ngraphs <= 6, f'Too many plots; expected <= 6, got {ngraphs}' 

367 if ngraphs > 4: 

368 layout = 2, 3 

369 elif ngraphs > 2: 

370 layout = 2, 2 

371 else: 

372 layout = 1, ngraphs 

373 

374 fig = plt.figure() 

375 subplots = fig.subplots(*layout).flatten() # type: ignore 

376 ntrimmed = layout[0] * layout[1] - ngraphs 

377 if ntrimmed: 

378 assert ntrimmed > 0, (f'Too many plots {ngraphs!r} ' 

379 f'for the layout {layout!r}') 

380 for ax in subplots[-ntrimmed:]: # Remove unused subplots 

381 ax.remove() 

382 

383 return [fig] * ngraphs, subplots[:ngraphs].tolist() 

384 

385 

386def plot_dataset( 

387 setup: SetupData, 

388 *, 

389 basis: Basis | None = None, 

390 gen: PAWSetupGenerator | None = None, 

391 plot_potential_components: bool = True, 

392 plot_partial_waves: bool = True, 

393 plot_projectors: bool = True, 

394 plot_logarithmic_derivatives: str | None = None, 

395 separate_figures: bool = False, 

396 savefig: str | None = None, 

397) -> tuple[list['Axes'], str | None]: 

398 """ 

399 Return 

400 ------ 

401 2-tuple: `tuple[list[Axes], <filename> | None]` 

402 """ 

403 if gen is not None: 

404 reconstruct = False 

405 elif plot_logarithmic_derivatives or plot_potential_components: 

406 reconstruct = True 

407 else: 

408 reconstruct = False 

409 if reconstruct: 

410 data = setup.generatordata 

411 if parse_generator_data(data): 

412 gen = reconstruct_paw_gen(setup, basis) 

413 else: 

414 if data: 

415 data_status = 'malformed' 

416 else: 

417 data_status = 'missing' 

418 msg = ('cannot reconstruct the `PAWSetupGenerator` object ' 

419 f'({data_status} `setup.generatordata`), ' 

420 'so the logarithmic derivatives and/or ' 

421 '(some of) the potential components cannot be plotted') 

422 warn(msg, stacklevel=2) 

423 plot_logarithmic_derivatives = None 

424 

425 plots: list[Callable] = [] 

426 

427 if gen is None: 

428 (symbol, name, 

429 rgd, cutoff, ppw_iter) = get_plot_pwaves_params_from_setup(setup) 

430 *_, pp_iter = get_plot_projs_params_from_setup(setup) 

431 *_, pot_comps = get_plot_pot_comps_params_from_setup(setup) 

432 else: 

433 # TODO: maybe we can compare the `ppw_iter` and `pp_iter` 

434 # between the stored and regenerated values for verification 

435 (symbol, name, 

436 rgd, cutoff, ppw_iter) = get_plot_pwaves_params_from_generator(gen) 

437 *_, pp_iter = get_plot_projs_params_from_generator(gen) 

438 *_, pot_comps = get_plot_pot_comps_params_from_generator(gen) 

439 

440 if plot_logarithmic_derivatives: 

441 assert gen is not None 

442 plots.append(functools.partial( 

443 plot_log_derivs, gen, plot_logarithmic_derivatives, True)) 

444 if plot_potential_components: 

445 plots.append(functools.partial( 

446 # Name clash with local variable 

447 globals()['plot_potential_components'], 

448 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff, 

449 components=pot_comps)) 

450 if plot_partial_waves: 

451 plots.append(functools.partial( 

452 # Name clash with local variable 

453 globals()['plot_partial_waves'], 

454 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff, 

455 iterator=ppw_iter)) 

456 if plot_projectors: 

457 plots.append(functools.partial( 

458 # Name clash with local variable 

459 globals()['plot_projectors'], 

460 symbol=symbol, name=name, rgd=rgd, cutoff=cutoff, 

461 iterator=pp_iter)) 

462 

463 if basis is not None: 

464 plots.append(functools.partial(BasisPlotter().plot, basis)) 

465 

466 if savefig is not None: 

467 separate_figures = False 

468 figs, ax_objs = _get_figures_and_axes(len(plots), separate_figures) 

469 assert len(figs) == len(ax_objs) == len(plots) 

470 for ax, plot_func in zip(ax_objs, plots): 

471 plot_func(ax=ax) 

472 

473 if savefig is not None: 

474 assert len({id(fig) for fig in figs}) == 1 

475 fig, *_ = figs 

476 assert fig is not None 

477 fig.savefig(savefig) 

478 

479 return ax_objs, savefig 

480 

481 

482def main(args: SimpleNamespace) -> list['Axes']: 

483 from matplotlib import pyplot as plt 

484 

485 if args.search: 

486 args.dataset, _ = search_for_file(args.dataset) 

487 setup = read_setup_file(args.dataset) 

488 sep_figs = args.outfile is None and args.separate_figures 

489 ax_objs, fname = plot_dataset( 

490 setup, 

491 separate_figures=sep_figs, 

492 plot_potential_components=args.potential_components, 

493 plot_logarithmic_derivatives=args.logarithmic_derivatives, 

494 savefig=args.outfile) 

495 assert ax_objs 

496 

497 if fname is None: 

498 plt.show() 

499 return ax_objs 

500 

501 

502class CLICommand: 

503 """Plot the PAW dataset, 

504 which by default includes the partial waves and the projectors. 

505 """ 

506 @staticmethod 

507 def add_arguments(parser): 

508 add = parser.add_argument 

509 add('-p', '--potential-components', 

510 action='store_true', 

511 help='Plot the potential components ' 

512 '(this reconstructs the full PAW setup generator object)') 

513 add('-l', '--logarithmic-derivatives', 

514 metavar='spdfg,e1:e2:de,radius', 

515 help='Plot logarithmic derivatives ' 

516 '(this reconstructs the full PAW setup generator object). ' 

517 'Example: -l spdf,-1:1:0.05,1.3. ' 

518 'Energy range and/or radius can be left out.') 

519 add('-s', '--separate-figures', 

520 action='store_true', 

521 help='If not plotting to a file, ' 

522 'plot the plots in separate figure windows/tabs, ' 

523 'instead of as subplots/panels in the same figure') 

524 add('-S', '--search', 

525 action='store_true', 

526 help='Look into the installed datasets (see `gpaw info`) for the ' 

527 'XML file, instead of treating it as a path') 

528 add('-o', '--outfile', '--write', 

529 metavar='FILE', 

530 help='Write the plots to FILE instead of `plt.show()`-ing them') 

531 add('dataset', 

532 metavar='FILE', 

533 help='XML file from which to read the PAW dataset') 

534 

535 run = staticmethod(main)