Coverage for gpaw/test/poisson/test_poisson_moment.py: 99%

110 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2import pytest 

3 

4from ase.units import Bohr 

5from gpaw.poisson import PoissonSolver, NoInteractionPoissonSolver 

6from gpaw.poisson_moment import MomentCorrectionPoissonSolver, MomentCorrection 

7from gpaw.poisson_extravacuum import ExtraVacuumPoissonSolver 

8from gpaw.grid_descriptor import GridDescriptor 

9 

10 

11@pytest.mark.parametrize('moment_corrections, expected_len', [ 

12 (None, 0), 

13 ([], 0), 

14 (4, 1), 

15 (9, 1), 

16 ([dict(moms=range(4), center=np.array([1, 3, 5]))], 1), 

17 ([dict(moms=range(4), center=np.array([5, 3, 5])), 

18 dict(moms=range(4), center=np.array([7, 5, 3]))], 2) 

19]) 

20def test_defaults(moment_corrections, expected_len): 

21 poisson_ref = NoInteractionPoissonSolver() 

22 poisson = MomentCorrectionPoissonSolver( 

23 poissonsolver=poisson_ref, 

24 moment_corrections=moment_corrections) 

25 

26 assert isinstance(poisson.moment_corrections, list), \ 

27 poisson.moment_corrections 

28 assert len(poisson.moment_corrections) == expected_len 

29 assert all([isinstance(mom, MomentCorrection) 

30 for mom in poisson.moment_corrections]) 

31 

32 

33@pytest.mark.parametrize('moment_corrections', [ 

34 None, 

35 [], 

36]) 

37def test_description_empty(moment_corrections): 

38 poisson_ref = NoInteractionPoissonSolver() 

39 poisson = MomentCorrectionPoissonSolver( 

40 poissonsolver=poisson_ref, 

41 moment_corrections=moment_corrections) 

42 

43 desc = poisson.get_description() 

44 desc_ref = poisson_ref.get_description() 

45 

46 assert isinstance(desc, str) 

47 assert isinstance(desc_ref, str) 

48 assert desc_ref in desc 

49 assert '0 moment corrections' in desc 

50 

51 

52@pytest.mark.parametrize('moment_corrections, expected_strings', [ 

53 (4, ['1 moment corrections', 'center', 'range(0, 4)']), 

54 (9, ['1 moment corrections', 'center', 'range(0, 9)']), 

55 ([dict(moms=range(4), center=np.array([1, 1, 1]))], 

56 ['1 moment corrections', '[1.00, 1.00, 1.00]', 'range(0, 4)']), 

57 ([dict(moms=[1, 2, 3], center=np.array([1, 1, 1]))], 

58 ['1 moment corrections', '[1.00, 1.00, 1.00]', 'range(1, 4)']), 

59 ([dict(moms=[0, 2, 3], center=np.array([1, 1, 1]))], 

60 ['1 moment corrections', '[1.00, 1.00, 1.00]', '(0, 2, 3)']), 

61 ([dict(moms=range(4), center=np.array([2, 3, 4])), 

62 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))], 

63 ['2 moment corrections', '[2.00, 3.00, 4.00]', 

64 '[7.40, 3.10, 0.10]', 'range(0, 4)']), 

65]) 

66def test_description(moment_corrections, expected_strings): 

67 poisson_ref = NoInteractionPoissonSolver() 

68 poisson = MomentCorrectionPoissonSolver( 

69 poissonsolver=poisson_ref, 

70 moment_corrections=moment_corrections) 

71 

72 desc = poisson.get_description() 

73 desc_ref = poisson_ref.get_description() 

74 

75 assert isinstance(desc, str) 

76 assert isinstance(desc_ref, str) 

77 

78 # Make sure that the description starts with the description of the wrapped 

79 # solver 

80 assert desc.startswith(desc_ref) 

81 

82 # and follows with the moments 

83 desc_rem = desc[len(desc_ref):] 

84 for expected_str in expected_strings: 

