Coverage for gpaw/new/pwfd/rmmdiis.py: 92%

73 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2 

3import warnings 

4from pprint import pformat 

5 

6import numpy as np 

7from gpaw.gpu import as_np 

8from gpaw.new import zips as zip 

9from gpaw.new.pwfd.eigensolver import PWFDEigensolver, calculate_residuals 

10from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

11 

12 

13class RMMDIIS(PWFDEigensolver): 

14 def __init__(self, 

15 nbands: int, 

16 wf_grid, 

17 band_comm, 

18 hamiltonian, 

19 converge_bands='occupied', 

20 niter: int = 1, 

21 trial_step: float | None = None, 

22 scalapack_parameters=None, 

23 max_buffer_mem: int = 200 * 1024 ** 2): 

24 """RMM-DIIS eigensolver. 

25 

26 Solution steps are: 

27 

28 * Subspace diagonalization 

29 * Calculation of residuals 

30 * Improvement of wave functions: psi' = psi + lambda PR + lambda PR' 

31 * Orthonormalization 

32 

33 Parameters 

34 ========== 

35 trial_step: 

36 Step length for final step. Use None for using the previously 

37 optimized step lengths. 

38 """ 

39 

40 if niter != 1: 

41 warnings.warn(f'Ignoring niter={niter} in RMMDIIS') 

42 super().__init__(hamiltonian, converge_bands, 

43 max_buffer_mem=max_buffer_mem) 

44 self.trial_step = trial_step 

45 

46 def __str__(self): 

47 return pformat(dict(name='RMMDIIS', 

48 converge_bands=self.converge_bands)) 

49 

50 def _initialize(self, ibzwfs): 

51 super()._initialize(ibzwfs) 

52 self._allocate_work_arrays(ibzwfs, shape=(1,)) 

53 self._allocate_buffer_arrays(ibzwfs, shape=(2,)) 

54 

55 def iterate1(self, 

56 wfs: PWFDWaveFunctions, 

57 Ht, dH, dS_aii, weight_n): 

58 """Do one step ... 

59 

60 See here: 

61 

62 https://gpaw.readthedocs.io/documentation/rmm-diis.html 

63 """ 

64 

65 psit_nX = wfs.psit_nX 

66 mynbands = psit_nX.mydims[0] 

67 

68 residual_nX = psit_nX.new(data=self.work_arrays[0, :mynbands]) 

69 

70 P_ani = wfs.P_ani 

71 work1_ani = P_ani.new() 

72 work2_ani = P_ani.new() 

73 

74 wfs.subspace_diagonalize(Ht, dH, 

75 psit2_nX=residual_nX, 

76 data_buffer=self.data_buffers[0]) 

77 calculate_residuals(wfs.psit_nX, residual_nX, wfs.pt_aiX, 

78 wfs.P_ani, wfs.myeig_n, 

79 dH, dS_aii, work1_ani, work2_ani) 

80 

81 work1_nX = psit_nX.create_work_buffer(self.data_buffers[0]) 

82 work2_nX = psit_nX.create_work_buffer(self.data_buffers[1]) 

83 blocksize = work1_nX.data.shape[0] 

84 P1_ani = P_ani.layout.empty(blocksize) 

85 P2_ani = P_ani.layout.empty(blocksize) 

86 if weight_n is None: 

87 error = np.inf 

88 else: 

89 error = weight_n @ as_np(residual_nX.norm2()) 

90 

91 comm = psit_nX.comm 

92 blocksize_world = comm.sum_scalar(blocksize) 

93 totalbands = comm.sum_scalar(mynbands) 

94 for i1, N1 in enumerate( 

95 range(0, totalbands, blocksize_world)): 

96 n1 = i1 * blocksize 

97 n2 = n1 + blocksize 

98 if n2 > mynbands: 

99 n2 = mynbands 

100 P1_ani = P1_ani[:, :n2 - n1] 

101 P2_ani = P2_ani[:, :n2 - n1] 

102 block_step( 

103 psit_nX[n1:n2], 

104 residual_nX[n1:n2], 

105 wfs.pt_aiX, wfs.myeig_n[n1:n2], Ht, dH, dS_aii, 

106 self.trial_step, 

107 work1_nX[:n2 - n1], 

108 work2_nX[:n2 - n1], 

109 P1_ani, P2_ani, 

110 self.preconditioner) 

111 wfs._P_ani = None 

112 wfs.orthonormalized = False 

113 wfs.orthonormalize(residual_nX) 

114 return error 

115 

116 

117def block_step(psit_nX, 

118 R_nX, 

119 pt_aiX, 

120 eig_n, 

121 Ht, 

122 dH, 

123 dS_aii, 

124 trial_step, 

125 work1_nX, 

126 work2_nX, 

127 P1_ani, 

128 P2_ani, 

129 preconditioner) -> None: 

130 """See here: 

131 

132 https://gpaw.readthedocs.io/documentation/rmm-diis.html 

133 """ 

134 xp = psit_nX.xp 

135 PR_nX = work1_nX 

136 dR_nX = work2_nX 

137 ekin_n = preconditioner(psit_nX, R_nX, out=PR_nX) 

138 

139 Ht(PR_nX, out=dR_nX) 

140 P_ani = pt_aiX.integrate(PR_nX) 

141 calculate_residuals(PR_nX, dR_nX, pt_aiX, P_ani, eig_n, 

142 dH, dS_aii, P1_ani, P2_ani) 

143 a_n = xp.asarray([-d_X.integrate(r_X) 

144 for d_X, r_X in zip(dR_nX, R_nX)]) 

145 b_n = dR_nX.norm2() 

146 shape = (len(a_n),) + (1,) * (psit_nX.data.ndim - 1) 

147 lambda_n = (a_n / b_n).reshape(shape) 

148 PR_nX.data *= lambda_n 

149 psit_nX.data += PR_nX.data 

150 dR_nX.data *= lambda_n 

151 R_nX.data += dR_nX.data 

152 preconditioner(psit_nX, R_nX, out=PR_nX, ekin_n=ekin_n) 

153 if trial_step is None: 

154 PR_nX.data *= lambda_n 

155 else: 

156 PR_nX.data *= trial_step 

157 psit_nX.data += PR_nX.data