Coverage for gpaw/overlap.py: 94%
36 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1# Copyright (C) 2008 CSC Scientific Computing Ltd.
2# Please see the accompanying LICENSE file for further information.
4"""This module defines an Overlap operator.
6The module defines an overlap operator and implements overlap-related
7functions.
9"""
10import numpy as np
13class OverlapCorrections:
14 def __init__(self, setups):
15 self.setups = setups
17 def apply(self, P, out=None):
18 if out is None:
19 out = P.new()
20 for a, I1, I2 in P.indices:
21 dS_ii = self.setups[a].dO_ii
22 out.array[..., I1:I2] = np.dot(P.array[..., I1:I2], dS_ii)
23 return out
26class Overlap:
27 """Overlap operator S
29 This class contains information required to apply the
30 overlap operator to a set of wavefunctions.
31 """
33 def __init__(self, timer=None):
34 """Create the Overlap operator."""
35 self.timer = timer
37 def apply(self, a_xG, b_xG, wfs, kpt, calculate_P_ani=True):
38 """Apply the overlap operator to a set of vectors.
40 Parameters
41 ==========
42 a_nG: ndarray
43 Set of vectors to which the overlap operator is applied.
44 b_nG: ndarray, output
45 Resulting S times a_nG vectors.
46 kpt: KPoint object
47 k-point object defined in kpoint.py.
48 calculate_P_ani: bool
49 When True, the integrals of projector times vectors
50 P_ni = <p_i | a_nG> are calculated.
51 When False, existing P_ani are used
53 """
54 b_xG[:] = a_xG
55 shape = a_xG.shape[:-3]
56 P_axi = wfs.pt.dict(shape)
58 if calculate_P_ani:
59 wfs.pt.integrate(a_xG, P_axi, kpt.q)
60 else:
61 for a, P_ni in kpt.P_ani.items():
62 P_axi[a][:] = P_ni
64 for a, P_xi in P_axi.items():
65 P_axi[a] = np.dot(P_xi, wfs.setups[a].dO_ii)
66 # gemm(1.0, wfs.setups[a].dO_ii, P_xi, 0.0, P_xi, 'n')
67 wfs.pt.add(b_xG, P_axi, kpt.q) # b_xG += sum_ai pt^a_i P_axi
69 def apply_inverse(self, a_xG, b_xG, wfs, kpt, calculate_P_ani=True):
70 """Apply approximative inverse overlap operator to wave functions."""
72 b_xG[:] = a_xG
73 shape = a_xG.shape[:-3]
74 P_axi = wfs.pt.dict(shape)
76 if calculate_P_ani:
77 wfs.pt.integrate(a_xG, P_axi, kpt.q)
78 else:
79 for a, P_ni in kpt.P_ani.items():
80 P_axi[a][:] = P_ni
82 for a, P_xi in P_axi.items():
83 P_axi[a] = np.dot(P_xi, wfs.setups[a].dC_ii)
84 wfs.pt.add(b_xG, P_axi, kpt.q)