Coverage for gpaw/poisson_moment.py: 99%

141 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from typing import Any, Dict, Optional, List, Sequence, Union 

2 

3import numpy as np 

4from ase.units import Bohr 

5from ase.utils.timing import Timer 

6from gpaw.poisson import _PoissonSolver, create_poisson_solver 

7from gpaw.utilities.gauss import Gaussian 

8from gpaw.typing import Array1D 

9from gpaw.utilities.timing import nulltimer, NullTimer 

10 

11from ase.utils.timing import timer 

12 

13 

14MomentCorrectionsType = Union[int, List[Dict[str, Any]]] 

15 

16 

17class MomentCorrection: 

18 

19 def __init__(self, 

20 center: Optional[Union[Sequence, Array1D]], 

21 moms: Union[int, Sequence[int]]): 

22 if center is not None: 

23 center = np.asarray(center) / Bohr 

24 self.center = center 

25 self.moms = np.asarray(moms) 

26 

27 def todict(self) -> Dict[str, Any]: 

28 """ return dictionary description, converting the moment correction 

29 from units of Bohr to Ångström """ 

30 

31 center = self.center 

32 if center is not None: 

33 center = center * Bohr 

34 

35 dict_out = dict(moms=self.moms, center=center) 

36 

37 return dict_out 

38 

39 def __str__(self) -> str: 

40 if self.center is None: 

41 center = 'center' 

42 else: 

43 center = ', '.join([f'{x:.2f}' for x in self.center * Bohr]) 

44 

45 if np.allclose(np.diff(self.moms), 1): 

46 # Increasing sequence 

47 moms = f'range({self.moms[0]}, {self.moms[-1] + 1})' 

48 else: 

49 # List of integers 

50 _moms = ', '.join([f'{m:.0f}' for m in self.moms]) 

51 moms = f'({_moms})' 

52 return f'[{center}] {moms}' 

53 

54 def __repr__(self) -> str: 

55 center = self.center 

56 if center is not None: 

57 center *= Bohr 

58 return f'{repr(self.moms)} @ {repr(center)}' 

59 

60 

61class MomentCorrectionPoissonSolver(_PoissonSolver): 

62 """Wrapper for the poisson solver that includes moment corrections 

63 

64 Parameters 

65 ---------- 

66 poissonsolver 

67 underlying poisson solver 

68 moment_corrections 

69 list of moment corrections, expressed as dictionaries 

70 `{'moms': ..., 'center': ...}` that specify the multipole moments 

71 and their centres. 

72 

73 moment_corrections = [{'moms': moms_list1, 'center': center1}, 

74 {'moms': moms_list2, 'center': center2}, 

75 ...] 

76 

77 Here moms_listX is list of integers of multipole moments to be 

78 corrected at centerX. For example moms_list=range(4) corresponds to 

79 s, p_x, p_y and p_z type multipoles. 

80 

81 Optionally setting moment_corrections to an integer m is equivalent to 

82 including multipole moments corresponding to range(m) at the center of 

83 the cell 

84 

85 

86 timer 

87 timer 

88 

89 """ 

90 

91 def __init__(self, 

92 poissonsolver: Union[_PoissonSolver, Dict[str, Any]], 

93 moment_corrections: Optional[MomentCorrectionsType], 

94 timer: Union[NullTimer, Timer] = nulltimer): 

95 

96 self._initialized = False 

97 self.poissonsolver = create_poisson_solver(poissonsolver) 

98 self.timer = timer 

99 

100 if moment_corrections is None: 

101 self.moment_corrections = [] 

102 elif isinstance(moment_corrections, int): 

103 moms = range(moment_corrections) 

104 center = None 

105 self.moment_corrections = [MomentCorrection(moms=moms, 

106 center=center)] 

107 elif isinstance(moment_corrections, list): 

108 assert all(['moms' in mom and 'center' in mom 

109 for mom in moment_corrections]), \ 

110 (f'{self.__class__.__name__}: each element in ' 

111 'moment_correction must be a dictionary ' 

112 'with the keys "moms" and "center"') 

113 self.moment_corrections = [MomentCorrection(**mom) 

114 for mom in moment_corrections] 

115 else: 

116 raise ValueError(f'{self.__class__.__name__}: moment_correction ' 

117 'must be a list of dictionaries') 

118 

119 def todict(self): 

120 mom_corr = [mom.todict() for mom in self.moment_corrections] 

121 d = {'name': 'MomentCorrectionPoissonSolver', 

122 'poissonsolver': self.poissonsolver.todict(), 

123 'moment_corrections': mom_corr} 

