Coverage for gpaw/new/lcao/wave_functions.py: 81%

119 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from gpaw.core.atom_arrays import (AtomArrays, AtomArraysLayout, 

5 AtomDistribution) 

6from gpaw.core.matrix import Matrix 

7from gpaw.mpi import MPIComm, receive, send, serial_comm 

8from gpaw.new.potential import Potential 

9from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

10from gpaw.new.wave_functions import WaveFunctions 

11from gpaw.setup import Setups 

12from gpaw.typing import Array2D 

13 

14 

15class LCAOWaveFunctions(WaveFunctions): 

16 xp = np 

17 

18 def __init__(self, 

19 *, 

20 setups: Setups, 

21 tci_derivatives, 

22 basis, 

23 C_nM: Matrix, 

24 S_MM: Matrix, 

25 T_MM: Matrix, 

26 P_aMi, 

27 relpos_ac: Array2D, 

28 atomdist: AtomDistribution, 

29 kpt_c=(0.0, 0.0, 0.0), 

30 domain_comm: MPIComm = serial_comm, 

31 spin: int = 0, 

32 q: int = 0, 

33 k: int = 0, 

34 weight: float = 1.0, 

35 ncomponents: int = 1): 

36 super().__init__(setups=setups, 

37 nbands=C_nM.shape[0], 

38 spin=spin, 

39 q=q, 

40 k=k, 

41 kpt_c=kpt_c, 

42 weight=weight, 

43 relpos_ac=relpos_ac, 

44 atomdist=atomdist, 

45 ncomponents=ncomponents, 

46 dtype=C_nM.dtype, 

47 domain_comm=domain_comm, 

48 band_comm=C_nM.dist.comm) 

49 self.tci_derivatives = tci_derivatives 

50 self.basis = basis 

51 self.C_nM = C_nM 

52 self.T_MM = T_MM 

53 self.S_MM = S_MM 

54 self.P_aMi = P_aMi 

55 

56 self.bytes_per_band = (self.array_shape(global_shape=True)[0] * 

57 C_nM.data.itemsize) 

58 

59 # This is for TB-mode (and MYPY): 

60 self.V_MM: Matrix 

61 

62 self._L_MM = None 

63 

64 def move(self, 

65 relpos_ac: Array2D, 

66 atomdist: AtomDistribution, 

67 move_wave_functions) -> None: 

68 self._update_phases(relpos_ac) 

69 super().move(relpos_ac, atomdist, move_wave_functions) 

70 self._L_MM = None 

71 

72 def _update_phases(self, relpos_ac): 

73 """Complex-rotate coefficients compensating discontinuous phase shift. 

74 

75 This changes the coefficients to counteract the phase discontinuity 

76 of overlaps when atoms move across a cell boundary.""" 

77 

78 # We don't want to apply any phase shift unless we crossed a cell 

79 # boundary. So we round the shift to either 0 or 1. 

80 # 

81 # Example: spos_ac goes from 0.01 to 0.99 -- this rounds to 1 and 

82 # we apply the phase. If someone moves an atom by half a cell 

83 # without crossing a boundary, then we are out of luck. But they 

84 # should have reinitialized from LCAO anyway. 

85 

86 C_nM = self.C_nM.data 

87 if C_nM.dtype == float: 

88 return 

89 diff_ac = (relpos_ac - self.relpos_ac).round() 

90 if not diff_ac.any(): 

91 return 

92 phase_a = np.exp(2j * np.pi * diff_ac @ self.kpt_c) 

93 M1 = 0 

94 for phase, sphere in zip(phase_a, self.basis.sphere_a): 

95 M2 = M1 + sphere.Mmax 

96 C_nM[:, M1:M2] *= phase 

97 M1 = M2 

98 

99 @property 

100 def L_MM(self): 

101 if self._L_MM is None: 

102 S_MM = self.S_MM.copy() 

103 S_MM.invcholesky() 

104 if self.ncomponents < 4: 

105 self._L_MM = S_MM 

106 else: 

107 M, M = S_MM.shape 

108 L_sMsM = Matrix(2 * M, 2 * M, dtype=complex) 

109 L_sMsM.data[:] = 0.0 

110 L_sMsM.data[:M, :M] = S_MM.data 

111 L_sMsM.data[M:, M:] = S_MM.data 

112 self._L_MM = L_sMsM 

113 return self._L_MM 

114 

115 def _short_string(self, global_shape): 

116 return f'basis functions: {global_shape[0]}' 

117 

118 def array_shape(self, global_shape=False): 

119 if global_shape: 

120 return self.C_nM.shape[1:] 

121 1 / 0 

122 

123 @property 

124 def _layout(self): 

125 atomdist = AtomDistribution.from_atom_indices( 

126 list(self.P_aMi), 

127 self.domain_comm, 

128 natoms=len(self.setups)) 

129 return AtomArraysLayout([setup.ni for setup in self.setups], 

130 atomdist=atomdist, 

131 dtype=self.dtype) 

132 

133 @property 

134 def P_ani(self): 

135 if self._P_ani is None: 

136 self._P_ani = self._layout.empty(self.nbands, 

137 comm=self.C_nM.dist.comm) 