85 assert expected_str in desc_rem, \ 

86 f'"{expected_str}" not in "{desc_rem}"' 

87 

88 

89@pytest.mark.parametrize('moment_corrections, expected_string', [ 

90 ([], 'no corrections'), 

91 (4, 'array([0, 1, 2, 3]) @ None'), 

92 (9, 'array([0, 1, 2, 3, 4, 5, 6, 7, 8]) @ None'), 

93 ([dict(moms=range(4), center=np.array([1., 1., 1.]))], 

94 'array([0, 1, 2, 3]) @ array([1., 1., 1.])'), 

95 ([dict(moms=[1, 2, 3], center=np.array([1., 1., 1.]))], 

96 'array([1, 2, 3]) @ array([1., 1., 1.])'), 

97 ([dict(moms=[0, 2, 3], center=np.array([1., 1., 1.]))], 

98 'array([0, 2, 3]) @ array([1., 1., 1.])'), 

99 ([dict(moms=range(4), center=np.array([2, 3, 4])), 

100 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))], 

101 '2 corrections'), 

102]) 

103def test_repr(moment_corrections, expected_string): 

104 poisson_ref = NoInteractionPoissonSolver() 

105 poisson = MomentCorrectionPoissonSolver( 

106 poissonsolver=poisson_ref, 

107 moment_corrections=moment_corrections) 

108 

109 rep = repr(poisson) 

110 expected_repr = f'MomentCorrectionPoissonSolver ({expected_string})' 

111 

112 assert isinstance(rep, str) 

113 assert rep == expected_repr, f'{rep} not equal to {expected_repr}' 

114 

115 

116@pytest.fixture 

117def gd(): 

118 N_c = (16, 16, 3 * 16) 

119 cell_cv = (1, 1, 3) 

120 gd = GridDescriptor(N_c, cell_cv, False) 

121 

122 return gd 

123 

124 

125@pytest.mark.parametrize('moment_corrections', [ 

126 4, 

127 9, 

128 [dict(moms=range(4), center=np.array([1, 1, 1]))], 

129 [dict(moms=range(4), center=np.array([2, 3, 4])), 

130 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))], 

131]) 

132def test_write(gd, moment_corrections): 

133 poisson_ref = PoissonSolver() 

134 poisson_ref.set_grid_descriptor(gd) 

135 

136 poisson = MomentCorrectionPoissonSolver( 

137 poissonsolver=poisson_ref, 

138 moment_corrections=moment_corrections) 

139 poisson.set_grid_descriptor(gd) 

140 

141 from gpaw.io import Writer 

142 from gpaw.mpi import world 

143 filename = '/dev/null' 

144 

145 # By using the Writer we check that everything is JSON serializable 

146 writer = Writer(filename, world) 

147 writer.child('poisson').write(**poisson.todict()) 

148 writer.close() 

149 

150 

151@pytest.fixture 

152def rho_g(gd): 

153 # Construct model density 

154 coord_vg = gd.get_grid_point_coordinates() 

155 z_g = coord_vg[2, :] 

156 rho_g = gd.zeros() 

157 for z0 in [1, 2]: 

158 rho_g += 10 * (z_g - z0) * \ 

159 np.exp(-20 * np.sum((coord_vg.T - np.array([.5, .5, z0])).T**2, 

160 axis=0)) 

161 

162 return rho_g 

163 

164 

165@pytest.fixture 

166def poisson_solve(gd, rho_g): 

167 

168 def _poisson_solve(poisson): 

169 poisson.set_grid_descriptor(gd) 

170 phi_g = gd.zeros() 

171 poisson.solve(phi_g, rho_g) 

172 

173 return phi_g 

174 

175 return _poisson_solve 

176 

177 

178@pytest.fixture 

179def compare(gd, tolerance, cmp_begin): 

180 # Some test cases compare in only a small region of space 

181 if cmp_begin is None: 

182 slice = None 

183 else: 

184 Ng_c = gd.get_size_of_global_array() 

