Coverage for gpaw/pair_overlap.py: 10%

243 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2 

3from gpaw import debug 

4from gpaw.mpi import world 

5from gpaw.overlap import Overlap 

6from gpaw.utilities import unpack_hermitian 

7from gpaw.lfc import LocalizedFunctionsCollection as LFC 

8 

9 

10def mpi_debug(x, ordered=True): 

11 return None # silenced 

12 

13 

14class PairOverlap: 

15 def __init__(self, gd, setups): 

16 self.gd = gd 

17 self.setups = setups 

18 self.ni_a = np.cumsum([0] + [setup.ni for setup in self.setups]) 

19 

20 def __len__(self): 

21 return self.ni_a[-1].item() 

22 

23 def assign_atomic_pair_matrix(self, X_aa, a1, a2, dX_ii): 

24 X_aa[self.ni_a[a1]:self.ni_a[a1 + 1], 

25 self.ni_a[a2]:self.ni_a[a2 + 1]] = dX_ii 

26 

27 def extract_atomic_pair_matrix(self, X_aa, a1, a2): 

28 return X_aa[self.ni_a[a1]:self.ni_a[a1 + 1], 

29 self.ni_a[a2]:self.ni_a[a2 + 1]] 

30 

31 def calculate_overlaps(self, spos_ac, lfc1, lfc2=None): 

32 raise RuntimeError('This is a virtual member function.') 

33 

34 def calculate_atomic_pair_overlaps( 

35 self, lfs1, lfs2): # XXX Move some code here from above... 

36 raise RuntimeError('This is a virtual member function.') 

37 

38 

39class GridPairOverlap(PairOverlap): 

40 def calculate_overlaps(self, spos_ac, lfc1, lfc2=None): 

41 # CONDITION: The two sets of splines must belong to the same kpoint! 

42 

43 if lfc2 is None: 

44 lfc2 = lfc1 

45 

46 if isinstance(lfc1, LFC) and isinstance(lfc2, LFC): 

47 return self.calculate_overlaps2(spos_ac, lfc1, lfc2) 

48 

49 assert not isinstance(lfc1, LFC) and not isinstance(lfc2, LFC) 

50 

51 nproj = len(self) 

52 X_aa = np.zeros((nproj, nproj), dtype=float) # XXX always float? 

53 

54 if debug: 

55 if world.rank == 0: 

56 print('DEBUG INFO') 

57 

58 mpi_debug('lfc1.lfs_a.keys(): %s' % lfc1.lfs_a.keys()) 

59 mpi_debug('lfc2.lfs_a.keys(): %s' % lfc2.lfs_a.keys()) 

60 mpi_debug('N_c=%s, beg_c=%s, end_c=%s' % 

61 (self.gd.N_c, self.gd.beg_c, self.gd.end_c)) 

62 

63 assert len(lfc1.spline_aj) == len(lfc1.spos_ac) # not distributed 

64 assert len(lfc2.spline_aj) == len(lfc2.spos_ac) # not distributed 

65 # assert lfc1.lfs_a.keys() == lfc2.lfs_a.keys() 

66 # XXX must they be equal?!? 

67 

68 # Both loops are over all atoms in all domains 

69 for a1, spline1_j in enumerate(lfc1.spline_aj): 

70 # We assume that all functions have the same cut-off: 

71 rcut1 = spline1_j[0].get_cutoff() 

72 if debug: 

73 mpi_debug('a1=%d, spos1_c=%s, rcut1=%g, ni1=%d' % 

74 (a1, spos_ac[a1], rcut1, self.setups[a1].ni)) 

75 

76 for a2, spline2_j in enumerate(lfc2.spline_aj): 

77 # We assume that all functions have the same cut-off: 

78 rcut2 = spline2_j[0].get_cutoff() 

79 if debug: 

80 mpi_debug(' a2=%d, spos2_c=%s, rcut2=%g, ni2=%d' % 

81 (a2, spos_ac[a2], rcut2, self.setups[a2].ni)) 

