Coverage for gpaw/directmin/scf_helper.py: 97%

88 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1import warnings 

2import numpy as np 

3 

4from ase.units import Ha 

5 

6from gpaw.directmin.tools import (sort_orbitals_according_to_energies, 

7 get_n_occ) 

8 

9 

10def do_if_converged(eigensolver_name, wfs, ham, dens, log): 

11 name = eigensolver_name 

12 if name == 'etdm-lcao' or name == 'etdm-fdpw': 

13 occ_name = getattr(wfs.occupations, 'name', None) 

14 solver = wfs.eigensolver 

15 if hasattr(solver, 'dm_helper'): 

16 func_name = solver.dm_helper.func.name 

17 elif hasattr(solver, 'odd'): 

18 func_name = solver.odd.name 

19 sic_calc = 'SIC' in func_name 

20 else: 

21 return 

22 

23 if hasattr(solver, 'e_sic'): 

24 e_sic = solver.e_sic 

25 else: 

26 e_sic = 0.0 

27 

28 if hasattr(solver, 'constraints'): 

29 constraints = solver.constraints 

30 else: 

31 constraints = None 

32 

33 if eigensolver_name == 'etdm-lcao': 

34 with ((wfs.timer('Get canonical representation'))): 

35 for kpt in wfs.kpt_u: 

36 solver.dm_helper.update_to_canonical_orbitals( 

37 wfs, ham, kpt, False, False) 

38 

39 log('\nOccupied states converged after' 

40 ' {:d} e/g evaluations'.format(solver.eg_count)) 

41 

42 elif eigensolver_name == 'etdm-fdpw': 

43 solver.choose_optimal_orbitals(wfs) 

44 niter1 = solver.eg_count 

45 niter2 = 0 

46 niter3 = 0 

47 

48 iloop1 = solver.iloop is not None 

49 iloop2 = solver.outer_iloop is not None 

50 if iloop1: 

51 niter2 = solver.total_eg_count_iloop 

52 if iloop2: 

53 niter3 = solver.total_eg_count_outer_iloop 

54 

55 if iloop1 and iloop2: 

56 log( 

57 '\nOccupied states converged after' 

58 ' {:d} KS and {:d} SIC e/g ' 

59 'evaluations'.format(niter3, 

60 niter2 + niter3)) 

61 elif not iloop1 and iloop2: 

62 log( 

63 '\nOccupied states converged after' 

64 ' {:d} e/g evaluations'.format(niter3)) 

65 elif iloop1 and not iloop2: 

66 log( 

67 '\nOccupied states converged after' 

68 ' {:d} KS and {:d} SIC e/g ' 

69 'evaluations'.format(niter1, niter2)) 

70 else: 

71 log( 

72 '\nOccupied states converged after' 

73 ' {:d} e/g evaluations'.format(niter1)) 

74 if solver.converge_unocc: 

75 log('Converge unoccupied states:') 

76 max_er = wfs.eigensolver.error 

77 max_er *= Ha ** 2 / wfs.nvalence 

78 solver.run_unocc(ham, wfs, dens, max_er, log) 

79 else: 

80 log('Unoccupied states are not converged.') 

81 solver.initialized = False 

82 

83 rewrite_psi = True 

84 if sic_calc: 

85 rewrite_psi = False 

86 

87 solver.get_canonical_representation(ham, wfs, rewrite_psi) 

88 

89 if occ_name == 'mom': 

90 check_mom_no_update_of_occupations(wfs) 

91 

92 solver.update_ks_energy(ham, wfs, dens) 

93 ham.get_energy(0.0, wfs, kin_en_using_band=False, e_sic=e_sic) 

94 sort_orbitals_according_to_energies(ham, wfs, constraints) 

95 

96 if eigensolver_name == 'etdm-lcao': 

97 solver.set_ref_orbitals_and_a_vec(wfs) 

98 

99 if occ_name == 'mom': 

100 not_update = not wfs.occupations.update_numbers 

101 if not_update: 

102 wfs.occupations.numbers = solver.initial_occupation_numbers 

103 

104 

105def check_eigensolver_state(eigensolver_name, wfs, ham, dens, log): 

106 solver = wfs.eigensolver 

107 name = eigensolver_name 

108 if name == 'etdm-lcao' or name == 'etdm-fdpw': 

109 solver.eg_count = 0 

110 solver.globaliters = 0 

111 

112 if hasattr(solver, 'iloop'): 

113 if solver.iloop is not None: 

114 solver.iloop.total_eg_count = 0 

115 if hasattr(solver, 'outer_iloop'): 

116 if solver.outer_iloop is not None: 

117 solver.outer_iloop.total_eg_count = 0 

118 

119 solver.check_assertions(wfs, dens) 

120 if (hasattr(solver, 'dm_helper') and solver.dm_helper is None) \ 

121 or not solver.initialized: 

122 solver.initialize_dm_helper(wfs, ham, dens, log) 

123 

124 

125def check_mom_no_update_of_occupations(wfs): 

126 f_sn = wfs.occupations.update_occupations() 

127 for kpt in wfs.kpt_u: 

128 k = wfs.kd.nibzkpts * kpt.s + kpt.q 

129 n_occ, occupied = get_n_occ(kpt) 

130 if n_occ != 0.0 and np.min(f_sn[k][:n_occ]) == 0: 

131 warnings.warn('MOM has detected variational collapse ' 

132 'after getting canonical orbitals. Check ' 

133 'that the orbitals are consistent with the ' 

134 'initial guess.')