Coverage for gpaw/pipekmezey/wannier_basic.py: 23%

173 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1""" Maximally localized Wannier Functions 

2 

3 Find the set of maximally localized Wannier functions 

4 using the spread functional of Marzari and Vanderbilt 

5 (PRB 56, 1997 page 12847). 

6 

7 this code is as in ASE but modified to use it with gpaw's wfs. 

8""" 

9 

10from time import time 

11from math import pi 

12import numpy as np 

13from ase.dft.kpoints import get_monkhorst_pack_size_and_offset 

14from ase.dft.wannier import calculate_weights, gram_schmidt 

15from ase.transport.tools import dagger 

16from ase.parallel import parprint 

17 

18dag = dagger 

19 

20 

21def random_orthogonal_matrix(dim, rng, real=False): 

22 """Generate a random orthogonal matrix""" 

23 

24 H = rng.random((dim, dim)) 

25 np.add(dag(H), H, H) 

26 np.multiply(.5, H, H) 

27 

28 if real: 

29 gram_schmidt(H) 

30 return H 

31 else: 

32 val, vec = np.linalg.eig(H) 

33 return np.dot(vec * np.exp(1.j * val), dag(vec)) 

34 

35 

36def md_min(func, step=.25, tolerance=1e-6, verbose=False, **kwargs): 

37 if verbose: 

38 parprint('Localize with step =', step, 

39 'and tolerance =', tolerance) 

40 t = -time() 

41 fvalueold = 0. 

42 fvalue = fvalueold + 10 

43 count = 0 

44 V = np.zeros(func.get_gradients().shape, dtype=complex) 

45 

46 while abs((fvalue - fvalueold) / fvalue) > tolerance: 

47 fvalueold = fvalue 

48 dF = func.get_gradients() 

49 V *= (dF * V.conj()).real > 0 

50 V += step * dF 

51 func.step(V, **kwargs) 

52 fvalue = func.get_function_value() 

53 

54 if fvalue < fvalueold: 

55 step *= 0.5 

56 count += 1 

57 func.niter = count 

58 

59 if verbose: 

60 parprint('MDmin: iter=%s, step=%s, value=%s' 

61 % (count, step, fvalue)) 

62 t += time() 

63 if verbose: 

64 parprint('%d iterations in %0.2f seconds(%0.2f ms/iter),' 

65 ' endstep = %s' 

66 % (count, t, t * 1000. / count, step)) 

67 

68 

69def get_atoms_object_from_wfs(wfs): 

70 from ase.units import Bohr 

71 from ase import Atoms 

72 

73 spos_ac = wfs.spos_ac 

74 cell_cv = wfs.gd.cell_cv 

75 positions = spos_ac * cell_cv.diagonal() * Bohr 

76 

77 string = '' 

78 for a, atoms in enumerate(wfs.setups): 

79 string += atoms.symbol 

80 

81 atoms = Atoms(string) 

82 atoms.positions = positions 

83 atoms.cell = cell_cv * Bohr 

84 

85 return atoms 

86 

87 

88class WannierLocalization: 

89 """Maximally localized Wannier Functions 

90 for n_occ only - for ODD calculations 

91 """ 

92 

93 def __init__(self, wfs, calc=None, spin=0, seed=None, verbose=False): 

94 from ase.dft.wannier import get_kklst, get_invkklst 

95 

96 # Bloch phase sign convention 

97 sign = -1 

98 self.wfs = wfs 

99 self.gd = self.wfs.gd 

100 self.ns = self.wfs.nspins 

101 self.dtype = wfs.dtype 

102 

103 if hasattr(self.wfs, 'mode'): 

104 self.mode = self.wfs.mode 

105 else: 

106 self.mode = None 

107 

108 if calc is not None: 

109 self.atoms = calc.atoms 

110 else: 

111 self.atoms = get_atoms_object_from_wfs(self.wfs) 

112 

113 # Determine nocc: integer occupations only 

114 k_rank, u = divmod(0 + len(self.wfs.kd.ibzk_kc) * spin, 

115 len(self.wfs.kpt_u)) 

116 

117 f_n = self.wfs.kpt_u[u].f_n 

