Coverage for gpaw/lcao/atomic_correction.py: 99%
78 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import numpy as np
3from gpaw.utilities.blas import mmm
4from gpaw.utilities import unpack_hermitian
7class BaseAtomicCorrection:
8 name = 'base'
9 description = 'base class for atomic corrections with LCAO'
11 def __init__(self, dS_aii, Mstart, Mstop):
12 self.dS_aii = dS_aii
13 self.Mstart = Mstart
14 self.Mstop = Mstop
16 @staticmethod
17 def get_dS(atom_partition, setups):
18 dO_aii = atom_partition.arraydict([setup.dO_ii.shape
19 for setup in setups])
20 for a in dO_aii:
21 dO_aii[a][:] = setups[a].dO_ii
22 return dO_aii
24 def calculate_hamiltonian(self, kpt, dH_asp, H_MM, yy):
25 dH_aii = dH_asp.partition.arraydict(self.dS_aii.shapes_a,
26 dtype=dH_asp.dtype)
28 for a in dH_asp:
29 dH_aii[a] = yy * unpack_hermitian(dH_asp[a][kpt.s])
31 self.calculate(kpt.q, dH_aii, H_MM, self.Mstart, self.Mstop)
33 def add_overlap_correction(self, S_qMM):
34 for q, S_MM in enumerate(S_qMM):
35 self.calculate(q, self.dS_aii, S_MM, self.Mstart, self.Mstop)
37 def calculate(self, q, dX_aii, X_MM):
38 raise NotImplementedError
41class DenseAtomicCorrection(BaseAtomicCorrection):
42 name = 'dense'
43 description = 'dense with blas'
45 def __init__(self, P_aqMi, dS_aii, Mstart, Mstop):
46 BaseAtomicCorrection.__init__(self, dS_aii, Mstart, Mstop)
47 self.P_aqMi = P_aqMi
49 @classmethod
50 def new_from_wfs(cls, wfs):
51 return cls(wfs.P_aqMi, cls.get_dS(wfs.atom_partition, wfs.setups),
52 wfs.ksl.Mstart, wfs.ksl.Mstop)
54 def calculate(self, q, dX_aii, X_MM, Mstart, Mstop):
55 dtype = X_MM.dtype
56 P_aqMi = self.P_aqMi
58 # P_aqMi is distributed over domains (a) and bands (M).
59 # Hence the correction X_MM = sum(P dX P) includes contributions
60 # only from local atoms; the result must be summed over gd.comm
61 # to get all 'a' contributions, and it will be locally calculated
62 # only on the local slice of bands.
63 for a, dX_ii in dX_aii.items():
64 P_Mi = P_aqMi[a][q]
65 assert dtype == P_Mi.dtype
66 dXP_iM = np.zeros((dX_ii.shape[1], P_Mi.shape[0]), dtype)
67 # (ATLAS can't handle uninitialized output array)
68 mmm(1.0, np.asarray(dX_ii, dtype), 'N', P_Mi, 'C', 0.0, dXP_iM)
69 mmm(1.0, P_Mi[Mstart:Mstop], 'N', dXP_iM, 'N', 1.0, X_MM)
71 def calculate_projections(self, wfs, kpt):
72 for a, P_ni in kpt.P_ani.items():
73 # ATLAS can't handle uninitialized output array:
74 P_ni.fill(117)
75 mmm(1.0, kpt.C_nM, 'N', wfs.P_aqMi[a][kpt.q], 'N', 0.0, P_ni)
78class SparseAtomicCorrection(BaseAtomicCorrection):
79 name = 'sparse'
80 description = 'sparse using scipy'
82 def __init__(self, Psparse_qIM, P_indices, dS_aii, Mstart, Mstop,
83 tolerance=1e-12):
84 BaseAtomicCorrection.__init__(self, dS_aii, Mstart, Mstop)
85 self.Psparse_qIM = Psparse_qIM
86 self.P_indices = P_indices
87 # We currently don't use tolerance although we could speed things
88 # up that way.
89 #
90 # Tolerance is for zeroing elements very close to zero, which
91 # often increases sparsity somewhat, even for very small values.
92 self.tolerance = tolerance
94 @classmethod
95 def new_from_wfs(cls, wfs):
96 return cls(wfs.P_qIM, wfs.setups.projector_indices(),
97 cls.get_dS(wfs.atom_partition, wfs.setups),
98 wfs.ksl.Mstart, wfs.ksl.Mstop)
100 def calculate(self, q, dX_aii, X_MM, Mstart, Mstop):
101 P_indices = self.P_indices
102 nI = P_indices.max
104 import scipy.sparse as sparse
105 dXsparse_II = sparse.dok_matrix((nI, nI), dtype=X_MM.dtype)
106 for a in dX_aii:
107 I1, I2 = P_indices[a]
108 dXsparse_II[I1:I2, I1:I2] = dX_aii[a]
109 dXsparse_II = dXsparse_II.tocsr()
111 Psparse_IM = self.Psparse_qIM[q]
112 Psparse_MI = Psparse_IM[:, Mstart:Mstop].transpose().conj()
113 Xsparse_MM = Psparse_MI.dot(dXsparse_II.dot(Psparse_IM))
114 X_MM[:, :] += Xsparse_MM.todense()
116 def calculate_projections(self, wfs, kpt):
117 P_indices = self.P_indices
118 P_In = self.Psparse_qIM[kpt.q].dot(kpt.C_nM.T.conj())
119 for a in kpt.P_ani:
120 I1, I2 = P_indices[a]
121 kpt.P_ani[a][:, :] = P_In[I1:I2, :].T.conj()