82 

83 X_ii = self.extract_atomic_pair_matrix(X_aa, a1, a2) 

84 

85 b1 = 0 

86 for beg1_c, end1_c, sdisp1_c in self.gd.get_boxes( 

87 spos_ac[a1], rcut1, 

88 cut=False): # loop over lfs1.box_b instead? 

89 if debug: 

90 mpi_debug( 

91 ' b1=%d, beg1_c=%s, end1_c=%s, sdisp1_c=%s' % 

92 (b1, beg1_c, end1_c, sdisp1_c), 

93 ordered=False) 

94 

95 # Atom a1 has at least one piece so the LFC has LocFuncs 

96 lfs1 = lfc1.lfs_a[a1] 

97 

98 # Similarly, the LocFuncs must have the piece at hand 

99 box1 = lfs1.box_b[b1] 

100 

101 if debug: 

102 assert lfs1.dtype == lfc1.dtype 

103 assert self.setups[ 

104 a1].ni == lfs1.ni, 'setups[%d].ni=%d,'\ 

105 'lfc1.lfs_a[%d].ni=%d'\ 

106 % (a1, self.setups[a1].ni, a1, lfs1.i) 

107 

108 b2 = 0 

109 for beg2_c, end2_c, sdisp2_c in self.gd.get_boxes( 

110 spos_ac[a2], rcut2, 

111 cut=False): # loop over lfs2.box_b instead? 

112 if debug: 

113 mpi_debug( 

114 ' b2=%d, beg2_c=%s, end2_c=%s, sdisp2' 

115 'c=%s' % (b2, beg2_c, end2_c, sdisp2_c), 

116 ordered=False) 

117 

118 # Atom a2 has at least one piece so the LFC has 

119 # LocFuncs 

120 lfs2 = lfc2.lfs_a[a2] 

121 

122 # Similarly, the LocFuncs must have the piece at hand 

123 box2 = lfs2.box_b[b2] 

124 

125 if debug: 

126 assert lfs2.dtype == lfc2.dtype 

127 assert self.setups[ 

128 a2].ni == lfs2.ni, 'setups[%d].ni=%d,'\ 

129 ' lfc2.lfs_a[%d].ni=%d'\ 

130 % (a2, self.setups[a2].ni, a2, lfs2.ni) 

131 

132 # Find the intersection of the two boxes 

133 beg_c = np.max((beg1_c, beg2_c), axis=0) 

134 end_c = np.min((end1_c, end2_c), axis=0) 

135 

136 if debug: 

137 mpi_debug(' beg_c=%s, end_c=%s, size_c=%s' % 

138 (beg_c, end_c, tuple(end_c - beg_c)), 

139 ordered=False) 

140 

141 # Intersection is non-empty, add overlap contribution 

142 if (beg_c < end_c).all(): 

143 bra_iB1 = box1.get_functions() 

144 w1slice = [slice(None)] + [slice(b, e) for b, e in 

145 zip(beg_c - beg1_c, 

146 end_c - beg1_c)] 

147 

148 ket_iB2 = box2.get_functions() 

149 w2slice = [slice(None)] + [slice(b, e) for b, e in 

150 zip(beg_c - beg2_c, 

151 end_c - beg2_c)] 

152 

153 X_ii += self.gd.dv * np.inner( 

154 bra_iB1[w1slice].reshape((lfs1.ni, -1)), 

155 ket_iB2[w2slice].reshape((lfs2.ni, -1))) 

156 # XXX phase factors for kpoints 

157 

158 del bra_iB1, ket_iB2 

159 

160 b2 += 1 

161 

162 b1 += 1 

163 

164 self.gd.comm.sum(X_aa) # better to sum over X_ii? 

165 return X_aa 

166 

167 def calculate_overlaps2(self, spos_ac, lfc1, lfc2=None): 

168 # CONDITION: The two sets of splines must belong to the same kpoint! 

