Coverage for gpaw/pipekmezey/pipek_mezey_wannier.py: 93%

162 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1r""" 

2 Objective function class for 

3 Generalized Pipek-Mezey orbital localization. 

4 

5 Given a spin channel index the objective function is: 

6 __ __ 

7 \ \ | A |p 

8 P(W) = / / |Q(W)| Eq.1 

9 -- -- | ii | 

10 A i 

11 

12 where p is a penalty degree: p>1, p<1, not p=1, 

13 (note that p<1 corresponds to minimization) 

14 and 

15 __ 

16 A \ A 

17 Q(W) = / W* Q W Eq.2 

18 jj -- rj rs sj 

19 rs 

20 

21 rs run over occupied states only. 

22 

23 A 

24 Q can be defined with two methods: 

25 rs 

26 

27 Hirshfeld scheme: 'H' 

28 

29 A / * A 

30 Q = | Phi(r)w(r)Phi(r) dr Eq.4 

31 rs / r s 

32 

33 A 

34 with w(r) a weight function with center on atom A. 

35 A 

36 w(r) is constructed from simple and general gaussians. 

37 

38 and Wigner-Seitz scheme: 'W' 

39 

40 A / * A 

41 Q = | Phi(r)O(r)Phi(r) dr Eq.5 

42 rs / r s 

43 

44 A A B 

45 with O(r) = 1 if |r-R |>|r-R |, 0 otherwise 

46 

47 All integrals are performed over the course gd. 

48 

49""" 

50import numpy as np 

51from scipy.linalg import inv, sqrtm 

52from math import pi 

53from ase.transport.tools import dagger 

54from gpaw.pipekmezey.weightfunction import WeightFunc, WignerSeitz 

55from gpaw.pipekmezey.wannier_basic import md_min, get_atoms_object_from_wfs 

56from ase.dft.wannier import calculate_weights 

57from ase.dft.kpoints import get_monkhorst_pack_size_and_offset 

58from ase.parallel import world 

59 

60 

61def random_orthogonal(rng, s, dtype=float): 

62 # Make a random orthogonal matrix of dim s x s, 

63 # such that WW* = I = W*W 

64 w_r = rng.random((s, s)) 

65 if dtype == complex: 

66 w_r = w_r + 1.j * rng.random((s, s)) 

67 return w_r.dot(inv(sqrtm(w_r.T.conj().dot(w_r)))) 

68 

69 

70class PipekMezey: 

71 """ General Pipek-Mezey Wannier functions: 

72 J. Chem. Theory Comput. 2017, 13, 2, 460–474 

73 

74 Parameters 

75 ---------- 

76 wfs : GPAW wfs object 

77 calc : GPAW calculator object 

78 

79 method : string 

80 'W' Wigner-Seitz or 'H' Hirshfeld 

81 

82 penalty : int 

83 positive (int) value for maximization (localization) 

84 negative (int) value for minimization (delocalized) 

85 

86 spin : int 

87 spin channel index 

88 mu : float 

89 variance for Hirshfeld density 

90 dtype : dtype 

91 real or cmplx rotation matrix 

92 seed : int 

93 seed for random initial guess for unitary matrix 

94 ---------- 

95 

96 """ 

97 

98 def __init__(self, wfs=None, calc=None, 

99 method='W', penalty=2.0, spin=0, 

100 mu=None, dtype=None, seed=None): 

101 from ase.dft.wannier import get_kklst, get_invkklst 

102 

103 assert wfs or calc is not None 

104 

105 if calc is not None: 

106 self.wfs = calc.wfs 

107 else: 

108 self.wfs = wfs # CMOs 

109 

110 if hasattr(self.wfs, 'mode'): 

111 self.mode = self.wfs.mode 

112 else: 

113 self.mode = None 

114 

115 self.method = method # Charge partitioning scheme 

116 self.penalty = abs(penalty) # penalty exponent 

117 self.mu = mu # WF variance (if 'H') 

118 

119 self.gd = self.wfs.gd 

120 # Allow complex rotations 

121 if dtype is not None: 

122 self.dtype = dtype 

123 else: 

124 self.dtype = self.wfs.dtype 

125 

126 self.setups = self.wfs.setups 

127 

128 # Make atoms object from setups 

129 if calc is not None: 

130 self.atoms = calc.atoms 

131 else: 

132 self.atoms = get_atoms_object_from_wfs(self.wfs) 

133 

134 self.Na = len(self.atoms) 

135 self.ns = self.wfs.nspins 

136 self.spin = spin 

137 self.niter = 0 

