Coverage for gpaw/utilities/dipole.py: 92%

53 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""Calculate dipole matrix elements.""" 

2from __future__ import annotations 

3import numpy as np 

4from ase.units import Bohr 

5from gpaw.new.ase_interface import ASECalculator, GPAW 

6from gpaw.typing import Array3D 

7from gpaw.new.lcao.wave_functions import LCAOWaveFunctions 

8from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

9 

10 

11def dipole_matrix_elements(*args, **kwargs): 

12 """Deprecated. 

13 

14 Use 

15 ``gpaw.new.pwfd.wave_functions.PWFDWaveFunctions.dipole_matrix_elements`` 

16 instead. 

17 """ 

18 raise DeprecationWarning 

19 

20 

21def dipole_matrix_elements_from_calc(calc: ASECalculator, 

22 n1: int, 

23 n2: int, 

24 ) -> list[Array3D]: 

25 """Calculate dipole matrix-elements (units: eÅ). 

26 

27 Parameters 

28 ---------- 

29 n1, n2: 

30 Band range. 

31 """ 

32 ibzwfs = calc.dft.ibzwfs 

33 

34 assert ibzwfs.ibz.bz.gamma_only 

35 

36 wfs_s = ibzwfs.wfs_qs[0] 

37 

38 d_snnv = [] 

39 for wfs in wfs_s: 

40 if isinstance(wfs, LCAOWaveFunctions): 

41 basis = calc.dft.scf_loop.hamiltonian.basis 

42 grid = calc.dft.density.nt_sR.desc 

43 wfs = wfs.to_uniform_grid_wave_functions(grid, basis) 

44 wfs12 = wfs.collect(n1, n2) 

45 if wfs12 is not None: 

46 assert isinstance(wfs12, PWFDWaveFunctions) 

47 d_nnv = wfs12.dipole_matrix_elements() * Bohr 

48 else: 

49 d_nnv = np.empty((n2 - n1, n2 - n1, 3)) 

50 calc.comm.broadcast(d_nnv, 0) 

51 d_snnv.append(d_nnv) 

52 

53 return d_snnv 

54 

55 

56def main(argv: list[str] = None) -> None: 

57 import argparse 

58 

59 parser = argparse.ArgumentParser( 

60 prog='python3 -m gpaw.utilities.dipole', 

61 description='Calculate dipole matrix elements. Units: eÅ.') 

62 

63 add = parser.add_argument 

64 

65 add('file', metavar='input-file', 

66 help='GPW-file with wave functions.') 

67 add('-n', '--band-range', nargs=2, type=int, default=[0, 0], 

68 metavar=('n1', 'n2'), help='Include bands: n1 <= n < n2.') 

69 

70 args = parser.parse_intermixed_args(argv) 

71 

72 calc = GPAW(args.file) 

73 

74 n1, n2 = args.band_range 

75 nbands = calc.get_number_of_bands() 

76 n2 = n2 or n2 + nbands 

77 

78 d_snnv = dipole_matrix_elements_from_calc(calc, n1, n2) 

79 

80 if calc.comm.rank > 0: 

81 return 

82 

83 print('Number of bands:', nbands) 

84 print('Number of valence electrons:', calc.get_number_of_electrons()) 

85 print('Units: eÅ') 

86 print() 

87 

88 for spin, d_nnv in enumerate(d_snnv): 

89 print(f'Spin={spin}') 

90 

91 for direction, d_nn in zip('xyz', d_nnv.T): 

92 print(f' <{direction}>', 

93 ''.join(f'{n:8}' for n in range(n1, n2))) 

94 for n in range(n1, n2): 

95 print(f'{n:4}', ''.join(f'{d:8.3f}' for d in d_nn[n - n1])) 

96 

97 

98if __name__ == '__main__': 

99 main()