Coverage for gpaw/new/wave_functions.py: 85%

108 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1from __future__ import annotations 

2 

3from types import ModuleType 

4 

5import numpy as np 

6from gpaw.core.atom_arrays import AtomArrays, AtomDistribution 

7from gpaw.core.uniform_grid import UGArray 

8from gpaw.mpi import MPIComm, serial_comm 

9from gpaw.new import trace, zips 

10from gpaw.new.potential import Potential 

11from gpaw.setup import Setups 

12from gpaw.typing import Array1D, Array2D, ArrayND 

13 

14 

15class WaveFunctions: 

16 bytes_per_band: int 

17 xp: ModuleType # numpy or cupy 

18 

19 def __init__(self, 

20 *, 

21 setups: Setups, 

22 nbands: int, 

23 relpos_ac: Array2D, 

24 atomdist: AtomDistribution, 

25 spin: int = 0, 

26 q: int = 0, 

27 k: int = 0, 

28 kpt_c=(0.0, 0.0, 0.0), 

29 weight: float = 1.0, 

30 ncomponents: int = 1, 

31 dtype=float, 

32 qspiral_v=None, 

33 domain_comm: MPIComm = serial_comm, 

34 band_comm: MPIComm = serial_comm): 

35 """""" 

36 assert spin < ncomponents 

37 

38 self.spin = spin 

39 self.q = q 

40 self.k = k 

41 self.setups = setups 

42 self.weight = weight 

43 self.ncomponents = ncomponents 

44 self.dtype = dtype 

45 self.kpt_c = kpt_c 

46 self.relpos_ac = relpos_ac 

47 self.atomdist = atomdist 

48 self.domain_comm = domain_comm 

49 self.band_comm = band_comm 

50 self.nbands = nbands 

51 self.qspiral_v = qspiral_v 

52 

53 assert domain_comm.size == atomdist.comm.size 

54 

55 self.nspins = ncomponents % 3 

56 self.spin_degeneracy = ncomponents % 2 + 1 

57 

58 self._P_ani: AtomArrays | None = None 

59 

60 self._eig_n: Array1D | None = None 

61 self._occ_n: Array1D | None = None 

62 

63 mynbands = (nbands + band_comm.size - 1) // band_comm.size 

64 self.n1 = min(band_comm.rank * mynbands, nbands) 

65 self.n2 = min((band_comm.rank + 1) * mynbands, nbands) 

66 

67 def __repr__(self): 

68 dc = f'{self.domain_comm.rank}/{self.domain_comm.size}' 

69 bc = f'{self.band_comm.rank}/{self.band_comm.size}' 

70 return (f'{self.__class__.__name__}(nbands={self.nbands}, ' 

71 f'spin={self.spin}, q={self.q}, k={self.k}, ' 

72 f'weight={self.weight}, kpt_c={self.kpt_c}, ' 

73 f'ncomponents={self.ncomponents}, dtype={self.dtype} ' 

74 f'domain_comm={dc}, band_comm={bc})') 

75 

76 def array_shape(self, global_shape: bool = False) -> tuple[int, ...]: 

77 raise NotImplementedError 

78 

79 def add_to_density(self, 

80 nt_sR: UGArray, 

81 D_asii: AtomArrays) -> None: 

82 raise NotImplementedError 

83 

84 def add_to_ked(self, 

85 taut_sR: UGArray) -> None: 

86 raise NotImplementedError 

87 

88 def orthonormalize(self, work_array_nX: ArrayND = None): 

89 raise NotImplementedError 

90 

91 def move(self, 

92 relpos_ac: Array2D, 

93 atomdist: AtomDistribution, 

94 move_wave_functions) -> None: 

95 self.relpos_ac = relpos_ac 

96 self.atomdist = atomdist 

97 self._P_ani = None 

98 self._eig_n = None 

99 # self._occ_n = None 

100 

101 def collect(self, 

102 n1: int = 0, 

103 n2: int = 0) -> WaveFunctions | None: 

104 raise NotImplementedError 

105 

106 @property 

107 def has_eigs(self) -> bool: 

108 # Checks if eigenvalues have been calculated, 

109 # that is, one scf step has been performed. 

110 return self._eig_n is not None 

111 

112 @property 

113 def has_occs(self) -> bool: 

114 # Checks if occupations have been calculated, 

115 # that is, one scf step has been performed. 

116 # XXX: In theory, this should be the same as has_eigs, 

117 # however, there seems to be a discrepancy during 

118 # fixed density calculations. 

119 return self._occ_n is not None 

120 

121 @property 

122 def eig_n(self) -> Array1D: 

123 if self._eig_n is None: 

124 raise ValueError 

125 return self._eig_n 

126 

127 @property 

128 def occ_n(self) -> Array1D: 

129 if self._occ_n is None: 

130 raise ValueError 

131 return self._occ_n 

132 

133 @property 

134 def myeig_n(self): 

135 return self.eig_n[self.n1:self.n2] 

136 

137 @property 

138 def myocc_n(self): 

139 return self.occ_n[self.n1:self.n2] 

140 

141 @property 

142 def P_ani(self) -> AtomArrays: 

143 if self._P_ani is None: 

144 raise RuntimeError('Projections P_ani not present') 

145 return self._P_ani 

146 

147 @trace 

148 def add_to_atomic_density_matrices(self, 

149 occ_n, 

150 D_asii: AtomArrays) -> None: 

151 xp = D_asii.layout.xp 

152 occ_n = xp.asarray(occ_n) 

153 if self.ncomponents < 4: 

154 P_ani = self.P_ani 

155 for D_sii, P_ni in zips(D_asii.values(), P_ani.values()): 

156 D_sii[self.spin] += xp.einsum('ni, n, nj -> ij', 

157 P_ni.conj(), occ_n, P_ni).real 

158 else: 

159 for D_xii, P_nsi in zips(D_asii.values(), self.P_ani.values()): 

160 add_to_4component_density_matrix(D_xii, P_nsi, occ_n, xp) 

161 

162 def send(self, kpt_comm, rank): 

163 raise NotImplementedError 

164 

165 def receive(self, kpt_comm, rank): 

166 raise NotImplementedError 

167 

168 def force_contribution(self, potential: Potential, F_av: Array2D): 

169 raise NotImplementedError 

170 

171 def gather_wave_function_coefficients(self) -> np.ndarray | None: 

172 raise NotImplementedError 

173 

174 

175def add_to_4component_density_matrix(D_xii, P_nsi, occ_n, xp): 

176 D_ssii = xp.einsum('nsi, n, nzj -> szij', P_nsi.conj(), occ_n, P_nsi) 

177 D_xii[0] += D_ssii[0, 0] + D_ssii[1, 1] 

178 D_xii[1] += D_ssii[0, 1] + D_ssii[1, 0] 

179 D_xii[2] += -1j * (D_ssii[0, 1] - D_ssii[1, 0]) 

180 D_xii[3] += D_ssii[0, 0] - D_ssii[1, 1]