Coverage for gpaw/nlopt/shift.py: 99%
69 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
2from ase.parallel import parprint
3from ase.units import _e, _hbar, Ha, J
4from ase.utils.timing import Timer
6from gpaw.mpi import world
7from gpaw.nlopt.matrixel import get_derivative, get_rml
8from gpaw.utilities.progressbar import ProgressBar
11def get_shift(
12 nlodata,
13 freqs=[1.0],
14 eta=0.05,
15 pol='yyy',
16 eshift=0.0,
17 ftol=1e-4, Etol=1e-6,
18 band_n=None,
19 out_name='shift.npy'):
20 """
21 Calculate RPA shift current for nonmagnetic semiconductors.
23 Parameters
24 ----------
25 nlodata
26 Data object of class NLOData.
27 freqs
28 Excitation frequency array (a numpy array or list).
29 eta
30 Broadening, a number or an array (default 0.05 eV).
31 pol
32 Tensor element (default 'yyy').
33 Etol, ftol
34 Tolerance in energy and fermi to consider degeneracy.
35 band_n
36 List of bands in the sum (default 0 to nb).
37 out_name
38 Output filename (default 'shift.npy').
40 Returns
41 -------
42 np.ndarray
43 Numpy array containing the spectrum and frequencies.
45 """
47 # Start a timer
48 timer = Timer()
49 parprint(f'Calculating shift current (in {world.size:d} cores).')
51 # Covert inputs in eV to Ha
52 nw = len(freqs)
53 w_l = np.array(freqs) / Ha
54 eta /= Ha
55 Etol /= Ha
57 # Useful variables
58 pol_v = ['xyz'.index(ii) for ii in pol]
60 parprint(f'Calculation for element {pol}.')
62 # Load the required data
63 with timer('Load and distribute the data'):
64 k_info = nlodata.distribute()
65 if k_info:
66 tmp = list(k_info.values())[0]
67 nb = len(tmp[1])
68 nk = len(k_info) * world.size # Approximately
69 if band_n is None:
70 band_n = list(range(nb))
71 mem = 6 * 3 * nk * nb**2 * 16 / 2**20
72 parprint(f'At least {mem:.2f} MB of memory is required.')
74 # Initial call to print 0% progress
75 count = 0
76 ncount = len(k_info)
77 if world.rank == 0:
78 pb = ProgressBar()
80 # Initialize the outputs
81 sum2_l = np.zeros((nw), complex)
83 # Do the calculations
84 for _, (we, f_n, E_n, p_vnn) in k_info.items():
85 with timer('Position matrix elements calculation'):
86 r_vnn, D_vnn = get_rml(E_n, p_vnn, pol_v, Etol=Etol)
88 with timer('Compute generalized derivative'):
89 rd_vvnn = get_derivative(E_n, r_vnn, D_vnn, pol_v, Etol=Etol)
91 with timer('Sum over bands'):
92 tmp = shift_current(
93 eta, w_l, f_n, E_n, r_vnn, rd_vvnn, pol_v,
94 band_n, ftol, Etol, eshift)
96 # Add it to previous with a weight
97 sum2_l += tmp * we
99 # Print the progress
100 if world.rank == 0:
101 pb.update(count / ncount)
102 count += 1
104 if world.rank == 0:
105 pb.finish()
107 with timer('Gather data from cores'):
108 world.sum(sum2_l)
110 # Multiply prefactors
111 prefactor = 1 / (2 * (2.0 * np.pi)**3)
112 # Convert to SI units [A / V^2] = [C^3 / (J^2 * s)]
113 prefactor *= _e**3 / (_hbar * (Ha / J))
115 sigma_l = prefactor * sum2_l.real
117 # A multi-col output
118 shift = np.vstack((freqs, sigma_l))
120 # Save it to the file
121 if world.rank == 0:
122 np.save(out_name, shift)
124 # Print the timing
125 timer.write()
127 return shift
130def shift_current(
131 eta, w_l, f_n, E_n, r_vnn, rd_vvnn, pol_v,
132 band_n=None, ftol=1e-4, Etol=1e-6, eshift=0.0):
133 """
134 Loop over bands for computing in length gauge
136 Input:
137 eta Broadening
138 w_l Complex frequency array
139 f_n Fermi levels
140 E_n Energies
141 r_vnn Momentum matrix elements
142 rd_vvnn Generalized derivative of position
143 pol_v Tensor element
144 band_n Band list
145 Etol, ftol Tol. in energy and fermi to consider degeneracy
146 eshift Bandgap correction
147 Output:
148 sum2_l Output array
149 """
151 # Initialize variable
152 nb = len(f_n)
153 if band_n is None:
154 band_n = list(range(nb))
155 sum2_l = np.zeros((w_l.size), complex)
157 # Loop over bands
158 for nni in band_n:
159 for mmi in band_n:
160 # Remove the non important term (use TRS)
161 if mmi <= nni:
162 continue
163 fnm = f_n[nni] - f_n[mmi]
164 Emn = E_n[mmi] - E_n[nni] + fnm * eshift
166 # Two band part
167 if np.abs(fnm) > ftol:
168 tmp = np.imag(
169 r_vnn[pol_v[1], mmi, nni]
170 * rd_vvnn[pol_v[0], pol_v[2], nni, mmi]
171 + r_vnn[pol_v[2], mmi, nni]
172 * rd_vvnn[pol_v[0], pol_v[1], nni, mmi]) \
173 * (eta / (np.pi * ((w_l - Emn) ** 2 + eta ** 2))
174 - eta / (np.pi * ((w_l + Emn) ** 2 + eta ** 2)))
176 sum2_l += fnm * tmp
178 return 2 * np.pi * sum2_l