124 

125 return d 

126 

127 def set_grid_descriptor(self, gd): 

128 self.poissonsolver.set_grid_descriptor(gd) 

129 self.gd = gd 

130 

131 def get_description(self) -> str: 

132 description = self.poissonsolver.get_description() 

133 n = len(self.moment_corrections) 

134 

135 lines = [description] 

136 lines.extend([f' {n} moment corrections:']) 

137 lines.extend([f' {str(mom)}' 

138 for mom in self.moment_corrections]) 

139 

140 return '\n'.join(lines) 

141 

142 @timer('Poisson initialize') 

143 def _init(self): 

144 if self._initialized: 

145 return 

146 self.poissonsolver._init() 

147 

148 if not self.gd.orthogonal or self.gd.pbc_c.any(): 

149 raise NotImplementedError('Only orthogonal unit cells ' 

150 'and non-periodic boundary ' 

151 'conditions are tested') 

152 self.load_moment_corrections_gauss() 

153 

154 self._initialized = True 

155 

156 @timer('Load moment corrections') 

157 def load_moment_corrections_gauss(self): 

158 self.gauss_i = [] 

159 self.mom_ij = [] 

160 self.mask_ig = [] 

161 

162 if len(self.moment_corrections) == 0: 

163 return 

164 

165 mask_ir = [] 

166 r_ir = [] 

167 

168 for rmom in self.moment_corrections: 

169 center = rmom.center 

170 mom_j = rmom.moms 

171 gauss = Gaussian(self.gd, center=center) 

172 self.gauss_i.append(gauss) 

173 r_ir.append(gauss.r.ravel()) 

174 mask_ir.append(self.gd.zeros(dtype=int).ravel()) 

175 self.mom_ij.append(mom_j) 

176 

177 r_ir = np.array(r_ir) 

178 mask_ir = np.array(mask_ir) 

179 

180 Ni = r_ir.shape[0] 

181 Nr = r_ir.shape[1] 

182 

183 for r in range(Nr): 

184 i = np.argmin(r_ir[:, r]) 

185 mask_ir[i, r] = 1 

186 

187 for i in range(Ni): 

188 mask_r = mask_ir[i] 

189 mask_g = mask_r.reshape(self.gd.n_c) 

190 self.mask_ig.append(mask_g) 

191 

192 def solve(self, phi, rho, **kwargs): 

193 self._init() 

194 return self._solve(phi, rho, **kwargs) 

195 

196 @timer('Solve') 

197 def _solve(self, phi, rho, **kwargs): 

198 timer = kwargs.get('timer', self.timer) 

199 

200 if len(self.moment_corrections) > 0: 

201 assert not self.gd.pbc_c.any() 

202 

203 timer.start('Multipole moment corrections') 

204 

205 rho_neutral = rho * 0.0 

206 phi_cor_g = self.gd.zeros() 

207 for gauss, mask_g, mom_j in zip(self.gauss_i, self.mask_ig, 

208 self.mom_ij): 

209 rho_masked = rho * mask_g 

210 for mom in mom_j: 

211 phi_cor_g += gauss.remove_moment(rho_masked, mom) 

212 rho_neutral += rho_masked 

213 

214 # Remove multipoles for better initial guess 

215 phi -= phi_cor_g 

216 

217 timer.stop('Multipole moment corrections') 

218 

219 timer.start('Solve neutral') 

220 niter = self.poissonsolver.solve(phi, rho_neutral, **kwargs) 

221 timer.stop('Solve neutral') 

222 

223 timer.start('Multipole moment corrections') 

224 # correct error introduced by removing multipoles 

225 phi += phi_cor_g 

226 timer.stop('Multipole moment corrections') 

227 

228 return niter 

229 else: 

230 return self.poissonsolver.solve(phi, rho, **kwargs) 

231 

232 def estimate_memory(self, mem): 

233 self.poissonsolver.estimate_memory(mem) 

234 gdbytes = self.gd.bytecount() 

235 if self.moment_corrections is not None: 

236 mem.subnode('moment_corrections masks', 

237 len(self.moment_corrections) * gdbytes) 

238 

239 def __repr__(self): 

240 if len(self.moment_corrections) == 0: 

241 corrections_str = 'no corrections' 

242 elif len(self.moment_corrections) < 2: 

243 corrections_str = f'{repr(self.moment_corrections[0])}' 

244 else: 

245 corrections_str = f'{len(self.moment_corrections)} corrections' 

246 

247 representation = f'MomentCorrectionPoissonSolver ({corrections_str})' 

248 return representation