Coverage for gpaw/pipekmezey/wannier_basic.py: 23%
173 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1""" Maximally localized Wannier Functions
3 Find the set of maximally localized Wannier functions
4 using the spread functional of Marzari and Vanderbilt
5 (PRB 56, 1997 page 12847).
7 this code is as in ASE but modified to use it with gpaw's wfs.
8"""
10from time import time
11from math import pi
12import numpy as np
13from ase.dft.kpoints import get_monkhorst_pack_size_and_offset
14from ase.dft.wannier import calculate_weights, gram_schmidt
15from ase.transport.tools import dagger
16from ase.parallel import parprint
18dag = dagger
21def random_orthogonal_matrix(dim, rng, real=False):
22 """Generate a random orthogonal matrix"""
24 H = rng.random((dim, dim))
25 np.add(dag(H), H, H)
26 np.multiply(.5, H, H)
28 if real:
29 gram_schmidt(H)
30 return H
31 else:
32 val, vec = np.linalg.eig(H)
33 return np.dot(vec * np.exp(1.j * val), dag(vec))
36def md_min(func, step=.25, tolerance=1e-6, verbose=False, **kwargs):
37 if verbose:
38 parprint('Localize with step =', step,
39 'and tolerance =', tolerance)
40 t = -time()
41 fvalueold = 0.
42 fvalue = fvalueold + 10
43 count = 0
44 V = np.zeros(func.get_gradients().shape, dtype=complex)
46 while abs((fvalue - fvalueold) / fvalue) > tolerance:
47 fvalueold = fvalue
48 dF = func.get_gradients()
49 V *= (dF * V.conj()).real > 0
50 V += step * dF
51 func.step(V, **kwargs)
52 fvalue = func.get_function_value()
54 if fvalue < fvalueold:
55 step *= 0.5
56 count += 1
57 func.niter = count
59 if verbose:
60 parprint('MDmin: iter=%s, step=%s, value=%s'
61 % (count, step, fvalue))
62 t += time()
63 if verbose:
64 parprint('%d iterations in %0.2f seconds(%0.2f ms/iter),'
65 ' endstep = %s'
66 % (count, t, t * 1000. / count, step))
69def get_atoms_object_from_wfs(wfs):
70 from ase.units import Bohr
71 from ase import Atoms
73 spos_ac = wfs.spos_ac
74 cell_cv = wfs.gd.cell_cv
75 positions = spos_ac * cell_cv.diagonal() * Bohr
77 string = ''
78 for a, atoms in enumerate(wfs.setups):
79 string += atoms.symbol
81 atoms = Atoms(string)
82 atoms.positions = positions
83 atoms.cell = cell_cv * Bohr
85 return atoms
88class WannierLocalization:
89 """Maximally localized Wannier Functions
90 for n_occ only - for ODD calculations
91 """
93 def __init__(self, wfs, calc=None, spin=0, seed=None, verbose=False):
94 from ase.dft.wannier import get_kklst, get_invkklst
96 # Bloch phase sign convention
97 sign = -1
98 self.wfs = wfs
99 self.gd = self.wfs.gd
100 self.ns = self.wfs.nspins
101 self.dtype = wfs.dtype
103 if hasattr(self.wfs, 'mode'):
104 self.mode = self.wfs.mode
105 else:
106 self.mode = None
108 if calc is not None:
109 self.atoms = calc.atoms
110 else:
111 self.atoms = get_atoms_object_from_wfs(self.wfs)
113 # Determine nocc: integer occupations only
114 k_rank, u = divmod(0 + len(self.wfs.kd.ibzk_kc) * spin,
115 len(self.wfs.kpt_u))
117 f_n = self.wfs.kpt_u[u].f_n
118 self.nwannier = int(np.rint(f_n.sum()) /
119 (3 - self.ns)) # No fractional occ
121 self.spin = spin
122 self.verbose = verbose
123 self.rng = np.random.default_rng(seed)
124 self.kpt_kc = self.wfs.kd.bzk_kc
125 assert len(self.wfs.kd.ibzk_kc) == len(self.kpt_kc)
127 self.kptgrid = \
128 get_monkhorst_pack_size_and_offset(self.kpt_kc)[0]
129 self.kpt_kc *= sign
131 self.Nk = len(self.kpt_kc)
132 self.unitcell_cc = self.atoms.get_cell()
133 self.largeunitcell_cc = (self.unitcell_cc.T * self.kptgrid).T
134 self.weight_d, self.Gdir_dc = \
135 calculate_weights(self.largeunitcell_cc)
136 self.Ndir = len(self.weight_d) # Number of directions
138 # Get neighbor kpt list and inverse kpt list
139 self.kklst_dk, k0_dkc = get_kklst(self.kpt_kc, self.Gdir_dc)
140 self.invkklst_dk = get_invkklst(self.kklst_dk)
142 Nw = self.nwannier
143 Z_dknn = np.zeros((self.Ndir, self.Nk, Nw, Nw),
144 dtype=complex)
145 self.Z_dkww = np.empty((self.Ndir, self.Nk, Nw, Nw),
146 dtype=complex)
148 if self.mode == 'lcao' and self.wfs.kpt_u[0].psit_nG is None:
149 self.wfs.initialize_wave_functions_from_lcao()
151 for d, dirG in enumerate(self.Gdir_dc):
152 for k in range(self.Nk):
153 k1 = self.kklst_dk[d, k]
154 k0_c = k0_dkc[d, k]
155 k_kc = self.wfs.kd.bzk_kc
156 Gc = k_kc[k1] - k_kc[k] - k0_c
157 # Det. kpt/spin
158 kr, u = divmod(k + len(self.wfs.kd.ibzk_kc) * spin,
159 len(self.wfs.kpt_u))
160 kr1, u1 = divmod(k1 + len(self.wfs.kd.ibzk_kc) * spin,
161 len(self.wfs.kpt_u))
163 if self.wfs.mode == 'pw':
164 cmo = self.gd.zeros(Nw, dtype=self.wfs.dtype)
165 cmo1 = self.gd.zeros(Nw, dtype=self.wfs.dtype)
166 for i in range(Nw):
167 cmo[i] = self.wfs._get_wave_function_array(u, i)
168 cmo1[i] = self.wfs._get_wave_function_array(u1, i)
169 else:
170 cmo = self.wfs.kpt_u[u].psit_nG[:Nw]
171 cmo1 = self.wfs.kpt_u[u1].psit_nG[:Nw]
173 e_G = np.exp(-2.j * pi *
174 np.dot(np.indices(self.gd.n_c).T +
175 self.gd.beg_c,
176 Gc / self.gd.N_c).T)
177 pw = (e_G * cmo.conj()).reshape((Nw, -1))
179 Z_dknn[d, k] += \
180 np.inner(pw, cmo1.reshape((Nw, -1))) * self.gd.dv
181 # PAW corrections
182 P_ani1 = self.wfs.kpt_u[u1].P_ani
183 spos_ac = self.atoms.get_scaled_positions()
185 for A, P_ni in self.wfs.kpt_u[u].P_ani.items():
186 dS_ii = self.wfs.setups[A].dO_ii
187 P_n = P_ni[:Nw]
188 P_n1 = P_ani1[A][:Nw]
189 e = np.exp(-2.j * pi * np.dot(Gc, spos_ac[A]))
191 Z_dknn[d, k] += e * P_n.conj().dot(
192 dS_ii.dot(P_n1.T))
194 self.gd.comm.sum(Z_dknn)
195 self.Z_dknn = Z_dknn.copy()
197 self.initialize()
199 def initialize(self):
200 """Re-initialize current rotation matrix.
202 Keywords are identical to those of the constructor.
203 """
204 Nw = self.nwannier
206 # Set U to random (orthogonal) matrix
207 self.U_kww = np.zeros((self.Nk, Nw, Nw), self.dtype)
209 # for k in range(self.Nk):
210 if self.dtype == float:
211 real = True
212 else:
213 real = False
214 self.U_kww[:] = random_orthogonal_matrix(Nw, self.rng, real=real)
216 self.update()
218 def update(self):
220 # Calculate the Zk matrix from the rotation matrix:
221 # Zk = U^d[k] Zbloch U[k1]
222 for d in range(self.Ndir):
223 for k in range(self.Nk):
224 k1 = self.kklst_dk[d, k]
225 self.Z_dkww[d, k] = np.dot(dag(self.U_kww[k]), np.dot(
226 self.Z_dknn[d, k], self.U_kww[k1]))
228 # Update the new Z matrix
229 self.Z_dww = self.Z_dkww.sum(axis=1) / self.Nk
231 def get_centers(self, scaled=False):
232 """Calculate the Wannier centers
234 ::
236 pos = L / 2pi * phase(diag(Z))
237 """
238 coord_wc = \
239 np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T / \
240 (2.0 * pi) % 1
241 if not scaled:
242 coord_wc = np.dot(coord_wc, self.largeunitcell_cc)
243 return coord_wc
245 def localize(self, step=0.25, tolerance=1e-08,
246 updaterot=True):
247 """Optimize rotation to give maximal localization"""
248 md_min(self, step, tolerance, verbose=self.verbose,
249 updaterot=updaterot)
251 def get_function_value(self):
252 """Calculate the value of the spread functional.
254 ::
256 Tr[|ZI|^2]=sum(I)sum(n) w_i|Z_(i)_nn|^2,
258 where w_i are weights."""
259 a_d = np.sum(np.abs(self.Z_dww.diagonal(0, 1, 2)) ** 2,
260 axis=1)
261 return np.dot(a_d, self.weight_d).real
263 def get_gradients(self):
265 Nw = self.nwannier
266 dU = []
267 for k in range(self.Nk):
268 Utemp_ww = np.zeros((Nw, Nw), complex)
270 for d, weight in enumerate(self.weight_d):
271 if abs(weight) < 1.0e-6:
272 continue
274 diagZ_w = self.Z_dww[d].diagonal()
275 Zii_ww = np.repeat(diagZ_w, Nw).reshape(Nw, Nw)
276 k2 = self.invkklst_dk[d, k]
277 Z_kww = self.Z_dkww[d]
279 temp = Zii_ww.T * Z_kww[k].conj() - \
280 Zii_ww * Z_kww[k2].conj()
281 Utemp_ww += weight * (temp - dag(temp))
282 dU.append(Utemp_ww.ravel())
284 return np.concatenate(dU)
286 def step(self, dX, updaterot=True):
287 Nw = self.nwannier
288 Nk = self.Nk
289 if updaterot:
290 A_kww = dX[:Nk * Nw ** 2].reshape(Nk, Nw, Nw)
291 for U, A in zip(self.U_kww, A_kww):
292 H = -1.j * A.conj()
293 epsilon, Z = np.linalg.eigh(H)
294 # Z contains the eigenvectors as COLUMNS.
295 # Since H = iA, dU = exp(-A) = exp(iH) = ZDZ^d
296 dU = np.dot(Z * np.exp(1.j * epsilon), dag(Z))
297 if U.dtype == float:
298 U[:] = np.dot(U, dU).real
299 else:
300 U[:] = np.dot(U, dU)
302 self.update()