Coverage for gpaw/new/pw/stress.py: 97%
86 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1"""PW-mode stress tensor calculation."""
2from __future__ import annotations
4from typing import TYPE_CHECKING
6import numpy as np
7from gpaw.core.atom_arrays import AtomArrays
8from gpaw.gpu import synchronize, as_np
9from gpaw.new.ibzwfs import IBZWaveFunctions
10from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions
11from gpaw.typing import Array2D
12from gpaw.core import PWArray
13from gpaw.utilities import as_real_dtype
14if TYPE_CHECKING:
15 from gpaw.new.pw.pot_calc import PlaneWavePotentialCalculator
18def calculate_stress(pot_calc: PlaneWavePotentialCalculator,
19 ibzwfs, density, potential,
20 vt_g: PWArray,
21 nt_g: PWArray,
22 dedtaut_g: PWArray | None) -> Array2D:
23 """Calculate symmetrized stress tensor."""
24 comm = ibzwfs.comm
25 xp = density.nt_sR.xp
26 dom = density.nt_sR.desc
28 ibzwfs.make_sure_wfs_are_read_from_gpw_file()
29 s_vv = get_wfs_stress(ibzwfs, potential.dH_asii)
30 s_vv += pot_calc.xc.stress_contribution(
31 ibzwfs, density, pot_calc.interpolate)
33 if ibzwfs.kpt_comm.rank == 0 and ibzwfs.band_comm.rank == 0:
34 vHt_h = potential.vHt_x
35 assert vHt_h is not None
36 pw = vHt_h.desc
37 G_Gv = xp.asarray(pw.G_plus_k_Gv)
38 vHt2_hz = vHt_h.data.view(float).reshape((len(G_Gv), 2))**2
39 s_vv += (xp.einsum('Gz, Gv, Gw -> vw', vHt2_hz, G_Gv, G_Gv) *
40 pw.dv / (2 * np.pi))
41 Q_aL = density.calculate_compensation_charge_coefficients()
42 s_vv += pot_calc.poisson_solver.stress_contribution(vHt_h, Q_aL)
43 if ibzwfs.domain_comm.rank == 0:
44 s_vv -= xp.eye(3) * potential.e_stress
45 s_vv += pot_calc.vbar_ag.stress_contribution(nt_g)
46 s_vv += density.nct_aX.stress_contribution(vt_g)
48 if dedtaut_g is not None:
49 s_vv += density.tauct_aX.stress_contribution(dedtaut_g)
51 s_vv = as_np(s_vv)
53 if xp is not np:
54 synchronize()
55 comm.sum(s_vv, 0)
57 vol = dom.volume
58 s_vv = 0.5 / vol * (s_vv + s_vv.T)
60 # Symmetrize:
61 sigma_vv = np.zeros((3, 3))
62 cell_cv = dom.cell_cv
63 icell_cv = dom.icell
64 rotation_scc = ibzwfs.ibz.symmetries.rotation_scc
65 for U_cc in rotation_scc:
66 M_vv = (icell_cv.T @ (U_cc @ cell_cv)).T
67 sigma_vv += M_vv.T @ s_vv @ M_vv
68 sigma_vv /= len(rotation_scc)
70 # Make sure all agree on the result (redundant calculation on
71 # different cores involving BLAS might give slightly different
72 # results):
74 sigma_vv += pot_calc.extensions_stress_contribution
75 comm.broadcast(sigma_vv, 0)
76 return sigma_vv
79def get_wfs_stress(ibzwfs: IBZWaveFunctions,
80 dH_asii: AtomArrays) -> Array2D:
81 xp = ibzwfs.xp
82 sigma_vv = xp.zeros((3, 3))
83 for wfs in ibzwfs:
84 assert isinstance(wfs, PWFDWaveFunctions)
85 occ_n = xp.asarray(wfs.weight * wfs.spin_degeneracy * wfs.myocc_n)
86 sigma_vv += get_kinetic_stress(wfs, occ_n)
87 sigma_vv += get_paw_stress(wfs, dH_asii, occ_n)
88 return sigma_vv
91def get_kinetic_stress(wfs: PWFDWaveFunctions,
92 occ_n) -> Array2D:
93 psit_nG = wfs.psit_nX
94 pw = psit_nG.desc
95 xp = psit_nG.xp
96 psit_nGz = psit_nG.data.view(
97 as_real_dtype(pw.dtype)).reshape(psit_nG.data.shape + (2,))
98 psit2_G = xp.einsum('n, nGz, nGz -> G', occ_n, psit_nGz, psit_nGz)
99 Gk_Gv = xp.asarray(pw.G_plus_k_Gv)
100 sigma_vv = xp.einsum('G, Gv, Gw -> vw', psit2_G, Gk_Gv, Gk_Gv)
101 x = pw.dv
102 if np.issubdtype(pw.dtype, np.floating):
103 x *= 2
104 return -x * sigma_vv
107def get_paw_stress(wfs: PWFDWaveFunctions,
108 dH_asii: AtomArrays,
109 occ_n) -> Array2D:
110 xp = wfs.xp
111 eig_n1 = xp.asarray(wfs.myeig_n[:, None])
112 a_ani = {}
113 s = 0.0
114 for a, P_ni in wfs.P_ani.items():
115 Pf_ni = P_ni * occ_n[:, None]
116 dH_ii = dH_asii[a][wfs.spin]
117 dS_ii = xp.asarray(wfs.setups[a].dO_ii)
118 a_ni = Pf_ni @ dH_ii - Pf_ni * eig_n1 @ dS_ii
119 s += xp.vdot(P_ni, a_ni)
120 a_ani[a] = 2 * a_ni.conj()
121 s_vv = wfs.pt_aiX.stress_contribution(wfs.psit_nX, a_ani)
122 return s_vv - float(s.real) * xp.eye(3)