Coverage for gpaw/nlopt/matrixel.py: 87%

132 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2from typing import TYPE_CHECKING 

3 

4from ase.parallel import parprint 

5from ase.utils.timing import Timer 

6from pathlib import Path 

7import numpy as np 

8 

9from gpaw.new.ase_interface import ASECalculator 

10from gpaw.nlopt.basic import NLOData 

11from gpaw.utilities.progressbar import ProgressBar 

12 

13if TYPE_CHECKING: 

14 from gpaw.nlopt.adapters import CollinearGSInfo, NoncollinearGSInfo 

15 from gpaw.typing import ArrayND 

16 

17 

18def get_mml(gs: CollinearGSInfo | NoncollinearGSInfo, 

19 bands: slice, 

20 spin: int, 

21 timer: Timer | None = None) -> ArrayND: 

22 """ 

23 Compute momentum matrix elements. 

24 

25 Parameters 

26 ---------- 

27 gs 

28 Ground state adapter. 

29 bands 

30 Range of band indices. 

31 spin 

32 Spin channel index (for spin-polarized systems 0 or 1). 

33 timer 

34 Timer for monitoring code performance. 

35 

36 Returns 

37 ------- 

38 p_qvnn 

39 Momentum matrix elements for each local q-point. 

40 """ 

41 

42 # Start the timer 

43 if timer is None: 

44 timer = Timer() 

45 parprint(f'Calculating momentum matrix elements for spin channel {spin}.') 

46 

47 # Spin input 

48 assert spin < gs.ns, 'Wrong spin input' 

49 

50 # Allocate the matrix elements 

51 ibzwfs = gs.ibzwfs 

52 master = (ibzwfs.kpt_comm.rank == 0) 

53 nb = bands.stop - bands.start 

54 nq = len(ibzwfs.q_k.keys()) # Number of k-points (q-indices) for each core 

55 p_qvnn = np.empty([nq, 3, nb, nb], dtype=complex) 

56 

57 # Initial call to print 0 % progress 

58 if master: 

59 pb = ProgressBar() 

60 

61 # Calculate matrix elements in loop over k-points 

62 for wfs_s in ibzwfs.wfs_qs: 

63 wfs = gs.get_wfs(wfs_s, spin) 

64 

65 with timer('Contribution from pseudo wave functions'): 

66 G_plus_k_Gv, u_nG = gs.get_plane_wave_coefficients(wfs, 

67 bands, spin) 

68 p_vnn = np.einsum('Gv,mG,nG->vmn', 

69 G_plus_k_Gv, u_nG.conj(), u_nG) * gs.ucvol 

70 

71 with timer('Contribution from PAW corrections'): 

72 P_ani = gs.get_wave_function_projections(wfs, bands, spin) 

73 for P_ni, nabla_iiv in zip(P_ani.values(), gs.nabla_aiiv): 

74 p_vnn -= 1j * np.einsum('mi,nj,ijv->vmn', 

75 P_ni.conj(), P_ni, nabla_iiv) 

76 

77 p_qvnn[wfs.q] = p_vnn 

78 

79 if master: 

80 pb.update(wfs.q / nq) 

81 

82 if master: 

83 pb.finish() 

84 timer.write() 

85 

86 return p_qvnn 

87 

88 

89def gather_to_master(p_qvnn, ibzwfs): 

90 kpt_comm = ibzwfs.kpt_comm 

91 master = (kpt_comm.rank == 0) 

92 shape = p_qvnn.shape[1:4] 

93 

94 if not master: 

95 kpt_comm.send(p_qvnn, 0) 

96 return np.empty((0,) + shape, complex) 

97 else: 

98 rank_k = ibzwfs.rank_k 

99 nk = len(rank_k) 

100 

101 p_kvnn = np.empty((nk,) + shape, complex) 

102 

103 k_q = np.where(rank_k == 0)[0] 

104 p_kvnn[k_q] = p_qvnn 

105 for gather_rank in range(1, kpt_comm.size): 

106 k_q = np.where(rank_k == gather_rank)[0] 

107 nq = len(k_q) 

108 

109 p_qvnn = np.zeros((nq,) + shape, complex) 

110 kpt_comm.receive(p_qvnn, gather_rank) 

111 p_kvnn[k_q] = p_qvnn 

112 return p_kvnn 

113 

114 

115def make_nlodata(calc: ASECalculator | str | Path, 

116 spin_string: str = 'all', 

117 ni: int | None = None, 

118 nf: int | None = None) -> NLOData: 

119 """ 

120 This function calculates and returns all required 

121 NLO data: w_sk, f_skn, E_skn, p_skvnn. 

122 

123 Parameters 

124 ---------- 

125 calc 

126 Calculator or string/path pointing to a .gpw file. 

127 spin_string 

128 String denoting which spin channels to include ('all', 's0' , 's1'). 

129 ni 

130 First band to compute the mml. 

131 nf 

132 Last band to compute the mml (relative to number of bands for nf <= 0). 

133 

134 Returns 

135 ------- 

136 NLOData 

137 Data object carrying required matrix elements for NLO calculations. 

138 

139 """ 

140 

141 if not isinstance(calc, ASECalculator): 

142 if not (isinstance(calc, str) or isinstance(calc, Path)): 

143 raise TypeError('Input must be a calculator or a string / path' 

144 'pointing to a calculator.') 

145 from gpaw.new.ase_interface import GPAW 

146 calc = GPAW(calc, txt=None, parallel={'domain': 1, 'band': 1}) 

147 assert not calc.symmetry.point_group, \ 

148 'Point group symmetry should be off.' 

149 

150 gs: CollinearGSInfo | NoncollinearGSInfo 

151 if calc.dft.density.collinear: 

152 from gpaw.nlopt.adapters import CollinearGSInfo 

