Coverage for gpaw/new/pw/fulldiag.py: 15%

91 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from gpaw.core.atom_arrays import AtomArrays 

5from gpaw.core.matrix import Matrix, create_distribution 

6from gpaw.core.plane_waves import (PWAtomCenteredFunctions, 

7 PWArray, PWDesc) 

8from gpaw.core.uniform_grid import UGArray 

9from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

10from gpaw.typing import Array2D 

11from gpaw.new.ibzwfs import IBZWaveFunctions 

12from gpaw.new.wave_functions import WaveFunctions 

13from gpaw.new.potential import Potential 

14from gpaw.new.smearing import OccupationNumberCalculator 

15 

16 

17def pw_matrix(pw: PWDesc, 

18 pt_aiG: PWAtomCenteredFunctions, 

19 dH_aii: AtomArrays, 

20 dS_aii: list[Array2D], 

21 vt_R: UGArray, 

22 dedtaut_R: UGArray | None, 

23 comm) -> tuple[Matrix, Matrix]: 

24 """Calculate H and S matrices in plane-wave basis. 

25 

26 ::: 

27 

28 _ _ _ _ 

29 / -iG.r ~ iG.r _ 

30 O = | e O e dr 

31 GG' / 

32 

33 ::: 

34 

35 ~ ^ ~ _ _ _ --- ~a _ _a a ~ _ _a 

36 H = T + v(r) δ(r-r') + < p (r-R ) ΔH p (r'-R ) 

37 --- i ij j 

38 aij 

39 

40 ::: 

41 

42 ~ _ _ --- ~a _ _a a ~ _ _a 

43 S = δ(r-r') + < p (r-R ) ΔS p (r'-R ) 

44 --- i ij j 

45 aij 

46 """ 

47 assert pw.dtype == complex 

48 npw = pw.shape[0] 

49 dist = create_distribution(npw, npw, comm, -1, 1) 

50 H_GG = dist.matrix(complex) 

51 S_GG = dist.matrix(complex) 

52 G1, G2 = dist.my_row_range() 

53 

54 x_G = pw.empty() 

55 assert isinstance(x_G, PWArray) # Fix this! 

56 x_R = vt_R.desc.new(dtype=complex).zeros() 

57 assert isinstance(x_R, UGArray) # Fix this! 

58 dv = pw.dv 

59 

60 for G in range(G1, G2): 

61 x_G.data[:] = 0.0 

62 x_G.data[G] = 1.0 

63 x_G.ifft(out=x_R) 

64 x_R.data *= vt_R.data 

65 x_R.fft(out=x_G) 

66 H_GG.data[G - G1] = dv * x_G.data 

67 

68 if dedtaut_R is not None: 

69 G_Gv = pw.reciprocal_vectors() 

70 for G in range(G1, G2): 

71 for v in range(3): 

72 x_G.data[:] = 0.0 

73 x_G.data[G] = 1j * G_Gv[G, v] 

74 x_G.ifft(out=x_R) 

75 x_R.data *= dedtaut_R.data 

76 x_R.fft(out=x_G) 

77 H_GG.data[G - G1] += -0.5j * dv * G_Gv[:, v] * x_G.data 

78 

79 H_GG.add_to_diagonal(dv * pw.ekin_G[G1:G2]) 

80 S_GG.data[:] = 0.0 

81 S_GG.add_to_diagonal(dv) 

82 

83 pt_aiG._lazy_init() 

84 assert pt_aiG._lfc is not None 

85 f_GI = pt_aiG._lfc.expand() 

86 nI = f_GI.shape[1] 

87 dH_II = np.zeros((nI, nI)) 

88 dS_II = np.zeros((nI, nI)) 

89 I1 = 0 

90 for a, dH_ii in dH_aii.items(): 

91 dS_ii = dS_aii[a] 

92 I2 = I1 + len(dS_ii) 

93 dH_II[I1:I2, I1:I2] = dH_ii 

94 dS_II[I1:I2, I1:I2] = dS_ii 

95 I1 = I2 

96 

97 H_GG.data += (f_GI[G1:G2].conj() @ dH_II) @ f_GI.T 

98 S_GG.data += (f_GI[G1:G2].conj() @ dS_II) @ f_GI.T 

99 

100 return H_GG, S_GG 

101 

102 

103def diagonalize(potential: Potential, 

104 ibzwfs: IBZWaveFunctions, 

105 occ_calc: OccupationNumberCalculator, 

106 nbands: int, 

107 nelectrons: float) -> IBZWaveFunctions: 

108 """Diagonalize hamiltonian in plane-wave basis.""" 

109 vt_sR = potential.vt_sR 

110 dH_asii = potential.dH_asii 

111 dedtaut_sR: UGArray | list[None] = [None] * len(vt_sR) 

112 if potential.dedtaut_sR is not None: 

113 dedtaut_sR = potential.dedtaut_sR 

114 

115 band_comm = ibzwfs.band_comm 

116 

117 wfs_qs: list[list[WaveFunctions]] = [] 

118 for wfs_s in ibzwfs.wfs_qs: 

119 wfs_qs.append([]) 

120 for wfs in wfs_s: 

121 dS_aii = [setup.dO_ii for setup in wfs.setups] 

122 assert isinstance(wfs, PWFDWaveFunctions) 

123 assert isinstance(wfs.pt_aiX, PWAtomCenteredFunctions) 

124 pw = wfs.psit_nX.desc 

125 H_GG, S_GG = pw_matrix(pw, 

126 wfs.pt_aiX, 

127 dH_asii[:, wfs.spin], 

128 dS_aii, 

129 vt_sR[wfs.spin], 

130 dedtaut_sR[wfs.spin], 

131 band_comm) 

132 

133 eig_n = H_GG.eigh(S_GG, limit=nbands) 

134 H_GG.complex_conjugate() 

135 assert eig_n[0] > -1000, 'See issue #241' 

136 psit_nG = pw.empty(nbands, comm=band_comm) 

137 mynbands, nG = psit_nG.data.shape 

138 maxmynbands = (nbands + band_comm.size - 1) // band_comm.size 

139 C_nG = H_GG.new( 

140 dist=(band_comm, band_comm.size, 1, maxmynbands, 1)) 

141 H_GG.redist(C_nG) 

142 psit_nG.data[:] = C_nG.data[:mynbands] 

143 new_wfs = PWFDWaveFunctions.from_wfs(wfs, psit_nX=psit_nG) 

144 new_wfs._eig_n = eig_n 

145 wfs_qs[-1].append(new_wfs) 

146 

147 new_ibzwfs = IBZWaveFunctions( 

148 ibzwfs.ibz, 

149 ncomponents=ibzwfs.ncomponents, 

150 wfs_qs=wfs_qs, 

151 kpt_comm=ibzwfs.kpt_comm, 

152 kpt_band_comm=ibzwfs.kpt_band_comm, 

153 comm=ibzwfs.comm) 

154 

155 new_ibzwfs.calculate_occs(occ_calc, nelectrons) 

156 

157 return new_ibzwfs