Coverage for gpaw/nlopt/shift.py: 99%

69 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2from ase.parallel import parprint 

3from ase.units import _e, _hbar, Ha, J 

4from ase.utils.timing import Timer 

5 

6from gpaw.mpi import world 

7from gpaw.nlopt.matrixel import get_derivative, get_rml 

8from gpaw.utilities.progressbar import ProgressBar 

9 

10 

11def get_shift( 

12 nlodata, 

13 freqs=[1.0], 

14 eta=0.05, 

15 pol='yyy', 

16 eshift=0.0, 

17 ftol=1e-4, Etol=1e-6, 

18 band_n=None, 

19 out_name='shift.npy'): 

20 """ 

21 Calculate RPA shift current for nonmagnetic semiconductors. 

22 

23 Parameters 

24 ---------- 

25 nlodata 

26 Data object of class NLOData. 

27 freqs 

28 Excitation frequency array (a numpy array or list). 

29 eta 

30 Broadening, a number or an array (default 0.05 eV). 

31 pol 

32 Tensor element (default 'yyy'). 

33 Etol, ftol 

34 Tolerance in energy and fermi to consider degeneracy. 

35 band_n 

36 List of bands in the sum (default 0 to nb). 

37 out_name 

38 Output filename (default 'shift.npy'). 

39 

40 Returns 

41 ------- 

42 np.ndarray 

43 Numpy array containing the spectrum and frequencies. 

44 

45 """ 

46 

47 # Start a timer 

48 timer = Timer() 

49 parprint(f'Calculating shift current (in {world.size:d} cores).') 

50 

51 # Covert inputs in eV to Ha 

52 nw = len(freqs) 

53 w_l = np.array(freqs) / Ha 

54 eta /= Ha 

55 Etol /= Ha 

56 

57 # Useful variables 

58 pol_v = ['xyz'.index(ii) for ii in pol] 

59 

60 parprint(f'Calculation for element {pol}.') 

61 

62 # Load the required data 

63 with timer('Load and distribute the data'): 

64 k_info = nlodata.distribute() 

65 if k_info: 

66 tmp = list(k_info.values())[0] 

67 nb = len(tmp[1]) 

68 nk = len(k_info) * world.size # Approximately 

69 if band_n is None: 

70 band_n = list(range(nb)) 

71 mem = 6 * 3 * nk * nb**2 * 16 / 2**20 

72 parprint(f'At least {mem:.2f} MB of memory is required.') 

73 

74 # Initial call to print 0% progress 

75 count = 0 

76 ncount = len(k_info) 

77 if world.rank == 0: 

78 pb = ProgressBar() 

79 

80 # Initialize the outputs 

81 sum2_l = np.zeros((nw), complex) 

82 

83 # Do the calculations 

84 for _, (we, f_n, E_n, p_vnn) in k_info.items(): 

85 with timer('Position matrix elements calculation'): 

86 r_vnn, D_vnn = get_rml(E_n, p_vnn, pol_v, Etol=Etol) 

87 

88 with timer('Compute generalized derivative'): 

89 rd_vvnn = get_derivative(E_n, r_vnn, D_vnn, pol_v, Etol=Etol) 

90 

91 with timer('Sum over bands'): 

92 tmp = shift_current( 

93 eta, w_l, f_n, E_n, r_vnn, rd_vvnn, pol_v, 

94 band_n, ftol, Etol, eshift) 

95 

96 # Add it to previous with a weight 

97 sum2_l += tmp * we 

98 

99 # Print the progress 

100 if world.rank == 0: 

101 pb.update(count / ncount) 

102 count += 1 

103 

104 if world.rank == 0: 

105 pb.finish() 

106 

107 with timer('Gather data from cores'): 

108 world.sum(sum2_l) 

109 

110 # Multiply prefactors 

111 prefactor = 1 / (2 * (2.0 * np.pi)**3) 

112 # Convert to SI units [A / V^2] = [C^3 / (J^2 * s)] 

113 prefactor *= _e**3 / (_hbar * (Ha / J)) 

114 

115 sigma_l = prefactor * sum2_l.real 

116 

117 # A multi-col output 

118 shift = np.vstack((freqs, sigma_l)) 

119 

120 # Save it to the file 

121 if world.rank == 0: 

122 np.save(out_name, shift) 

123 

124 # Print the timing 

125 timer.write() 

126 

127 return shift 

128 

129 

130def shift_current( 

131 eta, w_l, f_n, E_n, r_vnn, rd_vvnn, pol_v, 

132 band_n=None, ftol=1e-4, Etol=1e-6, eshift=0.0): 

133 """ 

134 Loop over bands for computing in length gauge 

135 

136 Input: 

137 eta Broadening 

138 w_l Complex frequency array 

139 f_n Fermi levels 

140 E_n Energies 

141 r_vnn Momentum matrix elements 

142 rd_vvnn Generalized derivative of position 

143 pol_v Tensor element 

144 band_n Band list 

145 Etol, ftol Tol. in energy and fermi to consider degeneracy 

146 eshift Bandgap correction 

147 Output: 

148 sum2_l Output array 

149 """ 

150 

151 # Initialize variable 

152 nb = len(f_n) 

153 if band_n is None: 

154 band_n = list(range(nb)) 

155 sum2_l = np.zeros((w_l.size), complex) 

156 

157 # Loop over bands 

158 for nni in band_n: 

159 for mmi in band_n: 

160 # Remove the non important term (use TRS) 

161 if mmi <= nni: 

162 continue 

163 fnm = f_n[nni] - f_n[mmi] 

164 Emn = E_n[mmi] - E_n[nni] + fnm * eshift 

165 

166 # Two band part 

167 if np.abs(fnm) > ftol: 

168 tmp = np.imag( 

169 r_vnn[pol_v[1], mmi, nni] 

170 * rd_vvnn[pol_v[0], pol_v[2], nni, mmi] 

171 + r_vnn[pol_v[2], mmi, nni] 

172 * rd_vvnn[pol_v[0], pol_v[1], nni, mmi]) \ 

173 * (eta / (np.pi * ((w_l - Emn) ** 2 + eta ** 2)) 

174 - eta / (np.pi * ((w_l + Emn) ** 2 + eta ** 2))) 

175 

176 sum2_l += fnm * tmp 

177 

178 return 2 * np.pi * sum2_l