Coverage for gpaw/nlopt/matrixel.py: 87%
132 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
2from typing import TYPE_CHECKING
4from ase.parallel import parprint
5from ase.utils.timing import Timer
6from pathlib import Path
7import numpy as np
9from gpaw.new.ase_interface import ASECalculator
10from gpaw.nlopt.basic import NLOData
11from gpaw.utilities.progressbar import ProgressBar
13if TYPE_CHECKING:
14 from gpaw.nlopt.adapters import CollinearGSInfo, NoncollinearGSInfo
15 from gpaw.typing import ArrayND
18def get_mml(gs: CollinearGSInfo | NoncollinearGSInfo,
19 bands: slice,
20 spin: int,
21 timer: Timer | None = None) -> ArrayND:
22 """
23 Compute momentum matrix elements.
25 Parameters
26 ----------
27 gs
28 Ground state adapter.
29 bands
30 Range of band indices.
31 spin
32 Spin channel index (for spin-polarized systems 0 or 1).
33 timer
34 Timer for monitoring code performance.
36 Returns
37 -------
38 p_qvnn
39 Momentum matrix elements for each local q-point.
40 """
42 # Start the timer
43 if timer is None:
44 timer = Timer()
45 parprint(f'Calculating momentum matrix elements for spin channel {spin}.')
47 # Spin input
48 assert spin < gs.ns, 'Wrong spin input'
50 # Allocate the matrix elements
51 ibzwfs = gs.ibzwfs
52 master = (ibzwfs.kpt_comm.rank == 0)
53 nb = bands.stop - bands.start
54 nq = len(ibzwfs.q_k.keys()) # Number of k-points (q-indices) for each core
55 p_qvnn = np.empty([nq, 3, nb, nb], dtype=complex)
57 # Initial call to print 0 % progress
58 if master:
59 pb = ProgressBar()
61 # Calculate matrix elements in loop over k-points
62 for wfs_s in ibzwfs.wfs_qs:
63 wfs = gs.get_wfs(wfs_s, spin)
65 with timer('Contribution from pseudo wave functions'):
66 G_plus_k_Gv, u_nG = gs.get_plane_wave_coefficients(wfs,
67 bands, spin)
68 p_vnn = np.einsum('Gv,mG,nG->vmn',
69 G_plus_k_Gv, u_nG.conj(), u_nG) * gs.ucvol
71 with timer('Contribution from PAW corrections'):
72 P_ani = gs.get_wave_function_projections(wfs, bands, spin)
73 for P_ni, nabla_iiv in zip(P_ani.values(), gs.nabla_aiiv):
74 p_vnn -= 1j * np.einsum('mi,nj,ijv->vmn',
75 P_ni.conj(), P_ni, nabla_iiv)
77 p_qvnn[wfs.q] = p_vnn
79 if master:
80 pb.update(wfs.q / nq)
82 if master:
83 pb.finish()
84 timer.write()
86 return p_qvnn
89def gather_to_master(p_qvnn, ibzwfs):
90 kpt_comm = ibzwfs.kpt_comm
91 master = (kpt_comm.rank == 0)
92 shape = p_qvnn.shape[1:4]
94 if not master:
95 kpt_comm.send(p_qvnn, 0)
96 return np.empty((0,) + shape, complex)
97 else:
98 rank_k = ibzwfs.rank_k
99 nk = len(rank_k)
101 p_kvnn = np.empty((nk,) + shape, complex)
103 k_q = np.where(rank_k == 0)[0]
104 p_kvnn[k_q] = p_qvnn
105 for gather_rank in range(1, kpt_comm.size):
106 k_q = np.where(rank_k == gather_rank)[0]
107 nq = len(k_q)
109 p_qvnn = np.zeros((nq,) + shape, complex)
110 kpt_comm.receive(p_qvnn, gather_rank)
111 p_kvnn[k_q] = p_qvnn
112 return p_kvnn
115def make_nlodata(calc: ASECalculator | str | Path,
116 spin_string: str = 'all',
117 ni: int | None = None,
118 nf: int | None = None) -> NLOData:
119 """
120 This function calculates and returns all required
121 NLO data: w_sk, f_skn, E_skn, p_skvnn.
123 Parameters
124 ----------
125 calc
126 Calculator or string/path pointing to a .gpw file.
127 spin_string
128 String denoting which spin channels to include ('all', 's0' , 's1').
129 ni
130 First band to compute the mml.
131 nf
132 Last band to compute the mml (relative to number of bands for nf <= 0).
134 Returns
135 -------
136 NLOData
137 Data object carrying required matrix elements for NLO calculations.
139 """
141 if not isinstance(calc, ASECalculator):
142 if not (isinstance(calc, str) or isinstance(calc, Path)):
143 raise TypeError('Input must be a calculator or a string / path'
144 'pointing to a calculator.')
145 from gpaw.new.ase_interface import GPAW
146 calc = GPAW(calc, txt=None, parallel={'domain': 1, 'band': 1})
147 assert not calc.symmetry.point_group, \
148 'Point group symmetry should be off.'
150 gs: CollinearGSInfo | NoncollinearGSInfo
151 if calc.dft.density.collinear:
152 from gpaw.nlopt.adapters import CollinearGSInfo
153 gs = CollinearGSInfo(calc)
154 else:
155 from gpaw.nlopt.adapters import NoncollinearGSInfo
156 gs = NoncollinearGSInfo(calc)
158 # Start the timer
159 timer = Timer()
161 # Parse spin string
162 ns = gs.ns
163 if spin_string == 'all':
164 spins = list(range(ns))
165 elif spin_string == 's0':
166 spins = [0]
167 elif spin_string == 's1':
168 spins = [1]
169 assert spins[0] < ns, 'Wrong spin input'
170 else:
171 raise NotImplementedError
173 # Parse band input
174 ibzwfs = gs.ibzwfs
175 nb_full = ibzwfs.nbands
176 ni = int(ni) if ni is not None else 0
177 nf = int(nf) if nf is not None else nb_full
178 nf = nb_full + nf if (nf <= 0) else nf
179 bands = slice(ni, nf)
181 # Memory estimate
182 nk = len(ibzwfs.rank_k) # Total number of k-points
183 est_mem = 2 * 3 * nk * (nf - ni)**2 * 16 / 2**30
184 parprint(f'At least {est_mem:.2f} GB of memory is required on master.')
186 # Get the energy and Fermi-Dirac occupations (data is only in master)
187 with timer('Get energies and fermi levels'):
188 E_skn, f_skn = ibzwfs.get_all_eigs_and_occs()
189 w_sk = np.array([ibzwfs.ibz.weight_k for _ in range(gs.ndensities)])
190 w_sk *= gs.bzvol * ibzwfs.spin_degeneracy
192 # Compute the momentum matrix elements
193 with timer('Compute the momentum matrix elements'):
194 p_sqvnn = []
195 for spin in spins:
196 p_qvnn = get_mml(gs, bands, spin, timer)
197 p_sqvnn.append(p_qvnn)
198 if not gs.collinear:
199 p_sqvnn = [p_sqvnn[0] + p_sqvnn[1]]
200 with timer('Gather the data to master'):
201 p_skvnn = []
202 for p_qvnn in p_sqvnn:
203 p_kvnn = gather_to_master(p_qvnn, ibzwfs)
204 p_skvnn.append(p_kvnn)
206 # Save the output to the file
207 return NLOData(w_sk=w_sk,
208 f_skn=f_skn[:, :, bands],
209 E_skn=E_skn[:, :, bands],
210 p_skvnn=np.array(p_skvnn, complex),
211 comm=ibzwfs.kpt_comm)
214def get_rml(E_n, p_vnn, pol_v, Etol=1e-6):
215 """
216 Compute the position matrix elements
218 Parameters
219 ----------
220 E_n
221 Band energies.
222 p_vnn
223 Momentum matrix elements.
224 pol_v
225 Tensor element.
226 Etol
227 Tolerance in energy to consider degeneracy.
229 Returns
230 -------
231 r_vnn
232 Position matrix elements.
233 D_vnn
234 Velocity difference matrix elements.
236 """
238 # Useful variables
239 nb = len(E_n)
240 r_vnn = np.zeros((3, nb, nb), complex)
241 D_vnn = np.zeros((3, nb, nb), complex)
242 E_nn = np.tile(E_n[:, None], (1, nb)) - \
243 np.tile(E_n[None, :], (nb, 1))
244 zeroind = np.abs(E_nn) < Etol
245 E_nn[zeroind] = 1
246 # Loop over components
247 for v1 in set(pol_v):
248 r_vnn[v1] = p_vnn[v1] / (1j * E_nn)
249 r_vnn[v1, zeroind] = 0
250 p_n = np.diag(p_vnn[v1])
251 D_vnn[v1] = np.tile(p_n[:, None], (1, nb)) - \
252 np.tile(p_n[None, :], (nb, 1))
254 return r_vnn, D_vnn
257def get_derivative(E_n, r_vnn, D_vnn, pol_v, Etol=1e-6):
258 """
259 Compute the generalised derivative of position matrix elements
261 Parameters
262 ----------
263 E_n
264 Band energies.
265 r_vnn
266 Momentum matrix elements.
267 D_vnn
268 Velocity difference matrix elements.
269 pol_v
270 Tensor element.
271 Etol
272 Tolerance in energy to consider degeneracy.
274 Returns
275 -------
276 rd_vvnn
277 Generalised derivative of position matrix elements.
279 """
281 # Useful variables
282 nb = len(E_n)
283 rd_vvnn = np.zeros((3, 3, nb, nb), complex)
284 E_nn = np.tile(E_n[:, None], (1, nb)) - \
285 np.tile(E_n[None, :], (nb, 1))
286 zeroind = np.abs(E_nn) < Etol
287 E_nn[zeroind] = 1
288 for v1 in set(pol_v):
289 for v2 in set(pol_v):
290 tmp = (r_vnn[v1] * np.transpose(D_vnn[v2])
291 + r_vnn[v2] * np.transpose(D_vnn[v1])
292 + 1j * np.dot(r_vnn[v1], r_vnn[v2] * E_nn)
293 - 1j * np.dot(r_vnn[v2] * E_nn, r_vnn[v1])) / E_nn
294 tmp[zeroind] = 0
295 rd_vvnn[v1, v2] = tmp
297 return rd_vvnn