118 self.nwannier = int(np.rint(f_n.sum()) / 

119 (3 - self.ns)) # No fractional occ 

120 

121 self.spin = spin 

122 self.verbose = verbose 

123 self.rng = np.random.default_rng(seed) 

124 self.kpt_kc = self.wfs.kd.bzk_kc 

125 assert len(self.wfs.kd.ibzk_kc) == len(self.kpt_kc) 

126 

127 self.kptgrid = \ 

128 get_monkhorst_pack_size_and_offset(self.kpt_kc)[0] 

129 self.kpt_kc *= sign 

130 

131 self.Nk = len(self.kpt_kc) 

132 self.unitcell_cc = self.atoms.get_cell() 

133 self.largeunitcell_cc = (self.unitcell_cc.T * self.kptgrid).T 

134 self.weight_d, self.Gdir_dc = \ 

135 calculate_weights(self.largeunitcell_cc) 

136 self.Ndir = len(self.weight_d) # Number of directions 

137 

138 # Get neighbor kpt list and inverse kpt list 

139 self.kklst_dk, k0_dkc = get_kklst(self.kpt_kc, self.Gdir_dc) 

140 self.invkklst_dk = get_invkklst(self.kklst_dk) 

141 

142 Nw = self.nwannier 

143 Z_dknn = np.zeros((self.Ndir, self.Nk, Nw, Nw), 

144 dtype=complex) 

145 self.Z_dkww = np.empty((self.Ndir, self.Nk, Nw, Nw), 

146 dtype=complex) 

147 

148 if self.mode == 'lcao' and self.wfs.kpt_u[0].psit_nG is None: 

149 self.wfs.initialize_wave_functions_from_lcao() 

150 

151 for d, dirG in enumerate(self.Gdir_dc): 

152 for k in range(self.Nk): 

153 k1 = self.kklst_dk[d, k] 

154 k0_c = k0_dkc[d, k] 

155 k_kc = self.wfs.kd.bzk_kc 

156 Gc = k_kc[k1] - k_kc[k] - k0_c 

157 # Det. kpt/spin 

158 kr, u = divmod(k + len(self.wfs.kd.ibzk_kc) * spin, 

159 len(self.wfs.kpt_u)) 

160 kr1, u1 = divmod(k1 + len(self.wfs.kd.ibzk_kc) * spin, 

161 len(self.wfs.kpt_u)) 

162 

163 if self.wfs.mode == 'pw': 

164 cmo = self.gd.zeros(Nw, dtype=self.wfs.dtype) 

165 cmo1 = self.gd.zeros(Nw, dtype=self.wfs.dtype) 

166 for i in range(Nw): 

167 cmo[i] = self.wfs._get_wave_function_array(u, i) 

168 cmo1[i] = self.wfs._get_wave_function_array(u1, i) 

169 else: 

170 cmo = self.wfs.kpt_u[u].psit_nG[:Nw] 

171 cmo1 = self.wfs.kpt_u[u1].psit_nG[:Nw] 

172 

173 e_G = np.exp(-2.j * pi * 

174 np.dot(np.indices(self.gd.n_c).T + 

175 self.gd.beg_c, 

176 Gc / self.gd.N_c).T) 

177 pw = (e_G * cmo.conj()).reshape((Nw, -1)) 

178 

179 Z_dknn[d, k] += \ 

180 np.inner(pw, cmo1.reshape((Nw, -1))) * self.gd.dv 

181 # PAW corrections 

182 P_ani1 = self.wfs.kpt_u[u1].P_ani 

183 spos_ac = self.atoms.get_scaled_positions() 

184 

185 for A, P_ni in self.wfs.kpt_u[u].P_ani.items(): 

186 dS_ii = self.wfs.setups[A].dO_ii 

187 P_n = P_ni[:Nw] 

188 P_n1 = P_ani1[A][:Nw] 

189 e = np.exp(-2.j * pi * np.dot(Gc, spos_ac[A])) 

190 

191 Z_dknn[d, k] += e * P_n.conj().dot( 

192 dS_ii.dot(P_n1.T)) 

193 

194 self.gd.comm.sum(Z_dknn) 

195 self.Z_dknn = Z_dknn.copy() 

