Coverage for gpaw/new/lcao/builder.py: 96%

99 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numpy as np 

2from gpaw.core.matrix import Matrix, MatrixWithNoData 

3from gpaw.lcao.tci import TCIExpansions 

4from gpaw.new import zips 

5from gpaw.new.fd.builder import FDDFTComponentsBuilder 

6from gpaw.new.lcao.ibzwfs import LCAOIBZWaveFunctions 

7from gpaw.new.lcao.forces import TCIDerivatives 

8from gpaw.new.lcao.hamiltonian import LCAOHamiltonian 

9from gpaw.new.lcao.hybrids import HybridXCFunctional 

10from gpaw.new.lcao.wave_functions import LCAOWaveFunctions 

11from gpaw.utilities.timing import NullTimer 

12 

13 

14class LCAODFTComponentsBuilder(FDDFTComponentsBuilder): 

15 def __init__(self, 

16 atoms, 

17 params, 

18 *, 

19 comm, 

20 log): 

21 super().__init__(atoms, params, comm=comm, log=log) 

22 self.distribution = params.mode.distribution 

23 self.basis = None 

24 

25 def create_wf_description(self): 

26 raise NotImplementedError 

27 

28 def create_xc_functional(self): 

29 if self.params.xc['name'] in ['HSE06', 'PBE0', 'EXX']: 

30 return HybridXCFunctional(self.params.xc) 

31 return super().create_xc_functional() 

32 

33 def create_basis_set(self): 

34 self.basis = FDDFTComponentsBuilder.create_basis_set(self) 

35 return self.basis 

36 

37 def create_hamiltonian_operator(self): 

38 return LCAOHamiltonian(self.basis) 

39 

40 def create_eigensolver(self, hamiltonian): 

41 from gpaw.dft import DefaultEigensolver 

42 es = self.params.eigensolver 

43 if isinstance(es, DefaultEigensolver): 

44 if self.params.xc.name in ['HSE06', 'PBE0', 'EXX']: 

45 name = 'hybrid' 

46 else: 

47 name = 'lcao' 

48 es = es.from_param({'name': name, **es.params}) 

49 return es.build_lcao(self.basis, 

50 self.relpos_ac, 

51 self.grid.cell_cv, 

52 self.ibz.symmetries) 

53 

54 def read_ibz_wave_functions(self, reader): 

55 c = 1 

56 if reader.version >= 0 and reader.version < 4: 

57 c = reader.bohr**1.5 

58 

59 basis = self.create_basis_set() 

60 potential = self.create_potential_calculator() 

61 if 'coefficients' in reader.wave_functions: 

62 coefficients = reader.wave_functions.proxy('coefficients') 

63 coefficients.scale = c 

64 else: 

65 coefficients = None 

66 

67 ibzwfs = self.create_ibz_wave_functions(basis, potential, 

68 coefficients=coefficients) 

69 

70 # Set eigenvalues, occupations, etc.. 

71 self.read_wavefunction_values(reader, ibzwfs) 

72 return ibzwfs 

73 

74 def create_ibz_wave_functions(self, 

75 basis, 

76 potential, 

77 *, 

78 coefficients=None): 

79 ibzwfs, _ = create_lcao_ibzwfs( 

80 basis, 

81 self.ibz, self.communicators, self.setups, 

82 self.relpos_ac, self.grid, self.dtype, 

83 self.nbands, self.ncomponents, self.atomdist, self.nelectrons, 

84 coefficients) 

85 return ibzwfs 

86 

87 

88def create_lcao_ibzwfs(basis, 

89 ibz, communicators, setups, 

90 relpos_ac, grid, dtype, 

91 nbands, ncomponents, atomdist, nelectrons, 

92 coefficients=None): 

93 kpt_band_comm = communicators['D'] 

94 kpt_comm = communicators['k'] 

95 band_comm = communicators['b'] 

96 domain_comm = communicators['d'] 

97 

98 S_qMM, T_qMM, P_qaMi, tciexpansions, tci_derivatives = tci_helper( 

99 basis, ibz, domain_comm, band_comm, kpt_comm, 

100 relpos_ac, atomdist, 

101 grid, dtype, setups) 

102 

103 nao = setups.nao 

104 

