Coverage for gpaw/lcaotddft/ksdecomposition.py: 66%
444 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
2from typing import NamedTuple
4from ase.units import Hartree, Bohr
6from ase.io.ulm import Reader
7from gpaw.io import Writer
8from gpaw.external import ConstantElectricField
9from gpaw.kpoint import KPoint
10from gpaw.lcaotddft.hamiltonian import KickHamiltonian
11from gpaw.lcaotddft.utilities import collect_MM
12from gpaw.lcaotddft.utilities import distribute_nM
13from gpaw.lcaotddft.utilities import read_uMM
14from gpaw.lcaotddft.utilities import write_uMM
15from gpaw.lcaotddft.utilities import read_uX, write_uX
16from gpaw.utilities.scalapack import \
17 pblas_simple_gemm, pblas_simple_hemm, scalapack_tri2full
18from gpaw.utilities.tools import tri2full
21def gauss_ij(energy_i, energy_j, sigma):
22 denergy_ij = energy_i[:, np.newaxis] - energy_j[np.newaxis, :]
23 norm = 1.0 / (sigma * np.sqrt(2 * np.pi))
24 return norm * np.exp(-0.5 * denergy_ij**2 / sigma**2)
27def get_bfs_maps(calc):
28 # Construct maps
29 # a_M: M -> atom index a
30 # l_M: M -> angular momentum l
31 a_M = []
32 l_M = []
33 M = 0
34 for a, sphere in enumerate(calc.wfs.basis_functions.sphere_a):
35 for j, spline in enumerate(sphere.spline_j):
36 l = spline.get_angular_momentum_number()
37 for _ in range(2 * l + 1):
38 a_M.append(a)
39 l_M.append(l)
40 M += 1
41 a_M = np.array(a_M)
42 l_M = np.array(l_M)
43 return a_M, l_M
46class DummyKsl(NamedTuple):
47 using_blacs: bool = False
50class KohnShamDecomposition:
51 version = 1
52 ulmtag = 'KSD'
53 readwrite_attrs = ['fermilevel', 'only_ia', 'w_p', 'f_p', 'ia_p',
54 'P_p', 'dm_vp', 'a_M', 'l_M']
56 def __init__(self, paw=None, filename=None):
57 self.filename = filename
58 self.has_initialized = False
59 self.reader = None
60 if paw is not None:
61 self.world = paw.world
62 self.log = paw.log
63 self.ksl = paw.wfs.ksl
64 self.kd = paw.wfs.kd
65 self.bd = paw.wfs.bd
66 self.kpt_u = paw.wfs.kpt_u
67 self.density = paw.density
68 self.comm = paw.comms['K']
70 if len(paw.wfs.kpt_u) > 1:
71 raise RuntimeError('K-points are not fully supported')
73 if filename is not None:
74 self.read(filename)
76 if paw is not None:
77 return
79 # Create a dummy KohnShamLayouts object and one Gamma k-point
80 # This is necessary to read attributes
81 ns, nk = self.reader.eig_un.shape[:2]
82 assert (ns, nk) == (1, 1), 'Spins and K-points not implemented'
83 self.kpt_u = [KPoint(weightk=1, weight=1, s=0, k=0, q=0)]
84 self.ksl = DummyKsl()
86 def initialize(self, paw, min_occdiff=1e-3, only_ia=True):
87 if self.has_initialized:
88 return
89 paw.initialize_positions()
90 # paw.set_positions()
92 assert self.bd.nbands == self.ksl.nao
93 self.only_ia = only_ia
95 if not self.ksl.using_blacs and self.bd.comm.size > 1:
96 raise RuntimeError('Band parallelization without scalapack '
97 'is not supported')
99 if self.kd.gamma:
100 self.C0_dtype = float
101 else:
102 self.C0_dtype = complex
104 # Take quantities
105 self.fermilevel = paw.wfs.fermi_level
106 self.S_uMM = []
107 self.C0_unM = []
108 self.eig_un = []
109 self.occ_un = []
110 for kpt in paw.wfs.kpt_u:
111 S_MM = kpt.S_MM
112 assert np.max(np.absolute(S_MM.imag)) == 0.0
113 S_MM = np.ascontiguousarray(S_MM.real)
114 if self.ksl.using_blacs:
115 scalapack_tri2full(self.ksl.mmdescriptor, S_MM)
116 self.S_uMM.append(S_MM)
118 C_nM = kpt.C_nM
119 if self.C0_dtype == float:
120 assert np.max(np.absolute(C_nM.imag)) == 0.0
121 C_nM = np.ascontiguousarray(C_nM.real)
122 C_nM = distribute_nM(self.ksl, C_nM)
123 self.C0_unM.append(C_nM)
125 eig_n = paw.wfs.collect_eigenvalues(kpt.k, kpt.s)
126 occ_n = paw.wfs.collect_occupations(kpt.k, kpt.s)
127 self.eig_un.append(eig_n)
128 self.occ_un.append(occ_n)
130 self.a_M, self.l_M = get_bfs_maps(paw)
131 self.atoms = paw.atoms
133 # TODO: do the rest of the function with K-points
135 # Construct p = (i, a) pairs
136 u = 0
137 eig_n = self.eig_un[u]
138 occ_n = self.occ_un[u]
139 C0_nM = self.C0_unM[u]
141 if self.comm.rank == 0:
142 Nn = self.bd.nbands
144 f_p = []
145 w_p = []
146 i_p = []
147 a_p = []
148 ia_p = []
149 i0 = 0
150 for i in range(i0, Nn):
151 if only_ia:
152 a0 = i + 1
153 else:
154 a0 = 0
155 for a in range(a0, Nn):
156 f = occ_n[i] - occ_n[a]
157 if only_ia and f < min_occdiff:
158 continue
159 w = eig_n[a] - eig_n[i]
160 f_p.append(f)
161 w_p.append(w)
162 i_p.append(i)
163 a_p.append(a)
164 ia_p.append((i, a))
165 f_p = np.array(f_p)
166 w_p = np.array(w_p)
167 i_p = np.array(i_p, dtype=int)
168 a_p = np.array(a_p, dtype=int)
169 ia_p = np.array(ia_p, dtype=int)
171 # Sort according to energy difference
172 p_s = np.argsort(w_p)
173 f_p = f_p[p_s]
174 w_p = w_p[p_s]
175 i_p = i_p[p_s]
176 a_p = a_p[p_s]
177 ia_p = ia_p[p_s]
179 Np = len(f_p)
180 P_p = []
181 for p in range(Np):
182 P = np.ravel_multi_index(ia_p[p], (Nn, Nn))
183 P_p.append(P)
184 P_p = np.array(P_p)
186 dm_vp = np.empty((3, Np), dtype=float)
188 for v in range(3):
189 direction = np.zeros(3, dtype=float)
190 direction[v] = 1.0
191 cef = ConstantElectricField(Hartree / Bohr, direction)
192 kick_hamiltonian = KickHamiltonian(paw.hamiltonian, paw.density,
193 cef)
194 dm_MM = paw.wfs.eigensolver.calculate_hamiltonian_matrix(
195 kick_hamiltonian, paw.wfs, paw.wfs.kpt_u[u],
196 add_kinetic=False, root=-1)
198 if self.ksl.using_blacs:
199 tmp_nM = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
200 pblas_simple_hemm(self.ksl.mmdescriptor,
201 self.ksl.mmdescriptor,
202 self.ksl.mmdescriptor,
203 dm_MM, C0_nM.conj(), tmp_nM,
204 side='R', uplo='L')
205 dm_nn = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
206 pblas_simple_gemm(self.ksl.mmdescriptor,
207 self.ksl.mmdescriptor,
208 self.ksl.mmdescriptor,
209 tmp_nM, C0_nM, dm_nn, transb='T')
210 else:
211 tri2full(dm_MM)
212 dm_nn = np.dot(C0_nM.conj(), np.dot(dm_MM, C0_nM.T))
214 dm_nn = collect_MM(self.ksl, dm_nn)
215 if self.comm.rank == 0:
216 dm_P = dm_nn.ravel()
217 dm_p = dm_P[P_p]
218 dm_vp[v] = dm_p
220 if self.comm.rank == 0:
221 self.w_p = w_p
222 self.f_p = f_p
223 self.ia_p = ia_p
224 self.P_p = P_p
225 self.dm_vp = dm_vp
227 self.has_initialized = True
229 def write(self, filename):
230 from ase.io.trajectory import write_atoms
232 self.log(f'{self.__class__.__name__}: Writing to {filename}')
233 writer = Writer(filename, self.world, mode='w',
234 tag=self.__class__.ulmtag)
235 writer.write(version=self.__class__.version)
237 write_atoms(writer.child('atoms'), self.atoms)
239 writer.write(ha=Hartree)
240 write_uMM(self.kd, self.ksl, writer, 'S_uMM', self.S_uMM)
241 write_uMM(self.kd, self.ksl, writer, 'C0_unM', self.C0_unM)
242 write_uX(self.kd, self.ksl.block_comm, writer, 'eig_un', self.eig_un)
243 write_uX(self.kd, self.ksl.block_comm, writer, 'occ_un', self.occ_un)
245 if self.comm.rank == 0:
246 for arg in self.readwrite_attrs:
247 writer.write(arg, getattr(self, arg))
249 writer.close()
251 def read(self, filename):
252 self.reader = Reader(filename)
253 tag = self.reader.get_tag()
254 if tag != self.__class__.ulmtag:
255 raise RuntimeError('Unknown tag %s' % tag)
256 self.version = self.reader.version
258 # Do lazy reading in __getattr__ only if/when
259 # the variables are required
260 self.has_initialized = True
262 def __getattr__(self, attr):
263 if attr in ['S_uMM', 'C0_unM']:
264 val = read_uMM(self.kpt_u, self.ksl, self.reader, attr)
265 setattr(self, attr, val)
266 return val
267 if attr in ['eig_un', 'occ_un']:
268 val = read_uX(self.kpt_u, self.reader, attr)
269 setattr(self, attr, val)
270 return val
271 if attr in ['C0S_unM']:
272 C0S_unM = []
273 for u, kpt in enumerate(self.kpt_u):
274 C0_nM = self.C0_unM[u]
275 S_MM = self.S_uMM[u]
276 if self.ksl.using_blacs:
277 C0S_nM = self.ksl.mmdescriptor.zeros(dtype=C0_nM.dtype)
278 pblas_simple_hemm(self.ksl.mmdescriptor,
279 self.ksl.mmdescriptor,
280 self.ksl.mmdescriptor,
281 S_MM, C0_nM, C0S_nM,
282 side='R', uplo='L')
283 else:
284 C0S_nM = np.dot(C0_nM, S_MM)
285 C0S_unM.append(C0S_nM)
286 setattr(self, attr, C0S_unM)
287 return C0S_unM
288 if attr in ['weight_Mn']:
289 assert self.world.size == 1
290 C2_nM = np.absolute(self.C0_unM[0])**2
291 val = C2_nM.T / np.sum(C2_nM, axis=1)
292 setattr(self, attr, val)
293 return val
295 try:
296 val = getattr(self.reader, attr)
297 if attr == 'atoms':
298 from ase.io.trajectory import read_atoms
299 val = read_atoms(val)
300 setattr(self, attr, val)
301 return val
302 except (KeyError, AttributeError):
303 pass
305 raise AttributeError('Attribute %s not defined in version %s' %
306 (repr(attr), repr(self.version)))
308 def distribute(self, comm):
309 self.comm = comm
310 N = comm.size
311 self.Np = len(self.P_p)
312 self.Nq = int(np.ceil(self.Np / float(N)))
313 self.NQ = self.Nq * N
314 self.w_q = self.distribute_p(self.w_p)
315 self.f_q = self.distribute_p(self.f_p)
316 self.dm_vq = self.distribute_xp(self.dm_vp)
318 def distribute_p(self, a_p, a_q=None, root=0):
319 if a_q is None:
320 a_q = np.zeros(self.Nq, dtype=a_p.dtype)
321 if self.comm.rank == root:
322 a_Q = np.append(a_p, np.zeros(self.NQ - self.Np, dtype=a_p.dtype))
323 else:
324 a_Q = None
325 self.comm.scatter(a_Q, a_q, root)
326 return a_q
328 def collect_q(self, a_q, root=0):
329 if self.comm.rank == root:
330 a_Q = np.zeros(self.NQ, dtype=a_q.dtype)
331 else:
332 a_Q = None
333 self.comm.gather(a_q, root, a_Q)
334 if self.comm.rank == root:
335 a_p = a_Q[:self.Np]
336 else:
337 a_p = None
338 return a_p
340 def distribute_xp(self, a_xp):
341 Nx = a_xp.shape[0]
342 a_xq = np.zeros((Nx, self.Nq), dtype=a_xp.dtype)
343 for x in range(Nx):
344 self.distribute_p(a_xp[x], a_xq[x])
345 return a_xq
347 def transform(self, rho_uMM, broadcast=False):
348 assert len(rho_uMM) == 1, 'K-points not implemented'
349 u = 0
350 rho_MM = np.ascontiguousarray(rho_uMM[u])
351 C0S_nM = self.C0S_unM[u].astype(rho_MM.dtype, copy=True)
352 # KS decomposition
353 if self.ksl.using_blacs:
354 tmp_nM = self.ksl.mmdescriptor.zeros(dtype=rho_MM.dtype)
355 pblas_simple_gemm(self.ksl.mmdescriptor,
356 self.ksl.mmdescriptor,
357 self.ksl.mmdescriptor,
358 C0S_nM, rho_MM, tmp_nM)
359 rho_nn = self.ksl.mmdescriptor.zeros(dtype=rho_MM.dtype)
360 pblas_simple_gemm(self.ksl.mmdescriptor,
361 self.ksl.mmdescriptor,
362 self.ksl.mmdescriptor,
363 tmp_nM, C0S_nM, rho_nn, transb='C')
364 else:
365 rho_nn = np.dot(np.dot(C0S_nM, rho_MM), C0S_nM.T.conj())
367 rho_nn = collect_MM(self.ksl, rho_nn)
368 if self.comm.rank == 0:
369 rho_P = rho_nn.ravel()
370 # Remove de-excitation terms
371 rho_p = rho_P[self.P_p]
372 if self.only_ia:
373 rho_p *= 2
374 else:
375 rho_p = None
377 if broadcast:
378 if self.comm.rank != 0:
379 rho_p = np.zeros_like(self.P_p, dtype=rho_MM.dtype)
380 self.comm.broadcast(rho_p, 0)
381 rho_up = [rho_p]
382 return rho_up
384 def ialims(self):
385 i_p = self.ia_p[:, 0]
386 a_p = self.ia_p[:, 1]
387 imin = np.min(i_p)
388 imax = np.max(i_p)
389 amin = np.min(a_p)
390 amax = np.max(a_p)
391 return imin, imax, amin, amax
393 def M_p_to_M_ia(self, M_p):
394 return self.M_ia_from_M_p(M_p)
396 def M_ia_from_M_p(self, M_p):
397 imin, imax, amin, amax = self.ialims()
398 M_ia = np.zeros((imax - imin + 1, amax - amin + 1), dtype=M_p.dtype)
399 for M, (i, a) in zip(M_p, self.ia_p):
400 M_ia[i - imin, a - amin] = M
401 return M_ia
403 def plot_matrix(self, M_p):
404 import matplotlib.pyplot as plt
405 M_ia = self.M_ia_from_M_p(M_p)
406 plt.imshow(M_ia, interpolation='none')
407 plt.xlabel('a')
408 plt.ylabel('i')
410 def get_dipole_moment_contributions(self, rho_up):
411 assert len(rho_up) == 1, 'K-points not implemented'
412 u = 0
413 rho_p = rho_up[u]
414 dmrho_vp = - self.dm_vp * rho_p
415 return dmrho_vp
417 def get_dipole_moment(self, rho_up):
418 assert len(rho_up) == 1, 'K-points not implemented'
419 u = 0
420 rho_p = rho_up[u]
421 dm_v = - np.dot(self.dm_vp, rho_p)
422 return dm_v
424 def get_density(self, wfs, rho_up, density='comp'):
425 from gpaw.lcaotddft.densitymatrix import get_density
427 if self.ksl.using_blacs:
428 raise NotImplementedError('Scalapack is not supported')
430 density_type = density
431 assert len(rho_up) == 1, 'K-points not implemented'
432 u = 0
433 rho_p = rho_up[u]
434 C0_nM = self.C0_unM[u]
436 rho_ia = self.M_ia_from_M_p(rho_p)
437 imin, imax, amin, amax = self.ialims()
438 C0_iM = C0_nM[imin:(imax + 1)]
439 C0_aM = C0_nM[amin:(amax + 1)]
441 rho_MM = np.dot(C0_iM.T, np.dot(rho_ia, C0_aM.conj()))
442 rho_MM = 0.5 * (rho_MM + rho_MM.T)
444 return get_density(rho_MM, wfs, self.density, density_type, u)
446 def get_contributions_table(self, weight_p, minweight=0.01,
447 zero_fermilevel=True):
448 assert weight_p.dtype == float
449 u = 0 # TODO
451 absweight_p = np.absolute(weight_p)
452 tot_weight = weight_p.sum()
453 propweight_p = weight_p / tot_weight * 100
454 tot_propweight = propweight_p.sum()
455 rest_weight = tot_weight
456 rest_propweight = tot_propweight
457 eig_n = self.eig_un[u].copy()
458 if zero_fermilevel:
459 eig_n -= self.fermilevel
461 txt = ''
462 txt += ('# %6s %4s(%8s) %4s(%8s) %12s %14s %8s\n' %
463 ('p', 'i', 'eV', 'a', 'eV', 'Ediff (eV)', 'weight', '%'))
464 p_s = np.argsort(absweight_p)[::-1]
465 for s, p in enumerate(p_s):
466 i, a = self.ia_p[p]
467 if absweight_p[p] < minweight:
468 break
469 txt += (' %6s %4d(%8.3f) -> %4d(%8.3f): %12.4f %14.4f %8.1f\n' %
470 (p, i, eig_n[i] * Hartree, a, eig_n[a] * Hartree,
471 self.w_p[p] * Hartree, weight_p[p], propweight_p[p]))
472 rest_weight -= weight_p[p]
473 rest_propweight -= propweight_p[p]
474 txt += (' %39s: %12s %+14.4f %8.1f\n' %
475 ('rest', '', rest_weight, rest_propweight))
476 txt += (' %39s: %12s %+14.4f %8.1f\n' %
477 ('total', '', tot_weight, tot_propweight))
478 return txt
480 def plot_TCM(self, weight_p, energy_o, energy_u, sigma,
481 zero_fermilevel=True, vmax='80%'):
482 from gpaw.lcaotddft.tcm import TCMPlotter
483 plotter = TCMPlotter(self, energy_o, energy_u, sigma, zero_fermilevel)
484 ax_tcm = plotter.plot_TCM(weight_p, vmax)
485 ax_occ_dos, ax_unocc_dos = plotter.plot_DOS()
486 return ax_tcm, ax_occ_dos, ax_unocc_dos
488 def get_TCM(self, weight_p, eig_n, energy_o, energy_u, sigma):
489 flt_p = self.filter_by_x_ia(eig_n, energy_o, energy_u, 8 * sigma)
490 weight_f = weight_p[flt_p]
491 G_fo = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_o, sigma)
492 G_fu = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_u, sigma)
493 tcm_ou = np.dot(G_fo.T * weight_f, G_fu)
494 return tcm_ou
496 def get_DOS(self, eig_n, energy_o, energy_u, sigma):
497 return self.get_weighted_DOS(1, eig_n, energy_o, energy_u, sigma)
499 def get_weighted_DOS(self, weight_n, eig_n, energy_o, energy_u, sigma):
500 if not isinstance(weight_n, np.ndarray):
501 # Assume float
502 weight_n = weight_n * np.ones_like(eig_n)
503 G_on = gauss_ij(energy_o, eig_n, sigma)
504 G_un = gauss_ij(energy_u, eig_n, sigma)
505 dos_o = np.dot(G_on, weight_n)
506 dos_u = np.dot(G_un, weight_n)
507 return dos_o, dos_u
509 def get_weight_n_by_l(self, l):
510 if isinstance(l, int):
511 weight_n = np.sum(self.weight_Mn[self.l_M == l], axis=0)
512 else:
513 weight_n = np.sum([self.get_weight_n_by_l(l_) for l_ in l],
514 axis=0)
515 return weight_n
517 def get_weight_n_by_a(self, a):
518 if isinstance(a, int):
519 weight_n = np.sum(self.weight_Mn[self.a_M == a], axis=0)
520 else:
521 weight_n = np.sum([self.get_weight_n_by_a(a_) for a_ in a],
522 axis=0)
523 return weight_n
525 def get_distribution_i(self, weight_p, energy_e, sigma,
526 zero_fermilevel=True):
527 eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
528 flt_p = self.filter_by_x_i(eig_n, energy_e, 8 * sigma)
529 weight_f = weight_p[flt_p]
530 G_fe = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_e, sigma)
531 dist_e = np.dot(G_fe.T, weight_f)
532 return dist_e
534 def get_distribution_a(self, weight_p, energy_e, sigma,
535 zero_fermilevel=True):
536 eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
537 flt_p = self.filter_by_x_a(eig_n, energy_e, 8 * sigma)
538 weight_f = weight_p[flt_p]
539 G_fe = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_e, sigma)
540 dist_e = np.dot(G_fe.T, weight_f)
541 return dist_e
543 def get_distribution_ia(self, weight_p, energy_o, energy_u, sigma,
544 zero_fermilevel=True):
545 """
546 Filter both i and a spaces as in TCM.
548 """
549 eig_n, fermilevel = self.get_eig_n(zero_fermilevel)
550 flt_p = self.filter_by_x_ia(eig_n, energy_o, energy_u, 8 * sigma)
551 weight_f = weight_p[flt_p]
552 G_fo = gauss_ij(eig_n[self.ia_p[flt_p, 0]], energy_o, sigma)
553 dist_o = np.dot(G_fo.T, weight_f)
554 G_fu = gauss_ij(eig_n[self.ia_p[flt_p, 1]], energy_u, sigma)
555 dist_u = np.dot(G_fu.T, weight_f)
556 return dist_o, dist_u
558 def get_distribution(self, weight_p, energy_e, sigma):
559 w_p = self.w_p * Hartree
560 flt_p = self.filter_by_x_p(w_p, energy_e, 8 * sigma)
561 weight_f = weight_p[flt_p]
562 G_fe = gauss_ij(w_p[flt_p], energy_e, sigma)
563 dist_e = np.dot(G_fe.T, weight_f)
564 return dist_e
566 def get_eig_n(self, zero_fermilevel=True):
567 u = 0 # TODO
568 eig_n = self.eig_un[u].copy()
569 if zero_fermilevel:
570 eig_n -= self.fermilevel
571 fermilevel = 0.0
572 else:
573 fermilevel = self.fermilevel
574 eig_n *= Hartree
575 fermilevel *= Hartree
576 return eig_n, fermilevel
578 def filter_by_x_p(self, x_p, energy_e, buf):
579 flt_p = np.logical_and((energy_e[0] - buf) <= x_p,
580 x_p <= (energy_e[-1] + buf))
581 return flt_p
583 def filter_by_x_i(self, x_n, energy_e, buf):
584 return self.filter_by_x_p(x_n[self.ia_p[:, 0]], energy_e, buf)
586 def filter_by_x_a(self, x_n, energy_e, buf):
587 return self.filter_by_x_p(x_n[self.ia_p[:, 1]], energy_e, buf)
589 def filter_by_x_ia(self, x_n, energy_o, energy_u, buf):
590 flti_p = self.filter_by_x_i(x_n, energy_o, buf)
591 flta_p = self.filter_by_x_a(x_n, energy_u, buf)
592 flt_p = np.logical_and(flti_p, flta_p)
593 return flt_p
595 def __del__(self):
596 if self.reader is not None:
597 self.reader.close()