Coverage for gpaw/test/lcao/test_dipole_transition.py: 100%

58 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2import pytest 

3 

4from ase.parallel import world, parprint 

5from ase.units import Bohr 

6from gpaw import GPAW 

7from gpaw.lcao.dipoletransition import get_dipole_transitions 

8from gpaw.utilities.dipole import dipole_matrix_elements_from_calc 

9from gpaw.lrtddft.kssingle import KSSingles 

10 

11 

12@pytest.mark.old_gpaw_only 

13def test_dipole_transition(gpw_files, tmp_path_factory): 

14 """Check dipole matrix-elements for H20.""" 

15 calc = GPAW(gpw_files['h2o_lcao']) 

16 # Initialize calculator if necessary 

17 if not hasattr(calc.wfs, 'C_nM'): 

18 calc.initialize_positions(calc.atoms) 

19 dip_skvnm = get_dipole_transitions(calc.wfs).real 

20 parprint("Dipole moments calculated") 

21 assert dip_skvnm.shape == (1, 1, 3, 6, 6) 

22 dip_vnm = dip_skvnm[0, 0] * Bohr 

23 

24 print(world.rank, dip_vnm[0, 0, 3]) 

25 

26 # check symmetry: abs(d[i,j]) == abs(d[j,i]) 

27 for v in range(3): 

28 dip_vnm[v].T == pytest.approx(dip_vnm[v]) 

29 

30 # Check numerical value of a few elements - signs might change! 

31 assert 0.0693 == pytest.approx(abs(dip_vnm[2, 0, 4]), abs=1e-4) 

32 assert 0.1014 == pytest.approx(abs(dip_vnm[1, 0, 5]), abs=1e-4) 

33 assert 0.1709 == pytest.approx(abs(dip_vnm[0, 3, 4]), abs=1e-4) 

34 

35 # some printout for manual inspection, if wanted 

36 f = 6 * "{:+.4f} " 

37 for c in range(3): 

38 for i in range(6): 

39 parprint(f.format(*dip_vnm[c, i])) 

40 parprint("") 

41 

42 # ------------------------------------------------------------------------ 

43 # compare to utilities implementation 

44 if world.rank == 0: 

45 from gpaw.new.ase_interface import GPAW as NewGPAW 

46 from gpaw.mpi import serial_comm 

47 refcalc = NewGPAW(gpw_files['h2o_lcao'], 

48 communicator=serial_comm) 

49 uref = dipole_matrix_elements_from_calc(refcalc, 0, 6) 

50 uref = uref[0] 

51 assert uref.shape == (6, 6, 3) 

52 # NOTE: Comparing implementations of r gauge and v gauge is tricky, as they 

53 # tend to be numerically inequivalent. 

54 

55 # compare to lrtddft implementation 

56 kss = KSSingles() 

57 atoms = calc.atoms 

58 atoms.calc = calc 

59 kss.calculate(calc.atoms, 1) 

60 lrref = [] 

61 lrrefv = [] 

62 for ex in kss: 

63 lrref.append(-1. * ex.mur * Bohr) 

64 lrrefv.append(-1. * ex.muv * Bohr) 

65 lrref = np.array(lrref) 

66 lrrefv = np.array(lrrefv) 

67 

68 # Additional benefit: tests equivalence of r gauge implementations 

69 if world.rank == 0: 

70 for i, (m, n, v) in enumerate([[4, 0, 2], 

71 [5, 0, 1], 

72 [4, 1, 1], 

73 [5, 1, 2], 

74 [4, 2, 2], 

75 [5, 2, 1], 

76 [4, 3, 0]]): 

77 assert abs(lrref[i, v]) == pytest.approx(abs(uref[m, n, v]), 

78 abs=1e-4) 

79 

80 # some printout for manual inspection, if wanted 

81 parprint(" r-gauge lrtddft(v) raman(v)") 

82 f = "{} {:+.4f} {:+.4f} {:+.4f}" 

83 parprint(f.format('0->4 (z)', lrref[0, 2], lrrefv[0, 2], dip_vnm[2, 0, 4])) 

84 parprint(f.format('0->5 (y)', lrref[1, 1], lrrefv[1, 1], dip_vnm[1, 0, 5])) 

85 parprint(f.format('1->4 (y)', lrref[2, 1], lrrefv[2, 1], dip_vnm[1, 1, 4])) 

86 parprint(f.format('1->5 (z)', lrref[3, 2], lrrefv[3, 2], dip_vnm[2, 1, 5])) 

87 parprint(f.format('2->4 (z)', lrref[4, 2], lrrefv[4, 2], dip_vnm[2, 2, 4])) 

88 parprint(f.format('2->5 (y)', lrref[5, 1], lrrefv[5, 1], dip_vnm[1, 2, 5])) 

89 parprint(f.format('3->4 (x)', lrref[6, 0], lrrefv[6, 0], dip_vnm[0, 3, 4]))