Coverage for gpaw/pair_overlap.py: 10%
243 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
3from gpaw import debug
4from gpaw.mpi import world
5from gpaw.overlap import Overlap
6from gpaw.utilities import unpack_hermitian
7from gpaw.lfc import LocalizedFunctionsCollection as LFC
10def mpi_debug(x, ordered=True):
11 return None # silenced
14class PairOverlap:
15 def __init__(self, gd, setups):
16 self.gd = gd
17 self.setups = setups
18 self.ni_a = np.cumsum([0] + [setup.ni for setup in self.setups])
20 def __len__(self):
21 return self.ni_a[-1].item()
23 def assign_atomic_pair_matrix(self, X_aa, a1, a2, dX_ii):
24 X_aa[self.ni_a[a1]:self.ni_a[a1 + 1],
25 self.ni_a[a2]:self.ni_a[a2 + 1]] = dX_ii
27 def extract_atomic_pair_matrix(self, X_aa, a1, a2):
28 return X_aa[self.ni_a[a1]:self.ni_a[a1 + 1],
29 self.ni_a[a2]:self.ni_a[a2 + 1]]
31 def calculate_overlaps(self, spos_ac, lfc1, lfc2=None):
32 raise RuntimeError('This is a virtual member function.')
34 def calculate_atomic_pair_overlaps(
35 self, lfs1, lfs2): # XXX Move some code here from above...
36 raise RuntimeError('This is a virtual member function.')
39class GridPairOverlap(PairOverlap):
40 def calculate_overlaps(self, spos_ac, lfc1, lfc2=None):
41 # CONDITION: The two sets of splines must belong to the same kpoint!
43 if lfc2 is None:
44 lfc2 = lfc1
46 if isinstance(lfc1, LFC) and isinstance(lfc2, LFC):
47 return self.calculate_overlaps2(spos_ac, lfc1, lfc2)
49 assert not isinstance(lfc1, LFC) and not isinstance(lfc2, LFC)
51 nproj = len(self)
52 X_aa = np.zeros((nproj, nproj), dtype=float) # XXX always float?
54 if debug:
55 if world.rank == 0:
56 print('DEBUG INFO')
58 mpi_debug('lfc1.lfs_a.keys(): %s' % lfc1.lfs_a.keys())
59 mpi_debug('lfc2.lfs_a.keys(): %s' % lfc2.lfs_a.keys())
60 mpi_debug('N_c=%s, beg_c=%s, end_c=%s' %
61 (self.gd.N_c, self.gd.beg_c, self.gd.end_c))
63 assert len(lfc1.spline_aj) == len(lfc1.spos_ac) # not distributed
64 assert len(lfc2.spline_aj) == len(lfc2.spos_ac) # not distributed
65 # assert lfc1.lfs_a.keys() == lfc2.lfs_a.keys()
66 # XXX must they be equal?!?
68 # Both loops are over all atoms in all domains
69 for a1, spline1_j in enumerate(lfc1.spline_aj):
70 # We assume that all functions have the same cut-off:
71 rcut1 = spline1_j[0].get_cutoff()
72 if debug:
73 mpi_debug('a1=%d, spos1_c=%s, rcut1=%g, ni1=%d' %
74 (a1, spos_ac[a1], rcut1, self.setups[a1].ni))
76 for a2, spline2_j in enumerate(lfc2.spline_aj):
77 # We assume that all functions have the same cut-off:
78 rcut2 = spline2_j[0].get_cutoff()
79 if debug:
80 mpi_debug(' a2=%d, spos2_c=%s, rcut2=%g, ni2=%d' %
81 (a2, spos_ac[a2], rcut2, self.setups[a2].ni))
83 X_ii = self.extract_atomic_pair_matrix(X_aa, a1, a2)
85 b1 = 0
86 for beg1_c, end1_c, sdisp1_c in self.gd.get_boxes(
87 spos_ac[a1], rcut1,
88 cut=False): # loop over lfs1.box_b instead?
89 if debug:
90 mpi_debug(
91 ' b1=%d, beg1_c=%s, end1_c=%s, sdisp1_c=%s' %
92 (b1, beg1_c, end1_c, sdisp1_c),
93 ordered=False)
95 # Atom a1 has at least one piece so the LFC has LocFuncs
96 lfs1 = lfc1.lfs_a[a1]
98 # Similarly, the LocFuncs must have the piece at hand
99 box1 = lfs1.box_b[b1]
101 if debug:
102 assert lfs1.dtype == lfc1.dtype
103 assert self.setups[
104 a1].ni == lfs1.ni, 'setups[%d].ni=%d,'\
105 'lfc1.lfs_a[%d].ni=%d'\
106 % (a1, self.setups[a1].ni, a1, lfs1.i)
108 b2 = 0
109 for beg2_c, end2_c, sdisp2_c in self.gd.get_boxes(
110 spos_ac[a2], rcut2,
111 cut=False): # loop over lfs2.box_b instead?
112 if debug:
113 mpi_debug(
114 ' b2=%d, beg2_c=%s, end2_c=%s, sdisp2'
115 'c=%s' % (b2, beg2_c, end2_c, sdisp2_c),
116 ordered=False)
118 # Atom a2 has at least one piece so the LFC has
119 # LocFuncs
120 lfs2 = lfc2.lfs_a[a2]
122 # Similarly, the LocFuncs must have the piece at hand
123 box2 = lfs2.box_b[b2]
125 if debug:
126 assert lfs2.dtype == lfc2.dtype
127 assert self.setups[
128 a2].ni == lfs2.ni, 'setups[%d].ni=%d,'\
129 ' lfc2.lfs_a[%d].ni=%d'\
130 % (a2, self.setups[a2].ni, a2, lfs2.ni)
132 # Find the intersection of the two boxes
133 beg_c = np.max((beg1_c, beg2_c), axis=0)
134 end_c = np.min((end1_c, end2_c), axis=0)
136 if debug:
137 mpi_debug(' beg_c=%s, end_c=%s, size_c=%s' %
138 (beg_c, end_c, tuple(end_c - beg_c)),
139 ordered=False)
141 # Intersection is non-empty, add overlap contribution
142 if (beg_c < end_c).all():
143 bra_iB1 = box1.get_functions()
144 w1slice = [slice(None)] + [slice(b, e) for b, e in
145 zip(beg_c - beg1_c,
146 end_c - beg1_c)]
148 ket_iB2 = box2.get_functions()
149 w2slice = [slice(None)] + [slice(b, e) for b, e in
150 zip(beg_c - beg2_c,
151 end_c - beg2_c)]
153 X_ii += self.gd.dv * np.inner(
154 bra_iB1[w1slice].reshape((lfs1.ni, -1)),
155 ket_iB2[w2slice].reshape((lfs2.ni, -1)))
156 # XXX phase factors for kpoints
158 del bra_iB1, ket_iB2
160 b2 += 1
162 b1 += 1
164 self.gd.comm.sum(X_aa) # better to sum over X_ii?
165 return X_aa
167 def calculate_overlaps2(self, spos_ac, lfc1, lfc2=None):
168 # CONDITION: The two sets of splines must belong to the same kpoint!
170 if lfc2 is None:
171 lfc2 = lfc1
173 assert isinstance(lfc1, LFC) and isinstance(lfc2, LFC)
175 nproj = len(self)
176 X_aa = np.zeros((nproj, nproj), dtype=float) # XXX always float?
178 if debug:
179 if world.rank == 0:
180 print('DEBUG INFO')
182 mpi_debug('len(lfc1.sphere_a): %d, lfc1.atom_indices: %s' %
183 (len(lfc1.sphere_a), lfc1.atom_indices))
184 mpi_debug('len(lfc2.sphere_a): %d, lfc2.atom_indices: %s' %
185 (len(lfc2.sphere_a), lfc2.atom_indices))
186 mpi_debug('N_c=%s, beg_c=%s, end_c=%s' %
187 (self.gd.N_c, self.gd.beg_c, self.gd.end_c))
189 if debug:
190 assert len(lfc1.sphere_a) == len(
191 lfc2.sphere_a) # XXX must they be equal?!?
193 # Both a-loops are over all relevant atoms which affect this domain
194 for a1 in lfc1.atom_indices:
195 sphere1 = lfc1.sphere_a[a1]
197 # We assume that all functions have the same cut-off:
198 spline1_j = sphere1.spline_j
199 rcut1 = spline1_j[0].get_cutoff()
200 if debug:
201 mpi_debug('a1=%d, spos1_c=%s, rcut1=%g, ni1=%d' %
202 (a1, spos_ac[a1], rcut1, self.setups[a1].ni),
203 ordered=False)
205 for a2 in lfc2.atom_indices:
206 sphere2 = lfc2.sphere_a[a2]
208 # We assume that all functions have the same cut-off:
209 spline2_j = sphere2.spline_j
210 rcut2 = spline2_j[0].get_cutoff()
211 if debug:
212 mpi_debug(' a2=%d, spos2_c=%s, rcut2=%g, ni2=%d' %
213 (a2, spos_ac[a2], rcut2, self.setups[a2].ni),
214 ordered=False)
216 X_ii = self.extract_atomic_pair_matrix(X_aa, a1, a2)
218 b1 = 0
219 for beg1_c, end1_c, sdisp1_c in self.gd.get_boxes(
220 spos_ac[a1], rcut1,
221 cut=False): # loop over lfs1.box_b instead?
222 if debug:
223 mpi_debug(
224 ' b1=%d, beg1_c=%s, end1_c=%s, sdisp1_c=%s' %
225 (b1, beg1_c, end1_c, sdisp1_c),
226 ordered=False)
228 b2 = 0
229 for beg2_c, end2_c, sdisp2_c in self.gd.get_boxes(
230 spos_ac[a2], rcut2,
231 cut=False): # loop over lfs2.box_b instead?
232 if debug:
233 mpi_debug(
234 ' b2=%d, beg2_c=%s, end2_c=%s,'
235 'sdisp2_c=%s' % (b2, beg2_c, end2_c, sdisp2_c),
236 ordered=False)
238 # Find the intersection of the two boxes
239 beg_c = np.max((beg1_c, beg2_c), axis=0)
240 end_c = np.min((end1_c, end2_c), axis=0)
242 if debug:
243 mpi_debug(' beg_c=%s, end_c=%s, size_c=%s' %
244 (beg_c, end_c, tuple(end_c - beg_c)),
245 ordered=False)
247 # Intersection is non-empty, add overlap contribution
248 if (beg_c < end_c).all():
249 i1 = 0
250 for j1, spline1 in enumerate(spline1_j):
251 bra1_mB = spline1.get_functions(self.gd,
252 beg_c, end_c,
253 spos_ac[a1]
254 - sdisp1_c)
255 nm1 = bra1_mB.shape[0]
257 i2 = 0
258 for j2, spline2 in enumerate(spline2_j):
259 ket2_mB = spline2.get_functions(self.gd,
260 beg_c,
261 end_c,
262 spos_ac[a2]
263 - sdisp2_c)
264 nm2 = ket2_mB.shape[0]
266 X_mm = X_ii[i1:i1 + nm1, i2:i2 + nm2]
267 X_mm += self.gd.dv * np.inner(
268 bra1_mB.reshape((nm1, -1)),
269 ket2_mB.reshape((nm2, -1)))
270 # XXX phase factors for kpoints
272 del ket2_mB
273 i2 += nm2
275 del bra1_mB
276 i1 += nm1
278 b2 += 1
280 b1 += 1
282 self.gd.comm.sum(X_aa) # better to sum over X_ii?
283 return X_aa
286class ProjectorPairOverlap(Overlap, GridPairOverlap):
287 """
288 TODO
289 """
291 def __init__(self, wfs, atoms):
292 """TODO
294 Attributes:
296 ============ ======================================================
297 ``B_aa`` < p_i^a | p_i'^a' >
298 ``xO_aa`` TODO
299 ``dC_aa`` TODO
300 ``xC_aa`` TODO
301 ============ ======================================================
302 """
304 Overlap.__init__(self, wfs.orthoksl, wfs.timer)
305 GridPairOverlap.__init__(self, wfs.gd, wfs.setups)
306 self.natoms = len(atoms)
307 if debug:
308 assert len(self.setups) == self.natoms
309 self.update(wfs, atoms)
311 def update(self, wfs, atoms):
312 self.timer.start('Update two-center overlap')
314 nproj = len(self)
315 """
316 self.B_aa = np.zeros((nproj, nproj), dtype=float) #always float?
317 for a1,setup1 in enumerate(self.setups):
318 for a2 in wfs.pt.my_atom_indices:
319 setup2 = self.setups[a2]
320 R = (atoms[a1].get_position()
321 - atoms[a2].get_position()) / Bohr
323 if a1 == a2:
324 B_ii = setup1.B_ii
325 else:
326 B_ii = projector_overlap_matrix2(setup1, setup2, R)
327 #if a1 < a2:
328 # B_ii = projector_overlap_matrix2(setup1, setup2, R)
329 #elif a1 == a2:
330 # B_ii = setup1.B_ii
331 #else:
332 # B_ii = self.B_aa[ni_a[a2]:ni_a[a2+1],
333 ni_a[a1]:ni_a[a1+1]].T
335 #self.B_aa[self.ni_a[a1]:self.ni_a[a1+1], \
336 # self.ni_a[a2]:self.ni_a[a2+1]] = B_ii
337 self.assign_atomic_pair_matrix(self.B_aa, a1, a2, B_ii)
338 self.gd.comm.sum(self.B_aa) # TODO too heavy?
339 """
340 # self.B_aa = overlap_projectors(wfs.gd, wfs.pt, wfs.setups)
342 self.B_aa = self.calculate_overlaps(wfs.spos_ac, wfs.pt)
344 # Create two-center (block-diagonal) coefficients for overlap operator
345 dO_aa = np.zeros((nproj, nproj), dtype=float) # always float?
346 for a, setup in enumerate(self.setups):
347 self.assign_atomic_pair_matrix(dO_aa, a, a, setup.dO_ii)
349 # Calculate two-center rotation matrix for overlap projections
350 self.xO_aa = self.get_rotated_coefficients(dO_aa)
352 # Calculate two-center coefficients for inverse overlap operator
353 lhs_aa = np.eye(nproj) + self.xO_aa
354 rhs_aa = -dO_aa
355 self.dC_aa = np.linalg.solve(lhs_aa.T, rhs_aa.T).T # TODO parallel
357 # Calculate two-center rotation matrix for inverse overlap projections
358 self.xC_aa = self.get_rotated_coefficients(self.dC_aa)
360 self.timer.stop('Update two-center overlap')
362 def get_rotated_coefficients(self, X_aa):
363 r"""Rotate two-center projector expansion coefficients with
364 the projector-projector overlap integrals as basis.
366 Performs the following operation and returns the result::
368 ---
369 a1,a3 \ a1 a2 a2,a3
370 Y = ) < p | p > X
371 i1,i3 / i1 i2 i2,i3
372 ---
373 a2,i2
374 """
375 return np.dot(self.B_aa, X_aa)
377 def apply_to_atomic_matrices(self, dI_asp, P_axi, wfs, kpt, shape=()):
379 self.timer.start('Update two-center projections')
381 nproj = len(self)
382 dI_aa = np.zeros((nproj, nproj), dtype=float) # always float?
384 for a, dI_sp in dI_asp.items():
385 dI_p = dI_sp[kpt.s]
386 dI_ii = unpack_hermitian(dI_p)
387 self.assign_atomic_pair_matrix(dI_aa, a, a, dI_ii)
388 self.gd.comm.sum(dI_aa) # TODO too heavy?
390 dM_aa = self.get_rotated_coefficients(dI_aa)
391 Q_axi = wfs.pt.dict(shape, zero=True)
392 for a1 in range(self.natoms):
393 if a1 in Q_axi.keys():
394 Q_xi = Q_axi[a1]
395 else:
396 # Atom a1 is not in domain so allocate a temporary buffer
397 Q_xi = np.zeros(shape + (self.setups[a1].ni, ),
398 dtype=wfs.pt.dtype) # TODO
399 for a2, P_xi in P_axi.items():
400 dM_ii = self.extract_atomic_pair_matrix(dM_aa, a1, a2)
401 # sum over a2 and last i in dM_ii
402 Q_xi += np.dot(P_xi, dM_ii.T)
403 self.gd.comm.sum(Q_xi)
405 self.timer.stop('Update two-center projections')
407 return Q_axi
409 def apply(self,
410 a_xG,
411 b_xG,
412 wfs,
413 kpt,
414 calculate_P_ani=True,
415 extrapolate_P_ani=False):
416 """Apply the overlap operator to a set of vectors.
418 Parameters
419 ==========
420 a_nG: ndarray
421 Set of vectors to which the overlap operator is applied.
422 b_nG: ndarray, output
423 Resulting S times a_nG vectors.
424 kpt: KPoint object
425 k-point object defined in kpoint.py.
426 calculate_P_ani: bool
427 When True, the integrals of projector times vectors
428 P_ni = <p_i | a_nG> are calculated.
429 When False, existing P_ani are used
430 extrapolate_P_ani: bool
431 When True, the integrals of projector times vectors#XXX TODO
432 P_ni = <p_i | a_nG> are calculated.
433 When False, existing P_ani are used
435 """
437 self.timer.start('Apply overlap')
438 b_xG[:] = a_xG
439 shape = a_xG.shape[:-3]
440 P_axi = wfs.pt.dict(shape)
442 if calculate_P_ani:
443 wfs.pt.integrate(a_xG, P_axi, kpt.q)
444 else:
445 for a, P_ni in kpt.P_ani.items():
446 P_axi[a][:] = P_ni
448 Q_axi = wfs.pt.dict(shape)
449 for a, Q_xi in Q_axi.items():
450 Q_xi[:] = np.dot(P_axi[a], self.setups[a].dO_ii)
452 wfs.pt.add(b_xG, Q_axi, kpt.q)
453 self.timer.stop('Apply overlap')
455 if extrapolate_P_ani:
456 for a1 in range(self.natoms):
457 if a1 in Q_axi.keys():
458 Q_xi = Q_axi[a1]
459 Q_xi[:] = P_axi[a1]
460 else:
461 # Atom a1 is not in domain so allocate a temporary buffer
462 Q_xi = np.zeros(shape + (self.setups[a1].ni, ),
463 dtype=wfs.pt.dtype) # TODO
464 for a2, P_xi in P_axi.items():
465 # xO_aa are the overlap extrapolators across atomic pairs
466 xO_ii = self.extract_atomic_pair_matrix(self.xO_aa, a1, a2)
467 Q_xi += np.dot(P_xi,
468 xO_ii.T) # sum over a2 and last i in xO_ii
469 self.gd.comm.sum(Q_xi)
471 return Q_axi
472 else:
473 return P_axi
475 def apply_inverse(self,
476 a_xG,
477 b_xG,
478 wfs,
479 kpt,
480 calculate_P_ani=True,
481 extrapolate_P_ani=False):
483 self.timer.start('Apply inverse overlap')
484 b_xG[:] = a_xG
485 shape = a_xG.shape[:-3]
486 P_axi = wfs.pt.dict(shape)
488 if calculate_P_ani:
489 wfs.pt.integrate(a_xG, P_axi, kpt.q)
490 else:
491 for a, P_ni in kpt.P_ani.items():
492 P_axi[a][:] = P_ni
494 Q_axi = wfs.pt.dict(shape, zero=True)
495 for a1 in range(self.natoms):
496 if a1 in Q_axi.keys():
497 Q_xi = Q_axi[a1]
498 else:
499 # Atom a1 is not in domain so allocate a temporary buffer
500 Q_xi = np.zeros(shape + (self.setups[a1].ni, ),
501 dtype=wfs.pt.dtype) # TODO
502 for a2, P_xi in P_axi.items():
503 # dC_aa are the inverse coefficients across atomic pairs
504 dC_ii = self.extract_atomic_pair_matrix(self.dC_aa, a1, a2)
505 # sum over a2 and last i in dC_ii
506 Q_xi += np.dot(P_xi, dC_ii.T)
507 self.gd.comm.sum(Q_xi)
509 wfs.pt.add(b_xG, Q_axi, kpt.q)
510 self.timer.stop('Apply inverse overlap')
512 if extrapolate_P_ani:
513 for a1 in range(self.natoms):
514 if a1 in Q_axi.keys():
515 Q_xi = Q_axi[a1]
516 Q_xi[:] = P_axi[a1]
517 else:
518 # Atom a1 is not in domain so allocate a temporary buffer
519 Q_xi = np.zeros(shape + (self.setups[a1].ni, ),
520 dtype=wfs.pt.dtype) # TODO
521 for a2, P_xi in P_axi.items():
522 # xC_aa are the inverse extrapolators across atomic pairs
523 xC_ii = self.extract_atomic_pair_matrix(self.xC_aa, a1, a2)
524 Q_xi += np.dot(P_xi,
525 xC_ii.T) # sum over a2 and last i in xC_ii
526 self.gd.comm.sum(Q_xi)
528 return Q_axi
529 else:
530 return P_axi