Coverage for gpaw/lcao/tci.py: 99%

217 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2import scipy.sparse as sparse 

3from ase.neighborlist import PrimitiveNeighborList 

4# from ase.utils.timing import timer 

5from gpaw.utilities.tools import tri2full 

6 

7# from gpaw import debug 

8from gpaw.lcao.overlap import (FourierTransformer, TwoSiteOverlapCalculator, 

9 ManySiteOverlapCalculator, 

10 AtomicDisplacement, NullPhases, BlochPhases, 

11 DerivativeAtomicDisplacement) 

12 

13 

14def get_cutoffs(f_Ij): 

15 rcutmax_I = [] 

16 for f_j in f_Ij: 

17 rcutmax = 0.001 # 'paranoid zero' 

18 for f in f_j: 

19 rcutmax = max(rcutmax, f.get_cutoff()) 

20 rcutmax_I.append(rcutmax) 

21 return rcutmax_I 

22 

23 

24def get_lvalues(f_Ij): 

25 return [[f.get_angular_momentum_number() for f in f_j] for f_j in f_Ij] 

26 

27 

28class AtomPairRegistry: 

29 def __init__(self, cutoff_a, pbc_c, cell_cv, spos_ac): 

30 nl = PrimitiveNeighborList(cutoff_a, skin=0, sorted=True, 

31 self_interaction=True, 

32 use_scaled_positions=True) 

33 

34 nl.update(pbc=pbc_c, cell=cell_cv, coordinates=spos_ac) 

35 r_and_offset_aao = {} 

36 

37 def add(a1, a2, R_c, offset): 

38 r_and_offset_aao.setdefault((a1, a2), []).append((R_c, offset)) 

39 

40 for a1, spos1_c in enumerate(spos_ac): 

41 a2_a, offsets = nl.get_neighbors(a1) 

42 for a2, offset in zip(a2_a, offsets): 

43 spos2_c = spos_ac[a2] + offset 

44 

45 R_c = np.dot(spos2_c - spos1_c, cell_cv) 

46 add(a1, a2, R_c, offset) 

47 if a1 != a2 or offset.any(): 

48 add(a2, a1, -R_c, -offset) 

49 self.r_and_offset_aao = r_and_offset_aao 

50 

51 def get(self, a1, a2): 

52 R_c_and_offset_a = self.r_and_offset_aao.get((a1, a2)) 

53 return R_c_and_offset_a 

54 

55 def get_atompairs(self): 

56 return list(sorted(self.r_and_offset_aao)) 

57 

58 

59class TCIExpansions: 

60 def __init__(self, phit_Ij, pt_Ij, I_a): 

61 assert len(pt_Ij) == len(phit_Ij) 

62 

63 # Cutoffs by species: 

64 pt_rcmax_I = get_cutoffs(pt_Ij) 

65 phit_rcmax_I = get_cutoffs(phit_Ij) 

66 rcmax_I = [max(rc1, rc2) for rc1, rc2 

67 in zip(pt_rcmax_I, phit_rcmax_I)] 

68 

69 transformer = FourierTransformer(rcut=max(rcmax_I + [1e-3]), N=2**10) 

70 tsoc = TwoSiteOverlapCalculator(transformer) 

71 msoc = ManySiteOverlapCalculator(tsoc, I_a, I_a) 

72 phit_Ijq = msoc.transform(phit_Ij) 

73 pt_Ijq = msoc.transform(pt_Ij) 

74 pt_l_Ij = get_lvalues(pt_Ij) 

75 phit_l_Ij = get_lvalues(phit_Ij) 

76 self.O_expansions = msoc.calculate_expansions(phit_l_Ij, phit_Ijq, 

77 phit_l_Ij, phit_Ijq) 

78 self.T_expansions = msoc.calculate_kinetic_expansions(phit_l_Ij, 

79 phit_Ijq) 