169 

170 if lfc2 is None: 

171 lfc2 = lfc1 

172 

173 assert isinstance(lfc1, LFC) and isinstance(lfc2, LFC) 

174 

175 nproj = len(self) 

176 X_aa = np.zeros((nproj, nproj), dtype=float) # XXX always float? 

177 

178 if debug: 

179 if world.rank == 0: 

180 print('DEBUG INFO') 

181 

182 mpi_debug('len(lfc1.sphere_a): %d, lfc1.atom_indices: %s' % 

183 (len(lfc1.sphere_a), lfc1.atom_indices)) 

184 mpi_debug('len(lfc2.sphere_a): %d, lfc2.atom_indices: %s' % 

185 (len(lfc2.sphere_a), lfc2.atom_indices)) 

186 mpi_debug('N_c=%s, beg_c=%s, end_c=%s' % 

187 (self.gd.N_c, self.gd.beg_c, self.gd.end_c)) 

188 

189 if debug: 

190 assert len(lfc1.sphere_a) == len( 

191 lfc2.sphere_a) # XXX must they be equal?!? 

192 

193 # Both a-loops are over all relevant atoms which affect this domain 

194 for a1 in lfc1.atom_indices: 

195 sphere1 = lfc1.sphere_a[a1] 

196 

197 # We assume that all functions have the same cut-off: 

198 spline1_j = sphere1.spline_j 

199 rcut1 = spline1_j[0].get_cutoff() 

200 if debug: 

201 mpi_debug('a1=%d, spos1_c=%s, rcut1=%g, ni1=%d' % 

202 (a1, spos_ac[a1], rcut1, self.setups[a1].ni), 

203 ordered=False) 

204 

205 for a2 in lfc2.atom_indices: 

206 sphere2 = lfc2.sphere_a[a2] 

207 

208 # We assume that all functions have the same cut-off: 

209 spline2_j = sphere2.spline_j 

210 rcut2 = spline2_j[0].get_cutoff() 

211 if debug: 

212 mpi_debug(' a2=%d, spos2_c=%s, rcut2=%g, ni2=%d' % 

213 (a2, spos_ac[a2], rcut2, self.setups[a2].ni), 

214 ordered=False) 

215 

216 X_ii = self.extract_atomic_pair_matrix(X_aa, a1, a2) 

217 

218 b1 = 0 

219 for beg1_c, end1_c, sdisp1_c in self.gd.get_boxes( 

220 spos_ac[a1], rcut1, 

221 cut=False): # loop over lfs1.box_b instead? 

222 if debug: 

223 mpi_debug( 

224 ' b1=%d, beg1_c=%s, end1_c=%s, sdisp1_c=%s' % 

225 (b1, beg1_c, end1_c, sdisp1_c), 

226 ordered=False) 

227 

228 b2 = 0 

229 for beg2_c, end2_c, sdisp2_c in self.gd.get_boxes( 

230 spos_ac[a2], rcut2, 

231 cut=False): # loop over lfs2.box_b instead? 

232 if debug: 

233 mpi_debug( 

234 ' b2=%d, beg2_c=%s, end2_c=%s,' 

235 'sdisp2_c=%s' % (b2, beg2_c, end2_c, sdisp2_c), 

236 ordered=False) 

237 

238 # Find the intersection of the two boxes 

239 beg_c = np.max((beg1_c, beg2_c), axis=0) 

240 end_c = np.min((end1_c, end2_c), axis=0) 

241 

242 if debug: 

243 mpi_debug(' beg_c=%s, end_c=%s, size_c=%s' % 

244 (beg_c, end_c, tuple(end_c - beg_c)), 

245 ordered=False) 

246 

247 # Intersection is non-empty, add overlap contribution 

248 if (beg_c < end_c).all(): 

249 i1 = 0 

250 for j1, spline1 in enumerate(spline1_j): 

251 bra1_mB = spline1.get_functions(self.gd, 

252 beg_c, end_c, 

253 spos_ac[a1] 

254 - sdisp1_c) 