138 self.rng = np.random.default_rng(seed) 

139 

140 # Determine nocc: integer occupations only 

141 k_rank, u = divmod(0 + len(self.wfs.kd.ibzk_kc) * spin, 

142 len(self.wfs.kpt_u)) 

143 

144 f_n = self.wfs.kpt_u[u].f_n 

145 self.nocc = 0 

146 while f_n[self.nocc] > 1e-10: 

147 self.nocc += 1 

148 

149 # Hold on to 

150 self.P = 0 

151 self.P_n = [] 

152 self.Qa_nn = np.zeros((self.Na, self.nocc, self.nocc)) 

153 

154 # kpts and dirs 

155 self.k_kc = self.wfs.kd.bzk_kc 

156 

157 assert len(self.wfs.kd.ibzk_kc) == len(self.k_kc) 

158 

159 self.kgd = get_monkhorst_pack_size_and_offset(self.k_kc)[0] 

160 self.k_kc *= -1 # Bloch phase sign conv. GPAW 

161 

162 # pbc-lattice 

163 self.Nk = len(self.k_kc) 

164 self.W_k = np.zeros((self.Nk, self.nocc, self.nocc), 

165 dtype=self.dtype) 

166 

167 # Expand cell to capture Bloch states 

168 largecell = (self.atoms.cell.T * self.kgd).T 

169 self.wd, self.Gd = calculate_weights(largecell) 

170 self.Nd = len(self.wd) 

171 

172 # Get neighbor kpt list and inverse kpt list 

173 self.lst_dk, k0_dk = get_kklst(self.k_kc, self.Gd) 

174 self.invlst_dk = get_invkklst(self.lst_dk) 

175 

176 # Using WFa and k-d lists make overlap matrix 

177 Qadk_nm = np.zeros((self.Na, 

178 self.Nd, 

179 self.Nk, 

180 self.nocc, self.nocc), complex) 

181 

182 if calc is not None and self.wfs.kpt_u[0].psit_nG is None: 

183 self.wfs.initialize_wave_functions_from_restart_file() 

184 

185 # initialize wfs array if lcao 

186 if self.mode == 'lcao' and self.wfs.kpt_u[0].psit_nG is None: 

187 self.wfs.initialize_wave_functions_from_lcao() 

188 

189 for d, dG in enumerate(self.Gd): 

190 for k in range(self.Nk): 

191 k1 = self.lst_dk[d, k] 

192 k0 = k0_dk[d, k] 

193 k_kc = self.wfs.kd.bzk_kc 

194 Gc = k_kc[k1] - k_kc[k] - k0 

195 # Det. kpt/spin 

196 kr, u = divmod(k + len(self.wfs.kd.ibzk_kc) * spin, 

197 len(self.wfs.kpt_u)) 

198 kr1, u1 = divmod(k1 + len(self.wfs.kd.ibzk_kc) * spin, 

199 len(self.wfs.kpt_u)) 

200 

201 if self.wfs.mode == 'pw': 

202 cmo = self.gd.zeros(self.nocc, dtype=self.wfs.dtype) 

203 cmo1 = self.gd.zeros(self.nocc, dtype=self.wfs.dtype) 

204 for i in range(self.nocc): 

205 cmo[i] = self.wfs._get_wave_function_array(u, i) 

206 cmo1[i] = self.wfs._get_wave_function_array(u1, i) 

207 else: 

208 cmo = self.wfs.kpt_u[u].psit_nG[:self.nocc] 

209 cmo1 = self.wfs.kpt_u[u1].psit_nG[:self.nocc] 

210 # Inner product 

211 e_G = np.exp(-2j * pi * 

212 np.dot(np.indices(self.gd.n_c).T + 

213 self.gd.beg_c, 

214 Gc / self.gd.N_c).T) 

215 # WFs per atom 

216 for atom in self.atoms: 

217 WF = self.get_weight_function_atom(atom.index) 

218 pw = (e_G * WF * cmo1) 

219 Qadk_nm[atom.index, d, k] += \ 

220 self.gd.integrate(np.asarray(cmo, dtype=complex), 

221 pw, 

222 global_integral=False) 

223 # PAW corrections 

224 P_ani1 = self.wfs.kpt_u[u1].P_ani 

225 

226 spos_ac = self.atoms.get_scaled_positions() 

227 

228 for A, P_ni in self.wfs.kpt_u[u].P_ani.items(): 

229 dS_ii = self.setups[A].dO_ii 

230 P_n = P_ni[:self.nocc] 

231 P_n1 = P_ani1[A][:self.nocc] 

