Coverage for gpaw/utilities/dipole.py: 92%
53 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""Calculate dipole matrix elements."""
2from __future__ import annotations
3import numpy as np
4from ase.units import Bohr
5from gpaw.new.ase_interface import ASECalculator, GPAW
6from gpaw.typing import Array3D
7from gpaw.new.lcao.wave_functions import LCAOWaveFunctions
8from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
11def dipole_matrix_elements(*args, **kwargs):
12 """Deprecated.
14 Use
15 ``gpaw.new.pwfd.wave_functions.PWFDWaveFunctions.dipole_matrix_elements``
16 instead.
17 """
18 raise DeprecationWarning
21def dipole_matrix_elements_from_calc(calc: ASECalculator,
22 n1: int,
23 n2: int,
24 ) -> list[Array3D]:
25 """Calculate dipole matrix-elements (units: eÅ).
27 Parameters
28 ----------
29 n1, n2:
30 Band range.
31 """
32 ibzwfs = calc.dft.ibzwfs
34 assert ibzwfs.ibz.bz.gamma_only
36 wfs_s = ibzwfs.wfs_qs[0]
38 d_snnv = []
39 for wfs in wfs_s:
40 if isinstance(wfs, LCAOWaveFunctions):
41 basis = calc.dft.scf_loop.hamiltonian.basis
42 grid = calc.dft.density.nt_sR.desc
43 wfs = wfs.to_uniform_grid_wave_functions(grid, basis)
44 wfs12 = wfs.collect(n1, n2)
45 if wfs12 is not None:
46 assert isinstance(wfs12, PWFDWaveFunctions)
47 d_nnv = wfs12.dipole_matrix_elements() * Bohr
48 else:
49 d_nnv = np.empty((n2 - n1, n2 - n1, 3))
50 calc.comm.broadcast(d_nnv, 0)
51 d_snnv.append(d_nnv)
53 return d_snnv
56def main(argv: list[str] = None) -> None:
57 import argparse
59 parser = argparse.ArgumentParser(
60 prog='python3 -m gpaw.utilities.dipole',
61 description='Calculate dipole matrix elements. Units: eÅ.')
63 add = parser.add_argument
65 add('file', metavar='input-file',
66 help='GPW-file with wave functions.')
67 add('-n', '--band-range', nargs=2, type=int, default=[0, 0],
68 metavar=('n1', 'n2'), help='Include bands: n1 <= n < n2.')
70 args = parser.parse_intermixed_args(argv)
72 calc = GPAW(args.file)
74 n1, n2 = args.band_range
75 nbands = calc.get_number_of_bands()
76 n2 = n2 or n2 + nbands
78 d_snnv = dipole_matrix_elements_from_calc(calc, n1, n2)
80 if calc.comm.rank > 0:
81 return
83 print('Number of bands:', nbands)
84 print('Number of valence electrons:', calc.get_number_of_electrons())
85 print('Units: eÅ')
86 print()
88 for spin, d_nnv in enumerate(d_snnv):
89 print(f'Spin={spin}')
91 for direction, d_nn in zip('xyz', d_nnv.T):
92 print(f' <{direction}>',
93 ''.join(f'{n:8}' for n in range(n1, n2)))
94 for n in range(n1, n2):
95 print(f'{n:4}', ''.join(f'{d:8.3f}' for d in d_nn[n - n1]))
98if __name__ == '__main__':
99 main()