255 nm1 = bra1_mB.shape[0] 

256 

257 i2 = 0 

258 for j2, spline2 in enumerate(spline2_j): 

259 ket2_mB = spline2.get_functions(self.gd, 

260 beg_c, 

261 end_c, 

262 spos_ac[a2] 

263 - sdisp2_c) 

264 nm2 = ket2_mB.shape[0] 

265 

266 X_mm = X_ii[i1:i1 + nm1, i2:i2 + nm2] 

267 X_mm += self.gd.dv * np.inner( 

268 bra1_mB.reshape((nm1, -1)), 

269 ket2_mB.reshape((nm2, -1))) 

270 # XXX phase factors for kpoints 

271 

272 del ket2_mB 

273 i2 += nm2 

274 

275 del bra1_mB 

276 i1 += nm1 

277 

278 b2 += 1 

279 

280 b1 += 1 

281 

282 self.gd.comm.sum(X_aa) # better to sum over X_ii? 

283 return X_aa 

284 

285 

286class ProjectorPairOverlap(Overlap, GridPairOverlap): 

287 """ 

288 TODO 

289 """ 

290 

291 def __init__(self, wfs, atoms): 

292 """TODO 

293 

294 Attributes: 

295 

296 ============ ====================================================== 

297 ``B_aa`` < p_i^a | p_i'^a' > 

298 ``xO_aa`` TODO 

299 ``dC_aa`` TODO 

300 ``xC_aa`` TODO 

301 ============ ====================================================== 

302 """ 

303 

304 Overlap.__init__(self, wfs.orthoksl, wfs.timer) 

305 GridPairOverlap.__init__(self, wfs.gd, wfs.setups) 

306 self.natoms = len(atoms) 

307 if debug: 

308 assert len(self.setups) == self.natoms 

309 self.update(wfs, atoms) 

310 

311 def update(self, wfs, atoms): 

312 self.timer.start('Update two-center overlap') 

313 

314 nproj = len(self) 

315 """ 

316 self.B_aa = np.zeros((nproj, nproj), dtype=float) #always float? 

317 for a1,setup1 in enumerate(self.setups): 

318 for a2 in wfs.pt.my_atom_indices: 

319 setup2 = self.setups[a2] 

320 R = (atoms[a1].get_position() 

321 - atoms[a2].get_position()) / Bohr 

322 

323 if a1 == a2: 

324 B_ii = setup1.B_ii 

325 else: 

326 B_ii = projector_overlap_matrix2(setup1, setup2, R) 

327 #if a1 < a2: 

328 # B_ii = projector_overlap_matrix2(setup1, setup2, R) 

329 #elif a1 == a2: 

330 # B_ii = setup1.B_ii 

331 #else: 

332 # B_ii = self.B_aa[ni_a[a2]:ni_a[a2+1], 

333 ni_a[a1]:ni_a[a1+1]].T 

334 

335 #self.B_aa[self.ni_a[a1]:self.ni_a[a1+1], \ 

336 # self.ni_a[a2]:self.ni_a[a2+1]] = B_ii 

337 self.assign_atomic_pair_matrix(self.B_aa, a1, a2, B_ii) 

338 self.gd.comm.sum(self.B_aa) # TODO too heavy? 

339 """ 

340 # self.B_aa = overlap_projectors(wfs.gd, wfs.pt, wfs.setups) 

341 

342 self.B_aa = self.calculate_overlaps(wfs.spos_ac, wfs.pt) 

343 

344 # Create two-center (block-diagonal) coefficients for overlap operator 

345 dO_aa = np.zeros((nproj, nproj), dtype=float) # always float? 

346 for a, setup in enumerate(self.setups): 

347 self.assign_atomic_pair_matrix(dO_aa, a, a, setup.dO_ii) 

348 

349 # Calculate two-center rotation matrix for overlap projections 

