Coverage for gpaw/new/pw/stress.py: 97%

86 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1"""PW-mode stress tensor calculation.""" 

2from __future__ import annotations 

3 

4from typing import TYPE_CHECKING 

5 

6import numpy as np 

7from gpaw.core.atom_arrays import AtomArrays 

8from gpaw.gpu import synchronize, as_np 

9from gpaw.new.ibzwfs import IBZWaveFunctions 

10from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

11from gpaw.typing import Array2D 

12from gpaw.core import PWArray 

13from gpaw.utilities import as_real_dtype 

14if TYPE_CHECKING: 

15 from gpaw.new.pw.pot_calc import PlaneWavePotentialCalculator 

16 

17 

18def calculate_stress(pot_calc: PlaneWavePotentialCalculator, 

19 ibzwfs, density, potential, 

20 vt_g: PWArray, 

21 nt_g: PWArray, 

22 dedtaut_g: PWArray | None) -> Array2D: 

23 """Calculate symmetrized stress tensor.""" 

24 comm = ibzwfs.comm 

25 xp = density.nt_sR.xp 

26 dom = density.nt_sR.desc 

27 

28 ibzwfs.make_sure_wfs_are_read_from_gpw_file() 

29 s_vv = get_wfs_stress(ibzwfs, potential.dH_asii) 

30 s_vv += pot_calc.xc.stress_contribution( 

31 ibzwfs, density, pot_calc.interpolate) 

32 

33 if ibzwfs.kpt_comm.rank == 0 and ibzwfs.band_comm.rank == 0: 

34 vHt_h = potential.vHt_x 

35 assert vHt_h is not None 

36 pw = vHt_h.desc 

37 G_Gv = xp.asarray(pw.G_plus_k_Gv) 

38 vHt2_hz = vHt_h.data.view(float).reshape((len(G_Gv), 2))**2 

39 s_vv += (xp.einsum('Gz, Gv, Gw -> vw', vHt2_hz, G_Gv, G_Gv) * 

40 pw.dv / (2 * np.pi)) 

41 Q_aL = density.calculate_compensation_charge_coefficients() 

42 s_vv += pot_calc.poisson_solver.stress_contribution(vHt_h, Q_aL) 

43 if ibzwfs.domain_comm.rank == 0: 

44 s_vv -= xp.eye(3) * potential.e_stress 

45 s_vv += pot_calc.vbar_ag.stress_contribution(nt_g) 

46 s_vv += density.nct_aX.stress_contribution(vt_g) 

47 

48 if dedtaut_g is not None: 

49 s_vv += density.tauct_aX.stress_contribution(dedtaut_g) 

50 

51 s_vv = as_np(s_vv) 

52 

53 if xp is not np: 

54 synchronize() 

55 comm.sum(s_vv, 0) 

56 

57 vol = dom.volume 

58 s_vv = 0.5 / vol * (s_vv + s_vv.T) 

59 

60 # Symmetrize: 

61 sigma_vv = np.zeros((3, 3)) 

62 cell_cv = dom.cell_cv 

63 icell_cv = dom.icell 

64 rotation_scc = ibzwfs.ibz.symmetries.rotation_scc 

65 for U_cc in rotation_scc: 

66 M_vv = (icell_cv.T @ (U_cc @ cell_cv)).T 

67 sigma_vv += M_vv.T @ s_vv @ M_vv 

68 sigma_vv /= len(rotation_scc) 

69 

70 # Make sure all agree on the result (redundant calculation on 

71 # different cores involving BLAS might give slightly different 

72 # results): 

73 

74 sigma_vv += pot_calc.extensions_stress_contribution 

75 comm.broadcast(sigma_vv, 0) 

76 return sigma_vv 

77 

78 

79def get_wfs_stress(ibzwfs: IBZWaveFunctions, 

80 dH_asii: AtomArrays) -> Array2D: 

81 xp = ibzwfs.xp 

82 sigma_vv = xp.zeros((3, 3)) 

83 for wfs in ibzwfs: 

84 assert isinstance(wfs, PWFDWaveFunctions) 

85 occ_n = xp.asarray(wfs.weight * wfs.spin_degeneracy * wfs.myocc_n) 

86 sigma_vv += get_kinetic_stress(wfs, occ_n) 

87 sigma_vv += get_paw_stress(wfs, dH_asii, occ_n) 

88 return sigma_vv 

89 

90 

91def get_kinetic_stress(wfs: PWFDWaveFunctions, 

92 occ_n) -> Array2D: 

93 psit_nG = wfs.psit_nX 

94 pw = psit_nG.desc 

95 xp = psit_nG.xp 

96 psit_nGz = psit_nG.data.view( 

97 as_real_dtype(pw.dtype)).reshape(psit_nG.data.shape + (2,)) 

98 psit2_G = xp.einsum('n, nGz, nGz -> G', occ_n, psit_nGz, psit_nGz) 

99 Gk_Gv = xp.asarray(pw.G_plus_k_Gv) 

100 sigma_vv = xp.einsum('G, Gv, Gw -> vw', psit2_G, Gk_Gv, Gk_Gv) 

101 x = pw.dv 

102 if np.issubdtype(pw.dtype, np.floating): 

103 x *= 2 

104 return -x * sigma_vv 

105 

106 

107def get_paw_stress(wfs: PWFDWaveFunctions, 

108 dH_asii: AtomArrays, 

109 occ_n) -> Array2D: 

110 xp = wfs.xp 

111 eig_n1 = xp.asarray(wfs.myeig_n[:, None]) 

112 a_ani = {} 

113 s = 0.0 

114 for a, P_ni in wfs.P_ani.items(): 

115 Pf_ni = P_ni * occ_n[:, None] 

116 dH_ii = dH_asii[a][wfs.spin] 

117 dS_ii = xp.asarray(wfs.setups[a].dO_ii) 

118 a_ni = Pf_ni @ dH_ii - Pf_ni * eig_n1 @ dS_ii 

119 s += xp.vdot(P_ni, a_ni) 

120 a_ani[a] = 2 * a_ni.conj() 

121 s_vv = wfs.pt_aiX.stress_contribution(wfs.psit_nX, a_ani) 

122 return s_vv - float(s.real) * xp.eye(3)