Coverage for gpaw/lrtddft/kssingle.py: 95%
416 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""Kohn-Sham single particle excitations realated objects.
3"""
4import sys
5import json
6import numpy as np
7from copy import copy
9from ase.units import Bohr, Hartree, alpha
11import gpaw.mpi as mpi
12from gpaw.utilities import packed_index
13from gpaw.lrtddft.excitation import Excitation, ExcitationList, get_filehandle
14from gpaw.pair_density import PairDensity
15from gpaw.fd_operators import Gradient
16from gpaw.utilities.tools import coordinates
17from .kssrestrictor import KSSRestrictor
20class KSSingles(ExcitationList):
21 """Kohn-Sham single particle excitations
23 Input parameters:
25 calculator:
26 the calculator object after a ground state calculation
28 nspins:
29 number of spins considered in the calculation
30 Note: Valid only for unpolarised ground state calculation
32 eps:
33 Minimal occupation difference for a transition (default 0.001)
35 istart:
36 First occupied state to consider
37 jend:
38 Last unoccupied state to consider
39 energy_range:
40 The energy range [emin, emax] or emax for KS transitions to use as basis
41 """
42 def __init__(self,
43 restrict={},
44 log=None,
45 txt=None):
46 ExcitationList.__init__(self, log=log, txt=txt)
47 self.world = mpi.world
49 self.restrict = KSSRestrictor()
50 self.restrict.update(restrict)
52 def calculate(self, atoms, nspins=None):
53 calculator = atoms.calc
54 self.calculator = calculator
56 # LCAO calculation requires special actions
57 self.lcao = calculator.wfs.mode == 'lcao'
59 # deny hybrids as their empty states are wrong
60# gsxc = calculator.hamiltonian.xc
61# hybrid = hasattr(gsxc, 'hybrid') and gsxc.hybrid > 0.0
62# assert(not hybrid)
64 # ensure correctly initialized wave functions
65 calculator.converge_wave_functions()
66 self.world = calculator.wfs.world
68 # parallelization over bands not yet supported
69 assert calculator.wfs.bd.comm.size == 1
71 # do the evaluation
72 self.select(nspins)
74 trkm = self.get_trk()
75 self.log('KSS {} transitions (restrict={})'.format(
76 len(self), self.restrict))
77 self.log('KSS TRK sum %g (%g,%g,%g)' %
78 (np.sum(trkm) / 3., trkm[0], trkm[1], trkm[2]))
79 pol = self.get_polarizabilities(lmax=3)
80 self.log('KSS polarisabilities(l=0-3) %g, %g, %g, %g' %
81 tuple(pol.tolist()))
82 return self
84 @staticmethod
85 def emin_emax(energy_range):
86 emin = -sys.float_info.max
87 emax = sys.float_info.max
88 if energy_range is not None:
89 try:
90 emin, emax = energy_range
91 emin /= Hartree
92 emax /= Hartree
93 except TypeError:
94 emax = energy_range / Hartree
95 return emin, emax
97 def select(self, nspins=None):
98 """Select KSSingles according to the given criterium."""
100 # criteria
101 emin, emax = self.restrict.emin_emax()
102 istart = self.restrict['istart']
103 jend = self.restrict['jend']
104 eps = self.restrict['eps']
106 if not hasattr(self, 'calculator'): # I'm read from a file
107 # throw away all not needed entries
108 for i, ks in reversed(list(enumerate(self))):
109 if not self.restrict.is_good(ks):
110 del self[i]
111 return None
113 paw = self.calculator
114 wfs = paw.wfs
115 self.dtype = wfs.dtype
116 self.kpt_u = wfs.kpt_u
118 if not self.lcao and self.kpt_u[0].psit_nG is None:
119 raise RuntimeError('No wave functions in calculator!')
121 # here, we need to take care of the spins also for
122 # closed shell systems (Sz=0)
123 # vspin is the virtual spin of the wave functions,
124 # i.e. the spin used in the ground state calculation
125 # pspin is the physical spin of the wave functions
126 # i.e. the spin of the excited states
127 self.nvspins = wfs.nspins
128 self.npspins = wfs.nspins
129 fijscale = 1
130 ispins = [0]
131 nks = wfs.kd.nibzkpts * wfs.kd.nspins
132 if self.nvspins < 2:
133 if (nspins or 0) > self.nvspins:
134 self.npspins = nspins
135 fijscale = 0.5
136 ispins = [0, 1]
137 nks *= 2
139 kpt_comm = self.calculator.wfs.kd.comm
140 nbands = len(self.kpt_u[0].f_n)
142 # select
143 take = np.zeros((nks, nbands, nbands), dtype=int)
144 u = 0
145 for ispin in ispins:
146 for k in range(wfs.kd.nibzkpts):
147 q = k - wfs.kd.k0
148 for s in range(wfs.nspins):
149 if q >= 0 and q < wfs.kd.mynk:
150 kpt = wfs.kpt_qs[q][s]
151 for i in range(nbands):
152 for j in range(i + 1, nbands):
153 fij = (kpt.f_n[i] - kpt.f_n[j]) / kpt.weight
154 epsij = kpt.eps_n[j] - kpt.eps_n[i]
155 if (fij > eps and
156 epsij >= emin and epsij < emax and
157 i >= istart and j <= jend):
158 take[u, i, j] = 1
159 u += 1
160 kpt_comm.sum(take)
162 self.log()
163 self.log('Kohn-Sham single transitions')
164 self.log()
166 # calculate in parallel
167 u = 0
168 for ispin in ispins:
169 for k in range(wfs.kd.nibzkpts):
170 q = k - wfs.kd.k0
171 for s in range(wfs.kd.nspins):
172 for i in range(nbands):
173 for j in range(i + 1, nbands):
174 if take[u, i, j]:
175 if q >= 0 and q < wfs.kd.mynk:
176 kpt = wfs.kpt_qs[q][s]
177 pspin = max(kpt.s, ispin)
178 self.append(
179 KSSingle(i, j, pspin, kpt, paw,
180 fijscale=fijscale,
181 dtype=self.dtype))
182 else:
183 self.append(KSSingle(i, j, pspin=0,
184 kpt=None, paw=paw,
185 dtype=self.dtype))
186 u += 1
188 # distribute
189 for kss in self:
190 kss.distribute()
192 @classmethod
193 def read(cls, filename=None, fh=None, restrict={}, log=None):
194 """Read myself from a file"""
195 assert (filename is not None) or (fh is not None)
197 def fail(f):
198 raise RuntimeError(f.name + ' does not contain ' +
199 cls.__class__.__name__ + ' data')
200 if fh is None:
201 f = get_filehandle(cls, filename)
203 # there can be other information, i.e. the LrTDDFT header
204 try:
205 content = f.read()
206 f.seek(content.index('# KSSingles'))
207 del content
208 f.readline()
209 except ValueError:
210 fail(f)
211 else:
212 f = fh
213 # we assume to be at the right place and read the header
214 if not f.readline().strip() == '# KSSingles':
215 fail(f)
217 words = f.readline().split()
218 n = int(words[0])
219 kssl = cls(log=log)
220 if len(words) == 1:
221 # very old output style for real wave functions (finite systems)
222 kssl.dtype = float
223 restrict_from_file = {}
224 else:
225 if words[1].startswith('complex'):
226 kssl.dtype = complex
227 else:
228 kssl.dtype = float
229 restrict_from_file = json.loads(f.readline())
230 if not isinstance(restrict_from_file, dict): # old output style
231 restrict_from_file = {'eps': restrict_from_file}
232 kssl.npspins = 1
233 for i in range(n):
234 kss = KSSingle(string=f.readline(), dtype=kssl.dtype)
235 kssl.append(kss)
236 kssl.npspins = max(kssl.npspins, kss.pspin + 1)
238 if fh is None:
239 f.close()
241 kssl.update()
242 kssl.restrict.update(restrict_from_file)
243 if len(restrict):
244 kssl.restrict.update(restrict)
245 kssl.select()
247 return kssl
249 def update(self):
250 istart = self[0].i
251 jend = 0
252 npspins = 1
253 nvspins = 1
254 for kss in self:
255 istart = min(kss.i, istart)
256 jend = max(kss.j, jend)
257 if kss.pspin == 1:
258 npspins = 2
259 if kss.spin == 1:
260 nvspins = 2
261 self.restrict.update({'istart': istart, 'jend': jend})
262 self.npspins = npspins
263 self.nvspins = nvspins
265 if hasattr(self, 'energies'):
266 del self.energies
268 def set_arrays(self):
269 if hasattr(self, 'energies'):
270 return
271 energies = []
272 fij = []
273 me = []
274 mur = []
275 muv = []
276 magn = []
277 for k in self:
278 energies.append(k.energy)
279 fij.append(k.fij)
280 me.append(k.me)
281 mur.append(k.mur)
282 if k.muv is not None:
283 muv.append(k.muv)
284 if k.magn is not None:
285 magn.append(k.magn)
286 self.energies = np.array(energies)
287 self.fij = np.array(fij)
288 self.me = np.array(me)
289 self.mur = np.array(mur)
290 if len(muv):
291 self.muv = np.array(muv)
292 else:
293 self.muv = None
294 if len(magn):
295 self.magn = np.array(magn)
296 else:
297 self.magn = None
299 def write(self, filename=None, fh=None):
300 """Write current state to a file.
302 'filename' is the filename. If the filename ends in .gz,
303 the file is automatically saved in compressed gzip format.
305 'fh' is a filehandle. This can be used to write into already
306 opened files.
307 """
308 if self.world.rank != 0:
309 return
311 if fh is None:
312 f = get_filehandle(self, filename, mode='w')
313 else:
314 f = fh
316 f.write('# KSSingles\n')
317 f.write(f'{len(self)} {np.dtype(self.dtype)}\n')
318 f.write(json.dumps(self.restrict.values) + '\n')
319 for kss in self:
320 f.write(kss.outstring())
321 if fh is None:
322 f.close()
324 def overlap(self, ov_nn, other):
325 """Matrix element overlaps determined from wave function overlaps.
327 Parameters
328 ----------
329 ov_nn: array
330 Wave function overlap factors from a displaced calculator.
331 Index 0 corresponds to our own wavefunctions conjugated and
332 index 1 to the others' wavefunctions
334 Returns
335 -------
336 ov_pp: array
337 Overlap corresponding to matrix elements.
338 Index 0 corresponds to our own matrix elements conjugated and
339 index 1 to the others' matrix elements
340 """
341 n0 = len(self)
342 n1 = len(other)
343 ov_pp = np.zeros((n0, n1), dtype=ov_nn.dtype)
344 i1_p = [ex.i for ex in other]
345 j1_p = [ex.j for ex in other]
346 for p0, ex0 in enumerate(self):
347 ov_pp[p0, :] = ov_nn[ex0.i, i1_p].conj() * ov_nn[ex0.j, j1_p]
348 return ov_pp
351class KSSingle(Excitation, PairDensity):
352 """Single Kohn-Sham transition containing all its indices
354 pspin=physical spin
355 spin=virtual spin, i.e. spin in the ground state calc.
356 kpt=the Kpoint object
357 fijscale=weight for the occupation difference::
358 me = sqrt(fij*epsij) * <i|r|j>
359 mur = - <i|r|a>
360 muv = - <i|nabla|a>/omega_ia with omega_ia>0
361 magn = <i|[r x nabla]|a> / (2 m_e c)
362 """
364 def __init__(self, iidx=None, jidx=None, pspin=None, kpt=None,
365 paw=None, string=None, fijscale=1, dtype=float):
366 """
367 iidx: index of occupied state
368 jidx: index of empty state
369 pspin: physical spin
370 kpt: kpoint object,
371 paw: calculator,
372 string: string to be initialized from
373 fijscale:
374 dtype: dtype of matrix elements
375 """
376 if string is not None:
377 self.fromstring(string, dtype)
378 return None
380 # normal entry
382 PairDensity.__init__(self, paw)
383 PairDensity.initialize(self, kpt, iidx, jidx)
385 self.pspin = pspin
387 self.energy = 0.0
388 self.fij = 0.0
390 self.me = np.zeros((3), dtype=dtype)
391 self.mur = np.zeros((3), dtype=dtype)
392 self.muv = np.zeros((3), dtype=dtype)
393 self.magn = np.zeros((3), dtype=dtype)
395 self.kpt_comm = paw.wfs.kd.comm
397 # leave empty if not my kpt
398 if kpt is None:
399 return
401 wfs = paw.wfs
402 gd = wfs.gd
404 self.energy = kpt.eps_n[jidx] - kpt.eps_n[iidx]
405 self.fij = (kpt.f_n[iidx] - kpt.f_n[jidx]) * fijscale
407 # calculate matrix elements -----------
409 # length form ..........................
411 # course grid contribution
412 # <i|r|j> is the negative of the dipole moment (because of negative
413 # e- charge)
414 me = - gd.calculate_dipole_moment(self.get())
416 # augmentation contributions
417 ma = np.zeros(me.shape, dtype=dtype)
418 pos_av = paw.atoms.get_positions() / Bohr
419 for a, P_ni in kpt.P_ani.items():
420 Ra = pos_av[a]
421 Pi_i = P_ni[self.i].conj()
422 Pj_i = P_ni[self.j]
423 Delta_pL = wfs.setups[a].Delta_pL
424 ni = len(Pi_i)
425 ma0 = 0
426 ma1 = np.zeros(me.shape, dtype=me.dtype)
427 for i in range(ni):
428 for j in range(ni):
429 pij = Pi_i[i] * Pj_i[j]
430 ij = packed_index(i, j, ni)
431 # L=0 term
432 ma0 += Delta_pL[ij, 0] * pij
433 # L=1 terms
434 if wfs.setups[a].lmax >= 1:
435 # see spherical_harmonics.py for
436 # L=1:y L=2:z; L=3:x
437 ma1 += np.array([Delta_pL[ij, 3], Delta_pL[ij, 1],
438 Delta_pL[ij, 2]]) * pij
439 ma += np.sqrt(4 * np.pi / 3) * ma1 + Ra * np.sqrt(4 * np.pi) * ma0
440 gd.comm.sum(ma)
442 self.me = np.sqrt(self.energy * self.fij) * (me + ma)
443 self.mur = - (me + ma)
445 # velocity form .............................
447 if self.lcao:
448 self.wfi = _get_and_distribute_wf(wfs, iidx, kpt.k, pspin)
449 self.wfj = _get_and_distribute_wf(wfs, jidx, kpt.k, pspin)
451 me = np.zeros(self.mur.shape, dtype=dtype)
453 # get derivatives
454 dtype = self.wfj.dtype
455 dwfj_cg = gd.empty((3), dtype=dtype)
456 if not hasattr(gd, 'ddr'):
457 gd.ddr = [Gradient(gd, c, dtype=dtype, n=2).apply
458 for c in range(3)]
459 for c in range(3):
460 gd.ddr[c](self.wfj, dwfj_cg[c], kpt.phase_cd)
461 me[c] = gd.integrate(self.wfi.conj() * dwfj_cg[c])
463 # XXX is this the best choice, maybe center of mass?
464 origin = 0.5 * np.diag(paw.wfs.gd.cell_cv)
466 # augmentation contributions
468 # <psi_i|grad|psi_j>
469 ma = np.zeros(me.shape, dtype=me.dtype)
470 # Ra x <psi_i|grad|psi_j> for magnetic transition dipole
471 mRa = np.zeros(me.shape, dtype=me.dtype)
472 for a, P_ni in kpt.P_ani.items():
473 Pi_i = P_ni[self.i].conj()
474 Pj_i = P_ni[self.j]
475 nabla_iiv = paw.wfs.setups[a].nabla_iiv
476 ma_c = np.zeros(me.shape, dtype=me.dtype)
477 for c in range(3):
478 for i1, Pi in enumerate(Pi_i):
479 for i2, Pj in enumerate(Pj_i):
480 ma_c[c] += Pi * Pj * nabla_iiv[i1, i2, c]
481 mRa += np.cross(paw.atoms[a].position / Bohr - origin, ma_c)
482 ma += ma_c
483 gd.comm.sum(ma)
484 gd.comm.sum(mRa)
486 self.muv = - (me + ma) / self.energy
488 # magnetic transition dipole ................
490 # m_ij = -(1/2c) <i|L|j> = i/2c <i|r x p|j>
491 # see Autschbach et al., J. Chem. Phys., 116, 6930 (2002)
493 r_cg, r2_g = coordinates(gd, origin=origin)
494 magn = np.zeros(me.shape, dtype=dtype)
496 # <psi_i|r x grad|psi_j>
497 wfi_g = self.wfi.conj()
498 for ci in range(3):
499 cj = (ci + 1) % 3
500 ck = (ci + 2) % 3
501 magn[ci] = gd.integrate(wfi_g * r_cg[cj] * dwfj_cg[ck] -
502 wfi_g * r_cg[ck] * dwfj_cg[cj])
504 # augmentation contributions
505 # <psi_i| r x nabla |psi_j>
506 # = <psi_i| (r - Ra + Ra) x nabla |psi_j>
507 # = <psi_i| (r - Ra) x nabla |psi_j> + Ra x <psi_i| nabla |psi_j>
509 ma = np.zeros(magn.shape, dtype=magn.dtype)
510 for a, P_ni in kpt.P_ani.items():
511 Pi_i = P_ni[self.i].conj()
512 Pj_i = P_ni[self.j]
513 rxnabla_iiv = paw.wfs.setups[a].rxnabla_iiv
514 for c in range(3):
515 for i1, Pi in enumerate(Pi_i):
516 for i2, Pj in enumerate(Pj_i):
517 ma[c] += Pi * Pj * rxnabla_iiv[i1, i2, c]
518 gd.comm.sum(ma)
520 self.magn = alpha / 2. * (magn + ma + mRa)
522 def distribute(self):
523 """Distribute results to all cores."""
524 self.spin = self.kpt_comm.sum_scalar(self.spin)
525 self.pspin = self.kpt_comm.sum_scalar(self.pspin)
526 self.k = self.kpt_comm.sum_scalar(self.k)
527 self.weight = self.kpt_comm.sum_scalar(self.weight)
528 self.energy = self.kpt_comm.sum_scalar(self.energy)
529 self.fij = self.kpt_comm.sum_scalar(self.fij)
531 self.kpt_comm.sum(self.me)
532 self.kpt_comm.sum(self.mur)
533 self.kpt_comm.sum(self.muv)
534 self.kpt_comm.sum(self.magn)
536 def __add__(self, other):
537 """Add two KSSingles"""
538 result = copy(self)
539 result.me = self.me + other.me
540 result.mur = self.mur + other.mur
541 result.muv = self.muv + other.muv
542 result.magn = self.magn + other.magn
543 return result
545 def __sub__(self, other):
546 """Subtract two KSSingles"""
547 result = copy(self)
548 result.me = self.me - other.me
549 result.mur = self.mur - other.mur
550 result.muv = self.muv - other.muv
551 result.magn = self.magn - other.magn
552 return result
554 def __rmul__(self, x):
555 return self.__mul__(x)
557 def __mul__(self, x):
558 """Multiply a KSSingle with a number"""
559 assert isinstance(x, (float, int))
560 result = copy(self)
561 result.me = self.me * x
562 result.mur = self.mur * x
563 result.muv = self.muv * x
564 result.magn = self.magn * x
565 return result
567 def __truediv__(self, x):
568 return self.__mul__(1. / x)
570 __div__ = __truediv__
572 def fromstring(self, string, dtype=float):
573 l = string.split()
574 self.i = int(l.pop(0))
575 self.j = int(l.pop(0))
576 self.pspin = int(l.pop(0))
577 self.spin = int(l.pop(0))
578 if dtype == float:
579 self.k = 0
580 self.weight = 1
581 else:
582 self.k = int(l.pop(0))
583 self.weight = float(l.pop(0))
584 self.energy = float(l.pop(0))
585 self.fij = float(l.pop(0))
586 self.mur = np.array([dtype(l.pop(0)) for i in range(3)])
587 self.me = - self.mur * np.sqrt(self.energy * self.fij)
588 self.muv = self.magn = None
589 if len(l):
590 self.muv = np.array([dtype(l.pop(0)) for i in range(3)])
591 if len(l):
592 self.magn = np.array([dtype(l.pop(0)) for i in range(3)])
593 return None
595 def outstring(self):
596 if self.mur.dtype == float:
597 string = '{:d} {:d} {:d} {:d} {:.10g} {:f}'.format(
598 self.i, self.j, self.pspin, self.spin, self.energy, self.fij)
599 else:
600 string = (
601 '{:d} {:d} {:d} {:d} {:d} {:.10g} {:g} {:g}'.format(
602 self.i, self.j, self.pspin, self.spin, self.k,
603 self.weight, self.energy, self.fij))
604 string += ' '
606 def format_me(me):
607 string = ''
608 if me.dtype == float:
609 for m in me:
610 string += f' {m:.5e}'
611 else:
612 for m in me:
613 string += ' {0.real:.5e}{0.imag:+.5e}j'.format(m)
614 return string
616 string += ' ' + format_me(self.mur)
617 if self.muv is not None:
618 string += ' ' + format_me(self.muv)
619 if self.magn is not None:
620 string += ' ' + format_me(self.magn)
621 string += '\n'
623 return string
625 def __str__(self):
626 string = '# <KSSingle> %d->%d %d(%d) eji=%g[eV]' % \
627 (self.i, self.j, self.pspin, self.spin,
628 self.energy * Hartree)
629 if self.me.dtype == float:
630 string += f' ({self.me[0]:g},{self.me[1]:g},{self.me[2]:g})'
631 else:
632 string += f' kpt={self.k:d} w={self.weight:g}'
633 string += ' ('
634 # use velocity form
635 s = - np.sqrt(self.energy * self.fij)
636 for c, m in enumerate(s * self.me):
637 string += '{0.real:.5e}{0.imag:+.5e}j'.format(m)
638 if c < 2:
639 string += ','
640 string += ')'
641 return string
643 def __eq__(self, other):
644 """KSSingles are considred equal when their indices are equal."""
645 return (self.pspin == other.pspin and self.k == other.k and
646 self.i == other.i and self.j == other.j)
648 def __hash__(self):
649 """Hash similar to __eq__"""
650 if not hasattr(self, 'hash'):
651 self.hash = hash((self.spin, self.k, self.i, self.j))
652 return self.hash
654 #
655 # User interface: ##
656 #
658 def get_weight(self):
659 return self.fij
662def _get_and_distribute_wf(wfs, n, k, s):
663 gd = wfs.gd
664 wf = wfs.get_wave_function_array(n=n, k=k, s=s, realspace=True,
665 periodic=False)
666 if wfs.world.rank != 0:
667 wf = gd.empty(dtype=wfs.dtype, global_array=True)
668 wf = np.ascontiguousarray(wf)
669 wfs.world.broadcast(wf, 0)
670 wfd = gd.empty(dtype=wfs.dtype, global_array=False)
671 wfd = gd.distribute(wf)
672 return wfd