350 self.xO_aa = self.get_rotated_coefficients(dO_aa) 

351 

352 # Calculate two-center coefficients for inverse overlap operator 

353 lhs_aa = np.eye(nproj) + self.xO_aa 

354 rhs_aa = -dO_aa 

355 self.dC_aa = np.linalg.solve(lhs_aa.T, rhs_aa.T).T # TODO parallel 

356 

357 # Calculate two-center rotation matrix for inverse overlap projections 

358 self.xC_aa = self.get_rotated_coefficients(self.dC_aa) 

359 

360 self.timer.stop('Update two-center overlap') 

361 

362 def get_rotated_coefficients(self, X_aa): 

363 r"""Rotate two-center projector expansion coefficients with 

364 the projector-projector overlap integrals as basis. 

365 

366 Performs the following operation and returns the result:: 

367 

368 --- 

369 a1,a3 \ a1 a2 a2,a3 

370 Y = ) < p | p > X 

371 i1,i3 / i1 i2 i2,i3 

372 --- 

373 a2,i2 

374 """ 

375 return np.dot(self.B_aa, X_aa) 

376 

377 def apply_to_atomic_matrices(self, dI_asp, P_axi, wfs, kpt, shape=()): 

378 

379 self.timer.start('Update two-center projections') 

380 

381 nproj = len(self) 

382 dI_aa = np.zeros((nproj, nproj), dtype=float) # always float? 

383 

384 for a, dI_sp in dI_asp.items(): 

385 dI_p = dI_sp[kpt.s] 

386 dI_ii = unpack_hermitian(dI_p) 

387 self.assign_atomic_pair_matrix(dI_aa, a, a, dI_ii) 

388 self.gd.comm.sum(dI_aa) # TODO too heavy? 

389 

390 dM_aa = self.get_rotated_coefficients(dI_aa) 

391 Q_axi = wfs.pt.dict(shape, zero=True) 

392 for a1 in range(self.natoms): 

393 if a1 in Q_axi.keys(): 

394 Q_xi = Q_axi[a1] 

395 else: 

396 # Atom a1 is not in domain so allocate a temporary buffer 

397 Q_xi = np.zeros(shape + (self.setups[a1].ni, ), 

398 dtype=wfs.pt.dtype) # TODO 

399 for a2, P_xi in P_axi.items(): 

400 dM_ii = self.extract_atomic_pair_matrix(dM_aa, a1, a2) 

401 # sum over a2 and last i in dM_ii 

402 Q_xi += np.dot(P_xi, dM_ii.T) 

403 self.gd.comm.sum(Q_xi) 

404 

405 self.timer.stop('Update two-center projections') 

406 

407 return Q_axi 

408 

409 def apply(self, 

410 a_xG, 

411 b_xG, 

412 wfs, 

413 kpt, 

414 calculate_P_ani=True, 

415 extrapolate_P_ani=False): 

416 """Apply the overlap operator to a set of vectors. 

417 

418 Parameters 

419 ========== 

420 a_nG: ndarray 

421 Set of vectors to which the overlap operator is applied. 

422 b_nG: ndarray, output 

423 Resulting S times a_nG vectors. 

424 kpt: KPoint object 

425 k-point object defined in kpoint.py. 

426 calculate_P_ani: bool 

427 When True, the integrals of projector times vectors 

428 P_ni = <p_i | a_nG> are calculated. 

429 When False, existing P_ani are used 

430 extrapolate_P_ani: bool 

431 When True, the integrals of projector times vectors#XXX TODO 

432 P_ni = <p_i | a_nG> are calculated. 

433 When False, existing P_ani are used 

434 

435 """ 

436 

437 self.timer.start('Apply overlap') 

438 b_xG[:] = a_xG 

439 shape = a_xG.shape[:-3] 

440 P_axi = wfs.pt.dict(shape) 

441 

442 if calculate_P_ani: 

443 wfs.pt.integrate(a_xG, P_axi, kpt.q) 

444 else: 

