Coverage for gpaw/new/wannier.py: 14%

79 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from __future__ import annotations 

2from math import factorial as fac 

3 

4import numpy as np 

5from ase.units import Bohr 

6 

7from gpaw.new.ibzwfs import IBZWaveFunctions 

8from gpaw.spline import Spline 

9from gpaw.typing import Array2D 

10 

11 

12def get_wannier_integrals(ibzwfs: IBZWaveFunctions, 

13 grid, 

14 s: int, 

15 k: int, 

16 k1: int, 

17 G_c, 

18 nbands=None) -> Array2D: 

19 """Calculate integrals for maximally localized Wannier functions.""" 

20 ibzwfs.make_sure_wfs_are_read_from_gpw_file() 

21 assert s <= ibzwfs.nspins 

22 # XXX not for the kpoint/spin parallel case 

23 assert ibzwfs.comm.size == 1 

24 wfs = ibzwfs.wfs_qs[k][s].to_uniform_grid_wave_functions(grid, None) 

25 wfs1 = ibzwfs.wfs_qs[k1][s].to_uniform_grid_wave_functions(grid, None) 

26 # Get pseudo part 

27 psit_nR = wfs.psit_nX.data 

28 psit1_nR = wfs1.psit_nX.data 

29 Z_nn = grid._gd.wannier_matrix(psit_nR, psit1_nR, G_c, nbands) 

30 # Add corrections 

31 add_wannier_correction(Z_nn, G_c, wfs, wfs1, nbands) 

32 grid.comm.sum(Z_nn) 

33 return Z_nn 

34 

35 

36def add_wannier_correction(Z_nn, G_c, wfs, wfs1, nbands): 

37 r"""Calculate the correction to the wannier integrals. 

38 

39 See: (Eq. 27 ref1):: 

40 

41 -i G.r 

42 Z = <psi | e |psi > 

43 nm n m 

44 

45 __ __ 

46 ~ \ a \ a* a a 

47 Z = Z + ) exp[-i G . R ] ) P dO P 

48 nmx nmx /__ x /__ ni ii' mi' 

49 

50 a ii' 

51 

52 Note that this correction is an approximation that assumes the 

53 exponential varies slowly over the extent of the augmentation sphere. 

54 

55 ref1: Thygesen et al, Phys. Rev. B 72, 125119 (2005) 

56 """ 

57 P_ani = wfs.P_ani 

58 P1_ani = wfs1.P_ani 

59 for a, P_ni in P_ani.items(): 

60 P_ni = P_ani[a][:nbands] 

61 P1_ni = P1_ani[a][:nbands] 

62 dO_ii = wfs.setups[a].dO_ii 

63 e = np.exp(-2.j * np.pi * np.dot(G_c, wfs.relpos_ac[a])) 

64 Z_nn += e * P_ni.conj() @ dO_ii @ P1_ni.T 

65 

66 

67def initial_wannier(ibzwfs: IBZWaveFunctions, 

68 initialwannier, kpointgrid, fixedstates, 

69 edf, spin, nbands): 

70 """Initial guess for the shape of wannier functions. 

71 

72 Use initial guess for wannier orbitals to determine rotation 

73 matrices U and C. 

74 """ 

75 

76 from ase.dft.wannier import rotation_from_projection 

77 proj_knw = get_projections(ibzwfs, initialwannier, spin) 

78 U_kww = [] 

79 C_kul = [] 

80 for fixed, proj_nw in zip(fixedstates, proj_knw): 

81 U_ww, C_ul = rotation_from_projection(proj_nw[:nbands], 

82 fixed, 

83 ortho=True) 

84 U_kww.append(U_ww) 

85 C_kul.append(C_ul) 

86 

87 return C_kul, np.asarray(U_kww) 

88 

89 

90def get_projections(ibzwfs: IBZWaveFunctions, 

91 locfun: str | list[tuple], 

92 spin=0): 

93 """Project wave functions onto localized functions 

94 

95 Determine the projections of the Kohn-Sham eigenstates 

96 onto specified localized functions of the format:: 

97 

98 locfun = [[spos_c, l, sigma], [...]] 

99 

100 spos_c can be an atom index, or a scaled position vector. l is 

101 the angular momentum, and sigma is the (half-) width of the 

102 radial gaussian. 

103 

104 Return format is:: 

105 

106 f_kni = <psi_kn | f_i> 

107 

108 where psi_kn are the wave functions, and f_i are the specified 

109 localized functions. 

110 

111 As a special case, locfun can be the string 'projectors', in which 

112 case the bound state projectors are used as localized functions. 

113 """ 

114 if isinstance(locfun, str): 

115 assert locfun == 'projectors' 

116 f_kin = [] 

117 for wfs in ibzwfs: 

118 if wfs.spin == spin: 

119 f_in = [] 

120 for a, P_ni in wfs.P_ani.items(): 

121 i = 0 

122 setup = wfs.setups[a] 

123 for l, n in zip(setup.l_j, setup.n_j): 

124 if n >= 0: 

125 for j in range(i, i + 2 * l + 1): 

126 f_in.append(P_ni[:, j]) 

127 i += 2 * l + 1 

128 f_kin.append(f_in) 

129 f_kni = np.array(f_kin).transpose(0, 2, 1) 

130 return f_kni.conj() 

131 

132 nkpts = len(ibzwfs.ibz) 

133 nbf = np.sum([2 * l + 1 for pos, l, a in locfun]) 

134 f_knB = np.zeros((nkpts, ibzwfs.nbands, nbf), ibzwfs.dtype) 

135 relpos_ac = ibzwfs.wfs_qs[0][0].relpos_ac 

136 

137 spos_bc = [] 

138 splines_b = [] 

139 for spos_c, l, sigma in locfun: 

140 if isinstance(spos_c, int): 

141 spos_c = relpos_ac[spos_c] 

142 spos_bc.append(spos_c) 

143 alpha = .5 * Bohr**2 / sigma**2 

144 r = np.linspace(0, 10. * sigma, 500) 

145 f_g = (fac(l) * (4 * alpha)**(l + 3 / 2.) * 

146 np.exp(-alpha * r**2) / 

147 (np.sqrt(4 * np.pi) * fac(2 * l + 1))) 

148 splines_b.append([Spline.from_data(l, rmax=r[-1], f_g=f_g)]) 

149 

150 assert ibzwfs.domain_comm.size == 1 

151 

152 for wfs in ibzwfs: 

153 if wfs.spin != spin: 

154 continue 

155 psit_nX = wfs.psit_nX 

156 lf_blX = psit_nX.desc.atom_centered_functions( 

157 splines_b, spos_bc, cut=True) 

158 f_bnl = lf_blX.integrate(psit_nX) 

159 f_knB[wfs.q] = f_bnl.data 

160 return f_knB.conj()