Coverage for gpaw/test/lcao/test_lcao_parallel.py: 10%

94 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import sys 

2 

3import pytest 

4from gpaw.utilities import devnull 

5 

6from gpaw import GPAW, FermiDirac, KohnShamConvergenceError 

7from gpaw.utilities import compiled_with_sl 

8from gpaw.mpi import world 

9from ase.build import molecule 

10 

11# Calculates energy and forces for various parallelizations 

12 

13pytestmark = pytest.mark.skipif(world.size < 4, 

14 reason='world.size < 4') 

15 

16 

17def test_lcao_lcao_parallel(): 

18 tolerance = 4e-5 

19 

20 parallel = dict() 

21 

22 basekwargs = dict(mode='lcao', 

23 nbands=6, 

24 parallel=parallel) 

25 

26 Eref = None 

27 Fref_av = None 

28 

29 def run(formula='H2O', vacuum=2.0, cell=None, pbc=0, **morekwargs): 

30 print(formula, parallel) 

31 system = molecule(formula) 

32 kwargs = dict(basekwargs) 

33 kwargs.update(morekwargs) 

34 calc = GPAW(**kwargs) 

35 system.calc = calc 

36 system.center(vacuum) 

37 if cell is None: 

38 system.center(vacuum) 

39 else: 

40 system.set_cell(cell) 

41 system.set_pbc(pbc) 

42 

43 try: 

44 system.get_potential_energy() 

45 except KohnShamConvergenceError: 

46 pass 

47 

48 E = calc.hamiltonian.e_total_free 

49 F_av = calc.get_forces() 

50 

51 nonlocal Eref, Fref_av 

52 if Eref is None: 

53 Eref = E 

54 Fref_av = F_av 

55 

56 eerr = abs(E - Eref) 

57 ferr = abs(F_av - Fref_av).max() 

58 

59 if calc.wfs.world.rank == 0: 

60 print('Energy', E) 

61 print() 

62 print('Forces') 

63 print(F_av) 

64 print() 

65 print('Errs', eerr, ferr) 

66 

67 if eerr > tolerance or ferr > tolerance: 

68 if calc.wfs.world.rank == 0: 

69 stderr = sys.stderr 

70 else: 

71 stderr = devnull 

72 if eerr > tolerance: 

73 print('Failed!', file=stderr) 

74 print('E = %f, Eref = %f' % (E, Eref), file=stderr) 

75 msg = 'Energy err larger than tolerance: %f' % eerr 

76 if ferr > tolerance: 

77 print('Failed!', file=stderr) 

78 print('Forces:', file=stderr) 

79 print(F_av, file=stderr) 

80 print(file=stderr) 

81 print('Ref forces:', file=stderr) 

82 print(Fref_av, file=stderr) 

83 print(file=stderr) 

84 msg = 'Force err larger than tolerance: %f' % ferr 

85 print(file=stderr) 

86 print('Args:', file=stderr) 

87 print(formula, vacuum, cell, pbc, morekwargs, file=stderr) 

88 print(parallel, file=stderr) 

89 raise AssertionError(msg) 

90 

91 # reference: 

92 # state-parallelization = 1, 

93 # domain-decomposition = (1, 2, 2) 

94 run() 

95 

96 # state-parallelization = 2, 

97 # domain-decomposition = (1, 2, 1) 

98 parallel['band'] = 2 

99 parallel['domain'] = (1, 2, world.size // 4) 

100 run() 

101 

102 if compiled_with_sl(): 

103 # state-parallelization = 2, 

104 # domain-decomposition = (1, 2, 1) 

105 # with blacs 

106 parallel['sl_default'] = (2, 2, 2) 

107 run() 

108 

109 # state-parallelization = 1, 

110 # domain-decomposition = (1, 2, 2) 

111 # with blacs 

112 del parallel['band'] 

113 del parallel['domain'] 

114 run() 

115 

116 # perform spin polarization test 

117 parallel = dict() 

118 

119 basekwargs = dict(mode='lcao', 

120 nbands=6, 

121 parallel=parallel) 

122 

123 Eref = None 

124 Fref_av = None 

125 

126 OH_kwargs = dict(formula='NH2', vacuum=1.5, pbc=1, spinpol=1, 

127 occupations=FermiDirac(0.1)) 

128 

129 # start with empty parallel keyword 

130 # del parallel['sl_default'] 

131 # parallel = None 

132 # parallel = dict() 

133 # print parallel 

134 # del parallel['band'] 

135 parallel['domain'] = (1, 2, 1) 

136 

137 # reference: 

138 # spin-polarization = 2, 

139 # state-parallelization = 1, 

140 # domain-decomposition = (1, 2, 1) 

141 run(**OH_kwargs) 

142 

143 # spin-polarization = 1, 

144 # state-parallelization= 2, 

145 # domain-decomposition = (1, 1, 1) 

146 del parallel['domain'] 

147 parallel['band'] = 2 

148 run(**OH_kwargs) 

149 

150 if compiled_with_sl(): 

151 # spin-polarization = 2, 

152 # state-parallelization = 2, 

153 # domain-decomposition = (1, 1, 1) 

154 # with blacs 

155 parallel['sl_default'] = (2, 1, 2) 

156 run(**OH_kwargs) 

157 

158 # spin-polarization = 2, 

159 # state-parallelization = 1, 

160 # domain-decomposition = (1, 2, 1) 

161 del parallel['band'] 

162 parallel['domain'] = (1, 2, 1) 

163 run(**OH_kwargs) 

164 

165 # spin-polarization = 1, 

166 # state-parallelization = 1, 

167 # domain-decomposition = (1, 2, 2) 

168 parallel['domain'] = (1, 2, 2) 

169 run(**OH_kwargs)