445 for a, P_ni in kpt.P_ani.items(): 

446 P_axi[a][:] = P_ni 

447 

448 Q_axi = wfs.pt.dict(shape) 

449 for a, Q_xi in Q_axi.items(): 

450 Q_xi[:] = np.dot(P_axi[a], self.setups[a].dO_ii) 

451 

452 wfs.pt.add(b_xG, Q_axi, kpt.q) 

453 self.timer.stop('Apply overlap') 

454 

455 if extrapolate_P_ani: 

456 for a1 in range(self.natoms): 

457 if a1 in Q_axi.keys(): 

458 Q_xi = Q_axi[a1] 

459 Q_xi[:] = P_axi[a1] 

460 else: 

461 # Atom a1 is not in domain so allocate a temporary buffer 

462 Q_xi = np.zeros(shape + (self.setups[a1].ni, ), 

463 dtype=wfs.pt.dtype) # TODO 

464 for a2, P_xi in P_axi.items(): 

465 # xO_aa are the overlap extrapolators across atomic pairs 

466 xO_ii = self.extract_atomic_pair_matrix(self.xO_aa, a1, a2) 

467 Q_xi += np.dot(P_xi, 

468 xO_ii.T) # sum over a2 and last i in xO_ii 

469 self.gd.comm.sum(Q_xi) 

470 

471 return Q_axi 

472 else: 

473 return P_axi 

474 

475 def apply_inverse(self, 

476 a_xG, 

477 b_xG, 

478 wfs, 

479 kpt, 

480 calculate_P_ani=True, 

481 extrapolate_P_ani=False): 

482 

483 self.timer.start('Apply inverse overlap') 

484 b_xG[:] = a_xG 

485 shape = a_xG.shape[:-3] 

486 P_axi = wfs.pt.dict(shape) 

487 

488 if calculate_P_ani: 

489 wfs.pt.integrate(a_xG, P_axi, kpt.q) 

490 else: 

491 for a, P_ni in kpt.P_ani.items(): 

492 P_axi[a][:] = P_ni 

493 

494 Q_axi = wfs.pt.dict(shape, zero=True) 

495 for a1 in range(self.natoms): 

496 if a1 in Q_axi.keys(): 

497 Q_xi = Q_axi[a1] 

498 else: 

499 # Atom a1 is not in domain so allocate a temporary buffer 

500 Q_xi = np.zeros(shape + (self.setups[a1].ni, ), 

501 dtype=wfs.pt.dtype) # TODO 

502 for a2, P_xi in P_axi.items(): 

503 # dC_aa are the inverse coefficients across atomic pairs 

504 dC_ii = self.extract_atomic_pair_matrix(self.dC_aa, a1, a2) 

505 # sum over a2 and last i in dC_ii 

506 Q_xi += np.dot(P_xi, dC_ii.T) 

507 self.gd.comm.sum(Q_xi) 

508 

509 wfs.pt.add(b_xG, Q_axi, kpt.q) 

510 self.timer.stop('Apply inverse overlap') 

511 

512 if extrapolate_P_ani: 

513 for a1 in range(self.natoms): 

514 if a1 in Q_axi.keys(): 

515 Q_xi = Q_axi[a1] 

516 Q_xi[:] = P_axi[a1] 

517 else: 

518 # Atom a1 is not in domain so allocate a temporary buffer 

519 Q_xi = np.zeros(shape + (self.setups[a1].ni, ), 

520 dtype=wfs.pt.dtype) # TODO 

521 for a2, P_xi in P_axi.items(): 

522 # xC_aa are the inverse extrapolators across atomic pairs 

523 xC_ii = self.extract_atomic_pair_matrix(self.xC_aa, a1, a2) 

524 Q_xi += np.dot(P_xi, 

525 xC_ii.T) # sum over a2 and last i in xC_ii 

526 self.gd.comm.sum(Q_xi) 

527 

528 return Q_axi 

529 else: 

530 return P_axi