105 def create_wfs(spin, q, k, kpt_c, weight): 

106 shape = (nbands, 2 * nao if ncomponents == 4 else nao) 

107 if coefficients is not None: 

108 C_nM = Matrix(*shape, 

109 dtype=dtype, 

110 dist=(band_comm, band_comm.size, 1)) 

111 n1, n2 = C_nM.dist.my_row_range() 

112 C_nM.data[:] = coefficients.proxy(spin, k)[n1:n2] 

113 else: 

114 C_nM = MatrixWithNoData(*shape, 

115 dtype=dtype, 

116 dist=(band_comm, band_comm.size, 1)) 

117 return LCAOWaveFunctions( 

118 setups=setups, 

119 tci_derivatives=tci_derivatives, 

120 basis=basis, 

121 C_nM=C_nM, 

122 S_MM=S_qMM[q], 

123 T_MM=T_qMM[q], 

124 P_aMi=P_qaMi[q], 

125 kpt_c=kpt_c, 

126 relpos_ac=relpos_ac, 

127 atomdist=atomdist, 

128 domain_comm=domain_comm, 

129 spin=spin, 

130 q=q, 

131 k=k, 

132 weight=weight, 

133 ncomponents=ncomponents) 

134 

135 ibzwfs = LCAOIBZWaveFunctions.create( 

136 ibz=ibz, 

137 ncomponents=ncomponents, 

138 create_wfs_func=create_wfs, 

139 kpt_comm=kpt_comm, 

140 kpt_band_comm=kpt_band_comm, 

141 comm=communicators['w']) 

142 ibzwfs.grid = grid # The TCI-stuff needs cell and pbc from somewhere ... 

143 return ibzwfs, tciexpansions 

144 

145 

146def tci_helper(basis, 

147 ibz, 

148 domain_comm, band_comm, kpt_comm, 

149 relpos_ac, atomdist, 

150 grid, 

151 dtype, 

152 setups): 

153 rank_k = ibz.ranks(kpt_comm) 

154 here_k = rank_k == kpt_comm.rank 

155 kpt_qc = ibz.kpt_kc[here_k] 

156 

157 tciexpansions = TCIExpansions.new_from_setups(setups) 

158 manytci = tciexpansions.get_manytci_calculator( 

159 setups, grid._gd, relpos_ac, 

160 kpt_qc, dtype, NullTimer()) 

161 

162 my_atom_indices = basis.my_atom_indices 

163 M1 = basis.Mstart 

164 M2 = basis.Mstop 

165 S0_qMM, T0_qMM = manytci.O_qMM_T_qMM(domain_comm, M1, M2, True) 

166 if dtype == complex: 

167 np.negative(S0_qMM.imag, S0_qMM.imag) 

168 np.negative(T0_qMM.imag, T0_qMM.imag) 

169 

170 P_aqMi = manytci.P_aqMi(my_atom_indices) 

171 P_qaMi = [{a: P_aqMi[a][q] for a in my_atom_indices} 

172 for q in range(len(S0_qMM))] 

173 

174 for a, P_qMi in P_aqMi.items(): 

175 dO_ii = setups[a].dO_ii 

176 for P_Mi, S_MM in zips(P_qMi, S0_qMM): 

177 S_MM += P_Mi[M1:M2].conj() @ dO_ii @ P_Mi.T 

178 domain_comm.sum(S0_qMM) 

179 

180 # self.atomic_correction= self.atomic_correction_cls.new_from_wfs(self) 

181 # self.atomic_correction.add_overlap_correction(newS_qMM) 

182 

183 nao = setups.nao 

184 

185 S_qMM = [Matrix(nao, nao, data=S_MM, 

186 dist=(band_comm, band_comm.size, 1)) for S_MM in S0_qMM] 

187 T_qMM = [Matrix(nao, nao, data=T_MM, 

188 dist=(band_comm, band_comm.size, 1)) for T_MM in T0_qMM] 

189 

190 for S_MM in S_qMM: 

191 S_MM.tril2full() 

192 for T_MM in T_qMM: 

193 T_MM.tril2full() 

194 

195 tci_derivatives = TCIDerivatives(manytci, atomdist, nao) 

196 

197 return S_qMM, T_qMM, P_qaMi, tciexpansions, tci_derivatives