80 self.P_expansions = msoc.calculate_expansions(pt_l_Ij, pt_Ijq, 

81 phit_l_Ij, phit_Ijq) 

82 self.I_a = I_a # Actually I_a belongs outside, like spos_ac. 

83 self.rcmax_I = rcmax_I 

84 self.phit_rcmax_I = phit_rcmax_I 

85 self.pt_rcmax_I = pt_rcmax_I 

86 

87 @classmethod 

88 def new_from_setups(cls, setups): 

89 I_setup = {} 

90 setups_I = list(setups.setups.values()) 

91 for I, setup in enumerate(setups_I): 

92 I_setup[setup] = I 

93 I_a = [I_setup[setup] for setup in setups] 

94 

95 return TCIExpansions([s.basis_functions_J for s in setups_I], 

96 [s.pt_j for s in setups_I], 

97 I_a) 

98 

99 def get_tci_calculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype): 

100 return TCICalculator(self, cell_cv, spos_ac, pbc_c, ibzk_qc, dtype) 

101 

102 def get_manytci_calculator(self, setups, gd, spos_ac, ibzk_qc, dtype, 

103 timer): 

104 return ManyTCICalculator(self, setups, gd, spos_ac, ibzk_qc, dtype, 

105 timer) 

106 

107 

108class TCICalculator: 

109 """High-level two-center integral calculator. 

110 

111 This object is not aware of parallelization. It works with any 

112 pair of atoms a1, a2. 

113 

114 Create the object and calculate any interatomic overlap matrix as below. 

115 

116 tci = TCI(...) 

117 

118 Projector/basis overlap <pt_i^a1|phi_mu> between atoms a1, a2: 

119 

120 P_qim = tci.P(a1, a2) 

121 

122 Derivatives of the above with respect to movement of a2: 

123 

124 dPdR_qvim = tci.dPdR(a1, a2) 

125 

126 Basis/basis overlap and kinetic matrix elements between atoms a1, a2: 

127 

128 O_qmm, T_qmm = tci.O_T(a1, a2) 

129 

130 Derivative of the above wrt. position of a2: 

131 

132 dOdR_qvmm, dTdR_qvmm = tci.dOdR_dTdR(a1, a2) 

133 

134 """ 

135 def __init__(self, tciexpansions, cell_cv, spos_ac, pbc_c, ibzk_qc, 

136 dtype): 

137 

138 self.tciexpansions = tciexpansions 

139 self.dtype = dtype 

140 

141 # XXX It is somewhat nasty that rcmax depends on how long our 

142 # longest orbital happens to be 

143 # Cutoffs by atom: 

144 I_a = tciexpansions.I_a 

145 cutoff_a = [tciexpansions.rcmax_I[I] for I in I_a] 

146 self.pt_rcmax_a = np.array([tciexpansions.pt_rcmax_I[I] for I in I_a]) 

147 self.phit_rcmax_a = np.array([tciexpansions.phit_rcmax_I[I] 

148 for I in I_a]) 

149 

150 self.a1a2 = AtomPairRegistry(cutoff_a, pbc_c, cell_cv, spos_ac) 

151 

152 self.ibzk_qc = ibzk_qc 

153 if ibzk_qc.any(): 

154 self.get_phases = BlochPhases 

155 else: 

156 self.get_phases = NullPhases 

157 

158 self.O_T = self._tci_shortcut(False, False) 

159 self.P = self._tci_shortcut(True, False) 

160 self.dOdR_dTdR = self._tci_shortcut(False, True) 

161 self.dPdR = self._tci_shortcut(True, True) 

162 

163 def _tci_shortcut(self, P, derivative): 

164 def calculate(a1, a2): 

165 return self._calculate(a1, a2, P, derivative) 

166 return calculate 

167 

168 def _calculate(self, a1, a2, P=False, derivative=False): 

169 """Calculate overlap of functions between atoms a1 and a2.""" 

170 

171 # We want to see quickly if there is no overlap because distance 

172 # outside bounding spheres. 

