Coverage for gpaw/pipekmezey/pipek_mezey_wannier.py: 93%
162 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1r"""
2 Objective function class for
3 Generalized Pipek-Mezey orbital localization.
5 Given a spin channel index the objective function is:
6 __ __
7 \ \ | A |p
8 P(W) = / / |Q(W)| Eq.1
9 -- -- | ii |
10 A i
12 where p is a penalty degree: p>1, p<1, not p=1,
13 (note that p<1 corresponds to minimization)
14 and
15 __
16 A \ A
17 Q(W) = / W* Q W Eq.2
18 jj -- rj rs sj
19 rs
21 rs run over occupied states only.
23 A
24 Q can be defined with two methods:
25 rs
27 Hirshfeld scheme: 'H'
29 A / * A
30 Q = | Phi(r)w(r)Phi(r) dr Eq.4
31 rs / r s
33 A
34 with w(r) a weight function with center on atom A.
35 A
36 w(r) is constructed from simple and general gaussians.
38 and Wigner-Seitz scheme: 'W'
40 A / * A
41 Q = | Phi(r)O(r)Phi(r) dr Eq.5
42 rs / r s
44 A A B
45 with O(r) = 1 if |r-R |>|r-R |, 0 otherwise
47 All integrals are performed over the course gd.
49"""
50import numpy as np
51from scipy.linalg import inv, sqrtm
52from math import pi
53from ase.transport.tools import dagger
54from gpaw.pipekmezey.weightfunction import WeightFunc, WignerSeitz
55from gpaw.pipekmezey.wannier_basic import md_min, get_atoms_object_from_wfs
56from ase.dft.wannier import calculate_weights
57from ase.dft.kpoints import get_monkhorst_pack_size_and_offset
58from ase.parallel import world
61def random_orthogonal(rng, s, dtype=float):
62 # Make a random orthogonal matrix of dim s x s,
63 # such that WW* = I = W*W
64 w_r = rng.random((s, s))
65 if dtype == complex:
66 w_r = w_r + 1.j * rng.random((s, s))
67 return w_r.dot(inv(sqrtm(w_r.T.conj().dot(w_r))))
70class PipekMezey:
71 """ General Pipek-Mezey Wannier functions:
72 J. Chem. Theory Comput. 2017, 13, 2, 460–474
74 Parameters
75 ----------
76 wfs : GPAW wfs object
77 calc : GPAW calculator object
79 method : string
80 'W' Wigner-Seitz or 'H' Hirshfeld
82 penalty : int
83 positive (int) value for maximization (localization)
84 negative (int) value for minimization (delocalized)
86 spin : int
87 spin channel index
88 mu : float
89 variance for Hirshfeld density
90 dtype : dtype
91 real or cmplx rotation matrix
92 seed : int
93 seed for random initial guess for unitary matrix
94 ----------
96 """
98 def __init__(self, wfs=None, calc=None,
99 method='W', penalty=2.0, spin=0,
100 mu=None, dtype=None, seed=None):
101 from ase.dft.wannier import get_kklst, get_invkklst
103 assert wfs or calc is not None
105 if calc is not None:
106 self.wfs = calc.wfs
107 else:
108 self.wfs = wfs # CMOs
110 if hasattr(self.wfs, 'mode'):
111 self.mode = self.wfs.mode
112 else:
113 self.mode = None
115 self.method = method # Charge partitioning scheme
116 self.penalty = abs(penalty) # penalty exponent
117 self.mu = mu # WF variance (if 'H')
119 self.gd = self.wfs.gd
120 # Allow complex rotations
121 if dtype is not None:
122 self.dtype = dtype
123 else:
124 self.dtype = self.wfs.dtype
126 self.setups = self.wfs.setups
128 # Make atoms object from setups
129 if calc is not None:
130 self.atoms = calc.atoms
131 else:
132 self.atoms = get_atoms_object_from_wfs(self.wfs)
134 self.Na = len(self.atoms)
135 self.ns = self.wfs.nspins
136 self.spin = spin
137 self.niter = 0
138 self.rng = np.random.default_rng(seed)
140 # Determine nocc: integer occupations only
141 k_rank, u = divmod(0 + len(self.wfs.kd.ibzk_kc) * spin,
142 len(self.wfs.kpt_u))
144 f_n = self.wfs.kpt_u[u].f_n
145 self.nocc = 0
146 while f_n[self.nocc] > 1e-10:
147 self.nocc += 1
149 # Hold on to
150 self.P = 0
151 self.P_n = []
152 self.Qa_nn = np.zeros((self.Na, self.nocc, self.nocc))
154 # kpts and dirs
155 self.k_kc = self.wfs.kd.bzk_kc
157 assert len(self.wfs.kd.ibzk_kc) == len(self.k_kc)
159 self.kgd = get_monkhorst_pack_size_and_offset(self.k_kc)[0]
160 self.k_kc *= -1 # Bloch phase sign conv. GPAW
162 # pbc-lattice
163 self.Nk = len(self.k_kc)
164 self.W_k = np.zeros((self.Nk, self.nocc, self.nocc),
165 dtype=self.dtype)
167 # Expand cell to capture Bloch states
168 largecell = (self.atoms.cell.T * self.kgd).T
169 self.wd, self.Gd = calculate_weights(largecell)
170 self.Nd = len(self.wd)
172 # Get neighbor kpt list and inverse kpt list
173 self.lst_dk, k0_dk = get_kklst(self.k_kc, self.Gd)
174 self.invlst_dk = get_invkklst(self.lst_dk)
176 # Using WFa and k-d lists make overlap matrix
177 Qadk_nm = np.zeros((self.Na,
178 self.Nd,
179 self.Nk,
180 self.nocc, self.nocc), complex)
182 if calc is not None and self.wfs.kpt_u[0].psit_nG is None:
183 self.wfs.initialize_wave_functions_from_restart_file()
185 # initialize wfs array if lcao
186 if self.mode == 'lcao' and self.wfs.kpt_u[0].psit_nG is None:
187 self.wfs.initialize_wave_functions_from_lcao()
189 for d, dG in enumerate(self.Gd):
190 for k in range(self.Nk):
191 k1 = self.lst_dk[d, k]
192 k0 = k0_dk[d, k]
193 k_kc = self.wfs.kd.bzk_kc
194 Gc = k_kc[k1] - k_kc[k] - k0
195 # Det. kpt/spin
196 kr, u = divmod(k + len(self.wfs.kd.ibzk_kc) * spin,
197 len(self.wfs.kpt_u))
198 kr1, u1 = divmod(k1 + len(self.wfs.kd.ibzk_kc) * spin,
199 len(self.wfs.kpt_u))
201 if self.wfs.mode == 'pw':
202 cmo = self.gd.zeros(self.nocc, dtype=self.wfs.dtype)
203 cmo1 = self.gd.zeros(self.nocc, dtype=self.wfs.dtype)
204 for i in range(self.nocc):
205 cmo[i] = self.wfs._get_wave_function_array(u, i)
206 cmo1[i] = self.wfs._get_wave_function_array(u1, i)
207 else:
208 cmo = self.wfs.kpt_u[u].psit_nG[:self.nocc]
209 cmo1 = self.wfs.kpt_u[u1].psit_nG[:self.nocc]
210 # Inner product
211 e_G = np.exp(-2j * pi *
212 np.dot(np.indices(self.gd.n_c).T +
213 self.gd.beg_c,
214 Gc / self.gd.N_c).T)
215 # WFs per atom
216 for atom in self.atoms:
217 WF = self.get_weight_function_atom(atom.index)
218 pw = (e_G * WF * cmo1)
219 Qadk_nm[atom.index, d, k] += \
220 self.gd.integrate(np.asarray(cmo, dtype=complex),
221 pw,
222 global_integral=False)
223 # PAW corrections
224 P_ani1 = self.wfs.kpt_u[u1].P_ani
226 spos_ac = self.atoms.get_scaled_positions()
228 for A, P_ni in self.wfs.kpt_u[u].P_ani.items():
229 dS_ii = self.setups[A].dO_ii
230 P_n = P_ni[:self.nocc]
231 P_n1 = P_ani1[A][:self.nocc]
232 # Phase factor is an approx. PRB 72, 125119 (2005)
233 e = np.exp(-2j * pi * np.dot(Gc, spos_ac[A]))
234 Qadk_nm[A, d, k] += \
235 e * P_n.conj().dot(dS_ii.dot(P_n1.T))
237 # Sum over domains
238 self.gd.comm.sum(Qadk_nm)
239 self.Qadk_nm = Qadk_nm.copy()
240 self.Qadk_nn = np.zeros_like(self.Qadk_nm)
242 # Initial W_k: Start from random WW*=I
243 for k in range(self.Nk):
244 self.W_k[k] = random_orthogonal(self.rng, self.nocc,
245 dtype=self.dtype)
246 if world is not None:
247 world.broadcast(self.W_k, 0)
249 # Given all matrices, update
250 self.update()
251 self.initialized = True
253 def step(self, dX):
254 No = self.nocc
255 Nk = self.Nk
257 A_kww = dX[:Nk * No ** 2].reshape(Nk, No, No)
258 for U, A in zip(self.W_k, A_kww):
259 H = -1.j * A.conj()
260 epsilon, Z = np.linalg.eigh(H)
261 dU = np.dot(Z * np.exp(1.j * epsilon), dagger(Z))
262 if U.dtype == float:
263 U[:] = np.dot(U, dU).real
264 else:
265 U[:] = np.dot(U, dU)
266 self.update()
268 def get_weight_function_atom(self, index):
269 if self.method == 'H':
270 WFa = WeightFunc(self.gd,
271 self.atoms,
272 [index],
273 mu=self.mu
274 ).construct_weight_function()
275 elif self.method == 'W':
276 WFa = WignerSeitz(self.gd,
277 self.atoms,
278 index
279 ).construct_weight_function()
280 else:
281 raise ValueError('check method')
282 return WFa
284 def localize(self, step=0.25, tolerance=1e-8, verbose=False):
285 md_min(self, step, tolerance, verbose)
287 def update(self):
288 for a in range(self.Na):
289 for d in range(self.Nd):
290 for k in range(self.Nk):
291 k1 = self.lst_dk[d, k]
292 self.Qadk_nn[a, d, k] = \
293 np.dot(self.W_k[k].T.conj(),
294 np.dot(self.Qadk_nm[a, d, k],
295 self.W_k[k1]))
296 # Update PCM
297 self.Qad_nn = self.Qadk_nn.sum(axis=2) / self.Nk
299 def update_matrices(self):
300 # Using new W_k rotate states
301 for a in range(self.Na):
302 for d in range(self.Nd):
303 for k in range(self.Nk):
304 k1 = self.lst_dk[d, k]
305 self.Qadk_nn[a, d, k] = \
306 np.dot(self.W_k[k].T.conj(),
307 np.dot(self.Qadk_nm[a, d, k],
308 self.W_k[k1]))
310 def get_function_value(self):
311 # Over k
312 Qad_nn = np.sum(abs(self.Qadk_nn), axis=2) / self.Nk
313 # Over d
314 Qa_nn = 0
315 self.P = 0
316 for d in range(self.Nd):
317 Qa_nn += Qad_nn[:, d] ** 2 * self.wd[d]
318 # Over a and diag
319 for a in range(self.Na):
320 self.P += np.sum(Qa_nn[a].diagonal())
322 self.P /= np.sum(self.wd)
323 self.P_n.append(self.P)
325 return self.P
327 def get_gradients(self):
328 No = self.nocc
329 dW = []
331 for k in range(self.Nk):
332 Wtemp = np.zeros((No, No), complex)
334 for a in range(self.Na):
335 for d, wd in enumerate(self.wd):
336 diagQ = self.Qad_nn[a, d].diagonal()
337 Qa_ii = np.repeat(diagQ, No).reshape(No, No)
338 k2 = self.invlst_dk[d, k]
339 Qk_nn = self.Qadk_nn[a, d]
340 temp = Qa_ii.T * Qk_nn[k].conj() - \
341 Qa_ii * Qk_nn[k2].conj()
342 Wtemp += wd * (temp - dagger(temp))
344 dW.append(Wtemp.ravel())
346 return np.concatenate(dW)