153 gs = CollinearGSInfo(calc) 

154 else: 

155 from gpaw.nlopt.adapters import NoncollinearGSInfo 

156 gs = NoncollinearGSInfo(calc) 

157 

158 # Start the timer 

159 timer = Timer() 

160 

161 # Parse spin string 

162 ns = gs.ns 

163 if spin_string == 'all': 

164 spins = list(range(ns)) 

165 elif spin_string == 's0': 

166 spins = [0] 

167 elif spin_string == 's1': 

168 spins = [1] 

169 assert spins[0] < ns, 'Wrong spin input' 

170 else: 

171 raise NotImplementedError 

172 

173 # Parse band input 

174 ibzwfs = gs.ibzwfs 

175 nb_full = ibzwfs.nbands 

176 ni = int(ni) if ni is not None else 0 

177 nf = int(nf) if nf is not None else nb_full 

178 nf = nb_full + nf if (nf <= 0) else nf 

179 bands = slice(ni, nf) 

180 

181 # Memory estimate 

182 nk = len(ibzwfs.rank_k) # Total number of k-points 

183 est_mem = 2 * 3 * nk * (nf - ni)**2 * 16 / 2**30 

184 parprint(f'At least {est_mem:.2f} GB of memory is required on master.') 

185 

186 # Get the energy and Fermi-Dirac occupations (data is only in master) 

187 with timer('Get energies and fermi levels'): 

188 E_skn, f_skn = ibzwfs.get_all_eigs_and_occs() 

189 w_sk = np.array([ibzwfs.ibz.weight_k for _ in range(gs.ndensities)]) 

190 w_sk *= gs.bzvol * ibzwfs.spin_degeneracy 

191 

192 # Compute the momentum matrix elements 

193 with timer('Compute the momentum matrix elements'): 

194 p_sqvnn = [] 

195 for spin in spins: 

196 p_qvnn = get_mml(gs, bands, spin, timer) 

197 p_sqvnn.append(p_qvnn) 

198 if not gs.collinear: 

199 p_sqvnn = [p_sqvnn[0] + p_sqvnn[1]] 

200 with timer('Gather the data to master'): 

201 p_skvnn = [] 

202 for p_qvnn in p_sqvnn: 

203 p_kvnn = gather_to_master(p_qvnn, ibzwfs) 

204 p_skvnn.append(p_kvnn) 

205 

206 # Save the output to the file 

207 return NLOData(w_sk=w_sk, 

208 f_skn=f_skn[:, :, bands], 

209 E_skn=E_skn[:, :, bands], 

210 p_skvnn=np.array(p_skvnn, complex), 

211 comm=ibzwfs.kpt_comm) 

212 

213 

214def get_rml(E_n, p_vnn, pol_v, Etol=1e-6): 

215 """ 

216 Compute the position matrix elements 

217 

218 Parameters 

219 ---------- 

220 E_n 

221 Band energies. 

222 p_vnn 

223 Momentum matrix elements. 

224 pol_v 

225 Tensor element. 

226 Etol 

227 Tolerance in energy to consider degeneracy. 

228 

229 Returns 

230 ------- 

231 r_vnn 

232 Position matrix elements. 

233 D_vnn 

234 Velocity difference matrix elements. 

235 

236 """ 

237 

238 # Useful variables 

239 nb = len(E_n) 

240 r_vnn = np.zeros((3, nb, nb), complex) 

241 D_vnn = np.zeros((3, nb, nb), complex) 

242 E_nn = np.tile(E_n[:, None], (1, nb)) - \ 

243 np.tile(E_n[None, :], (nb, 1)) 

244 zeroind = np.abs(E_nn) < Etol 

245 E_nn[zeroind] = 1 

246 # Loop over components 

247 for v1 in set(pol_v): 

248 r_vnn[v1] = p_vnn[v1] / (1j * E_nn) 

249 r_vnn[v1, zeroind] = 0 

250 p_n = np.diag(p_vnn[v1]) 

251 D_vnn[v1] = np.tile(p_n[:, None], (1, nb)) - \ 

252 np.tile(p_n[None, :], (nb, 1)) 

253 

254 return r_vnn, D_vnn 

255 

256 

257def get_derivative(E_n, r_vnn, D_vnn, pol_v, Etol=1e-6): 

258 """ 

259 Compute the generalised derivative of position matrix elements 

260 

261 Parameters 

262 ---------- 

263 E_n 

264 Band energies. 

265 r_vnn 

266 Momentum matrix elements. 

267 D_vnn 

268 Velocity difference matrix elements. 

269 pol_v 

270 Tensor element. 

271 Etol 

272 Tolerance in energy to consider degeneracy. 

273 

274 Returns 

275 ------- 

276 rd_vvnn 

277 Generalised derivative of position matrix elements. 

278 

279 """ 

280 

281 # Useful variables 

282 nb = len(E_n) 

283 rd_vvnn = np.zeros((3, 3, nb, nb), complex) 

284 E_nn = np.tile(E_n[:, None], (1, nb)) - \ 

285 np.tile(E_n[None, :], (nb, 1)) 

286 zeroind = np.abs(E_nn) < Etol 

287 E_nn[zeroind] = 1 

288 for v1 in set(pol_v): 

289 for v2 in set(pol_v): 

290 tmp = (r_vnn[v1] * np.transpose(D_vnn[v2]) 

291 + r_vnn[v2] * np.transpose(D_vnn[v1]) 

292 + 1j * np.dot(r_vnn[v1], r_vnn[v2] * E_nn) 

293 - 1j * np.dot(r_vnn[v2] * E_nn, r_vnn[v1])) / E_nn 

294 tmp[zeroind] = 0 

295 rd_vvnn[v1, v2] = tmp 

296 

297 return rd_vvnn