Coverage for gpaw/test/cli/test_plot_dataset.py: 100%

78 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import argparse 

2import contextlib 

3import dataclasses 

4import os 

5import shlex 

6import subprocess 

7import sys 

8from typing import ContextManager 

9 

10import pytest 

11 

12from gpaw.setup_data import search_for_file 

13from gpaw.atom.plot_dataset import CLICommand, read_setup_file 

14 

15 

16@dataclasses.dataclass 

17class SetupInfo: 

18 filename: str 

19 nplots: int 

20 ctx: ContextManager = dataclasses.field( 

21 default_factory=contextlib.nullcontext) 

22 

23 

24@pytest.fixture(scope='module') 

25def old_setup(): 

26 """ 

27 Legacy setup file, no info on which to reconstruct the generator 

28 -> cannot plot the log derivatives, expect warnings 

29 """ 

30 setup_file = 'Ti.LDA' 

31 installed_setup = read_setup_file(search_for_file(setup_file)[0]) 

32 assert ('=' not in installed_setup.generatordata), ( 

33 f'Setup {setup_file!r} has been updated to include more ' 

34 'generator data, update the test') 

35 return SetupInfo( 

36 setup_file, 3, pytest.warns(match='cannot reconstruct')) 

37 

38 

39@pytest.fixture(scope='module') 

40def new_setup(): 

41 """ 

42 New setup file, reconstruction possible -> all plots can be plotted 

43 """ 

44 setup_file = 'Cr.14.LDA' 

45 installed_setup = read_setup_file(search_for_file(setup_file)[0]) 

46 assert ('=' in installed_setup.generatordata), ( 

47 f'Setup {setup_file!r} does\'t have the expected generator data') 

48 return SetupInfo(setup_file, 4) 

49 

50 

51@pytest.mark.serial 

52@pytest.mark.parametrize('setup', ['old_setup', 'new_setup']) 

53@pytest.mark.parametrize( 

54 ('flags', 'search', 'use_cli', 'expected_nplots'), 

55 [('', False, False, 2), # Minimal plot 

56 ('-s', False, False, 2), # --separate-figures (ignored) 

57 ('-p -l spd,-1:1:.05', # Reconstruct gen and make the full fig 

58 False, False, None), 

59 ('-p -l spd,-1:1:.05', # Same as the above, but... 

60 True, False, # ... also test the dataset file searching 

61 None), 

62 ('-p -l spd,-1:1:.05', # Same as the above, but... 

63 True, True, # ... also test the CLI 

64 None)]) # Running in subproc, no warnings 

65def test_gpaw_plot_dataset( 

66 setup, flags, search, use_cli, expected_nplots, request, in_tmp_dir): 

67 """ 

68 Test for `gpaw plot-dataset`. For the cases where 

69 `expected_nplots = None`, try to make as many subplots as the 

70 dataset permits, potentially emitting a warning. 

71 """ 

72 info = request.getfixturevalue(setup) 

73 setup_file = info.filename 

74 old_files = set(os.listdir(os.curdir)) 

75 outfile = 'output.png' 

76 expected_files = {outfile} 

77 if expected_nplots is None: 

78 # Plot as many plots as possible, which may result in warnings 

79 # (to be caught by `info.ctx`) 

80 expected_nplots = info.nplots 

81 ctx = info.ctx 

82 else: 

83 # Plot a definite number of plots, expect no warnings 

84 ctx = contextlib.nullcontext() 

85 

86 args = [f'--write={outfile}', *shlex.split(flags)] 

87 if search: 

88 args.append('--search') 

89 else: 

90 _, content = search_for_file(setup_file) 

91 with open(setup_file, mode='wb') as fobj: 

92 # Note: we can't directly use the file pointed to by 

93 # `search_for_file()` because it may be zipped 

94 fobj.write(content) 

95 expected_files.add(setup_file) 

96 args.append(setup_file) 

97 

98 if use_cli: 

99 # Not much we can do about the subcommand, just see if it works 

100 gpaw_plot_cmd = [sys.executable, '-m', 'gpaw', '-T', 'plot-dataset'] 

101 subprocess.check_call(gpaw_plot_cmd + args) 

102 else: 

103 parser = argparse.ArgumentParser() 

104 CLICommand.add_arguments(parser) 

105 with ctx: 

106 axs = CLICommand.run(parser.parse_args(args)) 

107 # All plots should be on the same figure 

108 assert len({id(ax.get_figure()) for ax in axs}) == 1 

109 # Check we have as many plots as expected 

110 assert len(axs) == expected_nplots, ( 

111 repr([ax.get_title() for ax in axs])) 

112 

113 # Check existence of output file 

114 new_files = set(os.listdir(os.curdir)) 

115 assert new_files == old_files | expected_files 

116 assert new_files - old_files == expected_files 

117 

118 

119@pytest.mark.serial 

120@pytest.mark.parametrize(('write', 'basis'), 

121 [(False, False), # Minimal 

122 (True, True)]) 

123def test_gpaw_dataset_plot(write, basis, in_tmp_dir): 

124 """ 

125 Test for `gpaw dataset --plot`. 

126 """ 

127 old_files = set(os.listdir(os.curdir)) 

128 outfile = 'output.png' 

129 argv = [sys.executable, '-m', 'gpaw', '-T', 'dataset', 

130 f'--plot={outfile}', 'Ti'] 

131 expected_files = {outfile} 

132 if write: 

133 argv.insert(-1, '-w') 

134 expected_files.add('Ti.LDA') 

135 if basis: 

136 argv.insert(-1, '-b') 

137 expected_files.add('Ti.dzp.basis') 

138 subprocess.check_call(argv) 

139 new_files = set(os.listdir(os.curdir)) 

140 assert new_files == old_files | expected_files 

141 assert new_files - old_files == expected_files