232 # Phase factor is an approx. PRB 72, 125119 (2005) 

233 e = np.exp(-2j * pi * np.dot(Gc, spos_ac[A])) 

234 Qadk_nm[A, d, k] += \ 

235 e * P_n.conj().dot(dS_ii.dot(P_n1.T)) 

236 

237 # Sum over domains 

238 self.gd.comm.sum(Qadk_nm) 

239 self.Qadk_nm = Qadk_nm.copy() 

240 self.Qadk_nn = np.zeros_like(self.Qadk_nm) 

241 

242 # Initial W_k: Start from random WW*=I 

243 for k in range(self.Nk): 

244 self.W_k[k] = random_orthogonal(self.rng, self.nocc, 

245 dtype=self.dtype) 

246 if world is not None: 

247 world.broadcast(self.W_k, 0) 

248 

249 # Given all matrices, update 

250 self.update() 

251 self.initialized = True 

252 

253 def step(self, dX): 

254 No = self.nocc 

255 Nk = self.Nk 

256 

257 A_kww = dX[:Nk * No ** 2].reshape(Nk, No, No) 

258 for U, A in zip(self.W_k, A_kww): 

259 H = -1.j * A.conj() 

260 epsilon, Z = np.linalg.eigh(H) 

261 dU = np.dot(Z * np.exp(1.j * epsilon), dagger(Z)) 

262 if U.dtype == float: 

263 U[:] = np.dot(U, dU).real 

264 else: 

265 U[:] = np.dot(U, dU) 

266 self.update() 

267 

268 def get_weight_function_atom(self, index): 

269 if self.method == 'H': 

270 WFa = WeightFunc(self.gd, 

271 self.atoms, 

272 [index], 

273 mu=self.mu 

274 ).construct_weight_function() 

275 elif self.method == 'W': 

276 WFa = WignerSeitz(self.gd, 

277 self.atoms, 

278 index 

279 ).construct_weight_function() 

280 else: 

281 raise ValueError('check method') 

282 return WFa 

283 

284 def localize(self, step=0.25, tolerance=1e-8, verbose=False): 

285 md_min(self, step, tolerance, verbose) 

286 

287 def update(self): 

288 for a in range(self.Na): 

289 for d in range(self.Nd): 

290 for k in range(self.Nk): 

291 k1 = self.lst_dk[d, k] 

292 self.Qadk_nn[a, d, k] = \ 

293 np.dot(self.W_k[k].T.conj(), 

294 np.dot(self.Qadk_nm[a, d, k], 

295 self.W_k[k1])) 

296 # Update PCM 

297 self.Qad_nn = self.Qadk_nn.sum(axis=2) / self.Nk 

298 

299 def update_matrices(self): 

300 # Using new W_k rotate states 

301 for a in range(self.Na): 

302 for d in range(self.Nd): 

303 for k in range(self.Nk): 

304 k1 = self.lst_dk[d, k] 

305 self.Qadk_nn[a, d, k] = \ 

306 np.dot(self.W_k[k].T.conj(), 

307 np.dot(self.Qadk_nm[a, d, k], 

308 self.W_k[k1])) 

309 

310 def get_function_value(self): 

311 # Over k 

312 Qad_nn = np.sum(abs(self.Qadk_nn), axis=2) / self.Nk 

313 # Over d 

314 Qa_nn = 0 

315 self.P = 0 

316 for d in range(self.Nd): 

317 Qa_nn += Qad_nn[:, d] ** 2 * self.wd[d] 

318 # Over a and diag 

319 for a in range(self.Na): 

320 self.P += np.sum(Qa_nn[a].diagonal()) 

321 

322 self.P /= np.sum(self.wd) 

323 self.P_n.append(self.P) 

324 

325 return self.P 

326 

327 def get_gradients(self): 

328 No = self.nocc 

329 dW = [] 

330 

331 for k in range(self.Nk): 

332 Wtemp = np.zeros((No, No), complex) 

333 

334 for a in range(self.Na): 

335 for d, wd in enumerate(self.wd): 

336 diagQ = self.Qad_nn[a, d].diagonal() 

337 Qa_ii = np.repeat(diagQ, No).reshape(No, No) 

338 k2 = self.invlst_dk[d, k] 

339 Qk_nn = self.Qadk_nn[a, d] 

340 temp = Qa_ii.T * Qk_nn[k].conj() - \ 

341 Qa_ii * Qk_nn[k2].conj() 

342 Wtemp += wd * (temp - dagger(temp)) 

343 

344 dW.append(Wtemp.ravel()) 

345 

346 return np.concatenate(dW)