196 

197 self.initialize() 

198 

199 def initialize(self): 

200 """Re-initialize current rotation matrix. 

201 

202 Keywords are identical to those of the constructor. 

203 """ 

204 Nw = self.nwannier 

205 

206 # Set U to random (orthogonal) matrix 

207 self.U_kww = np.zeros((self.Nk, Nw, Nw), self.dtype) 

208 

209 # for k in range(self.Nk): 

210 if self.dtype == float: 

211 real = True 

212 else: 

213 real = False 

214 self.U_kww[:] = random_orthogonal_matrix(Nw, self.rng, real=real) 

215 

216 self.update() 

217 

218 def update(self): 

219 

220 # Calculate the Zk matrix from the rotation matrix: 

221 # Zk = U^d[k] Zbloch U[k1] 

222 for d in range(self.Ndir): 

223 for k in range(self.Nk): 

224 k1 = self.kklst_dk[d, k] 

225 self.Z_dkww[d, k] = np.dot(dag(self.U_kww[k]), np.dot( 

226 self.Z_dknn[d, k], self.U_kww[k1])) 

227 

228 # Update the new Z matrix 

229 self.Z_dww = self.Z_dkww.sum(axis=1) / self.Nk 

230 

231 def get_centers(self, scaled=False): 

232 """Calculate the Wannier centers 

233 

234 :: 

235 

236 pos = L / 2pi * phase(diag(Z)) 

237 """ 

238 coord_wc = \ 

239 np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T / \ 

240 (2.0 * pi) % 1 

241 if not scaled: 

242 coord_wc = np.dot(coord_wc, self.largeunitcell_cc) 

243 return coord_wc 

244 

245 def localize(self, step=0.25, tolerance=1e-08, 

246 updaterot=True): 

247 """Optimize rotation to give maximal localization""" 

248 md_min(self, step, tolerance, verbose=self.verbose, 

249 updaterot=updaterot) 

250 

251 def get_function_value(self): 

252 """Calculate the value of the spread functional. 

253 

254 :: 

255 

256 Tr[|ZI|^2]=sum(I)sum(n) w_i|Z_(i)_nn|^2, 

257 

258 where w_i are weights.""" 

259 a_d = np.sum(np.abs(self.Z_dww.diagonal(0, 1, 2)) ** 2, 

260 axis=1) 

261 return np.dot(a_d, self.weight_d).real 

262 

263 def get_gradients(self): 

264 

265 Nw = self.nwannier 

266 dU = [] 

267 for k in range(self.Nk): 

268 Utemp_ww = np.zeros((Nw, Nw), complex) 

269 

270 for d, weight in enumerate(self.weight_d): 

271 if abs(weight) < 1.0e-6: 

272 continue 

273 

274 diagZ_w = self.Z_dww[d].diagonal() 

275 Zii_ww = np.repeat(diagZ_w, Nw).reshape(Nw, Nw) 

276 k2 = self.invkklst_dk[d, k] 

277 Z_kww = self.Z_dkww[d] 

278 

279 temp = Zii_ww.T * Z_kww[k].conj() - \ 

280 Zii_ww * Z_kww[k2].conj() 

281 Utemp_ww += weight * (temp - dag(temp)) 

282 dU.append(Utemp_ww.ravel()) 

283 

284 return np.concatenate(dU) 

285 

286 def step(self, dX, updaterot=True): 

287 Nw = self.nwannier 

288 Nk = self.Nk 

289 if updaterot: 

290 A_kww = dX[:Nk * Nw ** 2].reshape(Nk, Nw, Nw) 

291 for U, A in zip(self.U_kww, A_kww): 

292 H = -1.j * A.conj() 

293 epsilon, Z = np.linalg.eigh(H) 

294 # Z contains the eigenvectors as COLUMNS. 

295 # Since H = iA, dU = exp(-A) = exp(iH) = ZDZ^d 

296 dU = np.dot(Z * np.exp(1.j * epsilon), dag(Z)) 

297 if U.dtype == float: 

298 U[:] = np.dot(U, dU).real 

299 else: 

300 U[:] = np.dot(U, dU) 

301 

302 self.update()