173 

174 R_c_and_offset_a = self.a1a2.get(a1, a2) 

175 if R_c_and_offset_a is None: 

176 return None if P else (None, None) 

177 

178 rcut1 = self.pt_rcmax_a[a1] if P else self.phit_rcmax_a[a1] 

179 rcut2 = self.phit_rcmax_a[a2] 

180 maxdist = rcut1 + rcut2 

181 

182 # Filter out displacements larger than maxdist: 

183 R_c_and_offset_a = [obj for obj in R_c_and_offset_a 

184 if np.linalg.norm(obj[0]) < maxdist] 

185 if not R_c_and_offset_a: # There was no overlap after all 

186 return None if P else (None, None) 

187 

188 dtype = self.dtype 

189 get_phases = self.get_phases 

190 

191 displacement = (DerivativeAtomicDisplacement 

192 if derivative 

193 else AtomicDisplacement) 

194 ibzk_qc = self.ibzk_qc 

195 nq = len(ibzk_qc) 

196 phit_rcmax_a = self.phit_rcmax_a 

197 pt_rcmax_a = self.pt_rcmax_a 

198 

199 shape = (nq, 3) if derivative else (nq,) 

200 

201 if P: 

202 P_expansion = self.tciexpansions.P_expansions.get(a1, a2) 

203 obj = P_qim = P_expansion.zeros(shape, dtype=dtype) 

204 else: 

205 O_expansion = self.tciexpansions.O_expansions.get(a1, a2) 

206 T_expansion = self.tciexpansions.T_expansions.get(a1, a2) 

207 O_qmm = O_expansion.zeros(shape, dtype=dtype) 

208 T_qmm = T_expansion.zeros(shape, dtype=dtype) 

209 obj = O_qmm, T_qmm 

210 

211 for R_c, offset in R_c_and_offset_a: 

212 norm = np.linalg.norm(R_c) 

213 phases = get_phases(ibzk_qc, offset) 

214 

215 disp = displacement(None, a1, a2, R_c, offset, phases) 

216 

217 if P: 

218 assert norm < pt_rcmax_a[a1] + phit_rcmax_a[a2] 

219 disp.evaluate_overlap(P_expansion, P_qim) 

220 else: 

221 assert norm < phit_rcmax_a[a1] + phit_rcmax_a[a2] 

222 disp.evaluate_overlap(O_expansion, O_qmm) 

223 disp.evaluate_overlap(T_expansion, T_qmm) 

224 

225 return obj 

226 

227 

228class ManyTCICalculator: 

229 def __init__(self, tciexpansions, setups, gd, spos_ac, ibzk_qc, dtype, 

230 timer): 

231 self.tci = tciexpansions.get_tci_calculator(gd.cell_cv, spos_ac, 

232 gd.pbc_c, 

233 ibzk_qc, dtype) 

234 

235 self.setups = setups 

236 self.dtype = dtype 

237 self.Pindices = setups.projector_indices() 

238 self.Mindices = setups.basis_indices() 

239 self.natoms = len(setups) 

240 self.nq = len(ibzk_qc) 

241 self.nao = self.Mindices.max 

242 self.timer = timer 

243 

244 # @timer('tci-projectors') 

245 def P_aqMi(self, my_atom_indices, derivative=False): 

246 P_axMi = {} 

247 if derivative: 

248 P = self.tci.dPdR 

249 

250 def empty(nI): 

251 return np.empty((self.nq, 3, self.nao, nI), self.dtype) 

252 else: 

253 P = self.tci.P 

254 

255 def empty(nI): 

256 return np.empty((self.nq, self.nao, nI), self.dtype) 

257 

258 Mindices = self.Mindices 

259 

260 for a1 in my_atom_indices: 

261 P_xMi = empty(self.setups[a1].ni) 

262 

263 for a2 in range(self.natoms): 

264 N1, N2 = Mindices[a2] 

265 P_xmi = P_xMi[..., N1:N2, :] 