185 cmp_end = 1 - cmp_begin 

186 idx_c = [np.arange(int(N * cmp_begin), int(N * cmp_end)) for N in Ng_c] 

187 slice = np.ix_(*idx_c) 

188 

189 def _compare(phi1_g, phi2_g): 

190 big_phi1_g = gd.collect(phi1_g) 

191 big_phi2_g = gd.collect(phi2_g) 

192 if gd.comm.rank == 0: 

193 if slice is not None: 

194 big_phi1_g = big_phi1_g[slice] 

195 big_phi2_g = big_phi2_g[slice] 

196 assert np.max(np.absolute(big_phi1_g - big_phi2_g)) == ( 

197 pytest.approx(0.0, abs=tolerance)) 

198 

199 return _compare 

200 

201 

202@pytest.fixture 

203def poisson_ref(gd, ref): 

204 poisson_default = PoissonSolver() 

205 if ref == 'default': 

206 # Get reference from default poissonsolver 

207 # Using the default solver the potential is forced to zero at the box 

208 # boundries. The potential thus has the wrong shape near the boundries 

209 # but is nearly right in the center of the box 

210 return poisson_default 

211 elif ref == 'extravac': 

212 # Get reference from extravacuum solver 

213 # With 4 times extra vacuum the potential is well converged everywhere 

214 poisson_extravac = ExtraVacuumPoissonSolver( 

215 gpts=4 * gd.N_c, 

216 poissonsolver_large=poisson_default) 

217 return poisson_extravac 

218 else: 

219 raise ValueError(f'No such ref {ref}') 

220 

221 

222@pytest.mark.parametrize('ref, moment_corrections, tolerance, cmp_begin', [ 

223 # MomentCorrectionPoissonSolver without any moment corrections should be 

224 # exactly as the underlying solver 

225 ('default', None, 0.0, None), 

226 # It should also be possible to chain default+extravacuum+moment correction 

227 # With moment_correction=None the MomentCorrection solver doesn't actually 

228 # do anything, so the potential should be identical to the extra vacuum 

229 # reference 

230 ('extravac', None, 0.0, None), 

231 # Test moment_corrections=int 

232 # The moment correction is applied to the center of the cell. This is not 

233 # enough to have a converged potential near the edges 

234 # The closer we are to the center the better though 

235 ('default', 4, 3.5e-2, 0.25), 

236 ('default', 4, 2.5e-2, 0.40), 

237 # Test moment_corrections=list 

238 # Remember that the solver expects Ångström units and we have specified 

239 # the grid in Bohr 

240 # This should give a well converged potential everywhere, that we can 

241 # compare to the reference extravacuum potential 

242 ('extravac', 

243 [{'moms': range(4), 'center': np.array([.5, .5, 1]) * Bohr}, 

244 {'moms': range(4), 'center': np.array([.5, .5, 2]) * Bohr}], 

245 3e-3, None), 

246 # It should be possible to chain default+extravacuum+moment correction 

247 # As the potential is already well converged, there should be little change 

248 ('extravac', 

249 [{'moms': range(4), 'center': np.array([.5, .5, 1]) * Bohr}, 

250 {'moms': range(4), 'center': np.array([.5, .5, 2]) * Bohr}], 

251 5e-4, None), 

252]) 

253def test_poisson_moment_correction(gd, rho_g, poisson_solve, 

254 compare, poisson_ref, 

255 ref, moment_corrections, 

256 tolerance, cmp_begin): 

257 # Solve for the potential using the reference solver 

258 phiref_g = poisson_solve(poisson_ref) 

259 

260 # Create a MomentCorrectionPoissonSolver and solve for the potential 

261 poisson = MomentCorrectionPoissonSolver(poissonsolver=poisson_ref, 

262 moment_corrections=None) 

263 phi_g = poisson_solve(poisson) 

264 

265 # Test the MomentCorrectionPoissonSolver 

266 compare(phi_g, phiref_g)