138 # As a hack, builder.py injects a NaN in the first element of 

139 # C_nM.data in order for us to be able to tell that the 

140 # data is uninitialized: 

141 if not isinstance(self.C_nM, Matrix): 

142 raise RuntimeError('There are no projections or wavefunctions') 

143 

144 for a, P_Mi in self.P_aMi.items(): 

145 self._P_ani[a][:] = self.C_nM.data @ P_Mi 

146 

147 return self._P_ani 

148 

149 def add_to_density(self, 

150 nt_sR, 

151 D_asii: AtomArrays) -> None: 

152 """Add density from wave functions. 

153 

154 Adds to ``nt_sR`` and ``D_asii``. 

155 """ 

156 rho_MM = self.calculate_density_matrix() 

157 self.basis.construct_density(rho_MM, nt_sR.data[self.spin], q=self.q) 

158 f_n = self.weight * self.spin_degeneracy * self.myocc_n 

159 self.add_to_atomic_density_matrices(f_n, D_asii) 

160 

161 def gather_wave_function_coefficients(self) -> np.ndarray: 

162 C_nM = self.C_nM.gather() 

163 if C_nM is not None: 

164 return C_nM.data 

165 return None 

166 

167 def calculate_density_matrix(self, 

168 *, 

169 eigs=False, 

170 transposed=False) -> np.ndarray: 

171 """Calculate the density matrix. 

172 

173 The density matrix is::: 

174 

175 -- * 

176 ρ = > C C f 

177 μν -- nμ nν n 

178 n 

179 

180 Returns 

181 ------- 

182 The density matrix in the LCAO basis 

183 """ 

184 if self.domain_comm.rank == 0: 

185 f_n = self.weight * self.spin_degeneracy * self.myocc_n 

186 if eigs: 

187 f_n *= self.myeig_n 

188 TempC_nM = self.C_nM.copy() 

189 TempC_nM.data *= f_n[:, None] 

190 rho_MM = TempC_nM.multiply(self.C_nM, opa='C') 

191 if transposed: 

192 rho_MM.complex_conjugate() 

193 rho_MM_data = rho_MM.data 

194 else: 

195 rho_MM_data = np.empty_like(self.T_MM.data) 

196 self.domain_comm.broadcast(rho_MM_data, 0) 

197 

198 return rho_MM_data 

199 

200 def to_uniform_grid_wave_functions(self, 

201 grid, 

202 basis): 

203 grid = grid.new(kpt=self.kpt_c, dtype=self.dtype) 

204 psit_nR = grid.zeros(self.nbands, self.band_comm) 

205 basis.lcao_to_grid(self.C_nM.data, psit_nR.data, self.q) 

206 

207 wfs = PWFDWaveFunctions.from_wfs(self, psit_nR) 

208 if self._eig_n is not None: 

209 wfs._eig_n = self._eig_n.copy() 

210 return wfs 

211 

212 def collect(self, 

213 n1: int = 0, 

214 n2: int = 0) -> LCAOWaveFunctions | None: 

215 # Quick'n'dirty implementation 

216 # We should generalize the PW+FD method 

217 assert self.band_comm.size == 1 

218 n2 = n2 or self.nbands + n2 

219 return LCAOWaveFunctions( 

220 setups=self.setups, 

221 tci_derivatives=self.tci_derivatives, 

222 basis=self.basis, 

223 C_nM=Matrix(n2 - n1, 

224 self.C_nM.shape[1], 

225 data=self.C_nM.data[n1:n2].copy()), 

226 S_MM=self.S_MM, 

227 T_MM=self.T_MM, 

228 P_aMi=self.P_aMi, 

229 relpos_ac=self.relpos_ac, 

230 atomdist=self.atomdist.gather(), 

231 kpt_c=self.kpt_c, 

232 spin=self.spin, 

233 q=self.q, 

234 k=self.k, 

235 weight=self.weight, 

236 ncomponents=self.ncomponents) 

237 

238 def force_contribution(self, potential: Potential, F_av: Array2D): 

239 from gpaw.new.lcao.forces import add_force_contributions 

240 add_force_contributions(self, potential, F_av) 

241 return F_av 

242 

243 def send(self, rank, comm): 

244 stuff = (self.kpt_c, 

245 self.C_nM.data, 

246 self.spin, 

247 self.q, 

248 self.k, 

249 self.weight, 

250 self.ncomponents) 

251 send(stuff, rank, comm) 

252 

253 def receive(self, rank, comm): 

254 kpt_c, data, spin, q, k, weight, ncomponents = receive(rank, comm) 

255 return LCAOWaveFunctions(setups=self.setups, 

256 tci_derivatives=self.tci_derivatives, 

257 basis=self.basis, 

258 C_nM=Matrix(*data.shape, data=data), 

259 S_MM=None, 

260 T_MM=None, 

261 P_aMi=None, 

262 relpos_ac=self.relpos_ac, 

263 atomdist=self.atomdist.gather(), 

264 kpt_c=kpt_c, 

265 spin=spin, 

266 q=q, 

267 k=k, 

268 weight=weight, 

269 ncomponents=ncomponents)