266 P_xim = P(a1, a2) 

267 if P_xim is None: 

268 P_xmi[:] = 0.0 

269 else: 

270 P_xmi[:] = P_xim.swapaxes(-2, -1).conj() 

271 P_axMi[a1] = P_xMi 

272 

273 if derivative: 

274 for a in P_axMi: 

275 P_axMi[a] *= -1.0 

276 return P_axMi 

277 

278 # @timer('tci-sparseprojectors') 

279 def P_qIM(self, my_atom_indices): 

280 nq = self.nq 

281 P = self.tci.P 

282 P_qIM = [sparse.dok_matrix((self.Pindices.max, self.Mindices.max), 

283 dtype=self.dtype) 

284 for _ in range(nq)] 

285 

286 for a1 in my_atom_indices: 

287 I1, I2 = self.Pindices[a1] 

288 

289 # We can stride a2 over e.g. bd.comm and then do bd.comm.sum(). 

290 # How should we do comm.sum() on a sparse matrix though? 

291 for a2 in range(self.natoms): 

292 M1, M2 = self.Mindices[a2] 

293 P_qim = P(a1, a2) 

294 if P_qim is not None: 

295 for q in range(nq): 

296 P_qIM[q][I1:I2, M1:M2] = P_qim[q] 

297 P_qIM = [P_IM.tocsr() for P_IM in P_qIM] 

298 return P_qIM 

299 

300 # @timer('tci-bfs') 

301 def O_qMM_T_qMM(self, gdcomm, Mstart, Mstop, ignore_upper=False, 

302 derivative=False): 

303 mynao = Mstop - Mstart 

304 Mindices = self.Mindices 

305 

306 if derivative: 

307 O_T = self.tci.dOdR_dTdR 

308 shape = (self.nq, 3, mynao, self.nao) 

309 else: 

310 O_T = self.tci.O_T 

311 shape = (self.nq, mynao, self.nao) 

312 

313 O_xMM = np.zeros(shape, self.dtype) 

314 T_xMM = np.zeros(shape, self.dtype) 

315 

316 # XXX the a1/a2 loops are not yet well load balanced. 

317 for a1 in range(self.natoms): 

318 M1, M2 = Mindices[a1] 

319 if M2 <= Mstart or M1 >= Mstop: 

320 continue 

321 

322 myM1 = max(M1 - Mstart, 0) 

323 myM2 = min(M2 - Mstart, mynao) 

324 nM = myM2 - myM1 

325 

326 assert nM > 0, nM 

327 

328 a2max = a1 + 1 # if not derivative else self.natoms 

329 

330 for a2 in range(gdcomm.rank, a2max, gdcomm.size): 

331 O_xmm, T_xmm = O_T(a1, a2) 

332 if O_xmm is None: 

333 continue 

334 

335 N1, N2 = Mindices[a2] 

336 m1 = max(Mstart - M1, 0) 

337 m2 = m1 + nM # (Slice may go beyond end of matrix but OK) 

338 O_xmm = O_xmm[..., m1:m2, :] 

339 T_xmm = T_xmm[..., m1:m2, :] 

340 O_xMM[..., myM1:myM2, N1:N2] = O_xmm 

341 T_xMM[..., myM1:myM2, N1:N2] = T_xmm 

342 

343 if not ignore_upper and O_xMM.size: # reshape() fails on size-0 arrays 

344 assert mynao == self.nao 

345 assert O_xMM.shape[-2:] == (self.nao, self.nao) 

346 if derivative: 

347 def lumap(arr, out): 

348 np.conj(arr, out) 

349 out *= -1.0 

350 else: 

351 lumap = np.conj 

352 

353 for arr_xMM in [O_xMM, T_xMM]: 

354 for tmp_MM in arr_xMM.reshape(-1, self.nao, self.nao): 

355 tri2full(tmp_MM, UL='L', map=lumap) 

356 

357 return O_xMM, T_xMM