Coverage for gpaw/hybrids/forces.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1# from typing import Tuple 

2 

3import numpy as np 

4 

5from gpaw.kpt_descriptor import KPointDescriptor 

6from gpaw.mpi import serial_comm, broadcast 

7from gpaw.pw.descriptor import PWDescriptor 

8from gpaw.pw.lfc import PWLFC 

9from .kpts import get_kpt 

10 

11 

12def calculate_forces(wfs, coulomb, sym, paw_s, ftol=1e-9) -> np.ndarray: 

13 kd = wfs.kd 

14 nspins = wfs.nspins 

15 

16 nocc = max(((kpt.f_n / kpt.weight) > ftol).sum() 

17 for kpt in wfs.kpt_u) 

18 nocc = kd.comm.max_scalar(int(nocc)) 

19 

20 dPdR_skaniv = {(kpt.s, kpt.k): wfs.pt.derivative(kpt.psit_nG[:nocc], 

21 q=kpt.k) 

22 for kpt in wfs.kpt_u} 

23 

24 natoms = len(wfs.spos_ac) 

25 F_av = np.zeros((natoms, 3)) 

26 

27 for spin in range(nspins): 

28 kpts = [] 

29 for k in range(kd.nibzkpts): 

30 kpt = get_kpt(wfs, k, spin, 0, nocc) 

31 rank_a = kpt.proj.atom_partition.rank_a 

32 dPdR_aniv = [ 

33 broadcast(dPdR_skaniv[(spin, k)].get(a), rank_a[a], wfs.world) 

34 for a in range(natoms)] 

35 kpt.dPdR_aniv = dPdR_aniv 

36 kpts.append(kpt) 

37 forces(kpts, paw_s[spin], 

38 wfs, sym, coulomb, F_av) 

39 

40 return F_av / nspins 

41 

42 

43def forces(kpts, paw, wfs, sym, coulomb, F_av): 

44 pd = kpts[0].psit.pd 

45 gd = pd.gd.new_descriptor(comm=serial_comm) 

46 comm = wfs.world 

47 for i1, i2, s, k1, k2, count in sym.pairs(kpts, wfs, wfs.spos_ac): 

48 q_c = k2.k_c - k1.k_c 

49 qd = KPointDescriptor([-q_c]) 

50 

51 pd12 = PWDescriptor(pd.ecut, gd, pd.dtype, kd=qd) 

52 ghat = PWLFC([data.ghat_l for data in wfs.setups], pd12) 

53 ghat.set_positions(wfs.spos_ac) 

54 

55 v_G = coulomb.get_potential(pd12) 

56 f_av = calculate_exx_for_pair(k1, k2, 

57 ghat, v_G, comm, 

58 paw, count) 

59 F_av += f_av * (1 / wfs.kd.nbzkpts**2) 

60 

61 for a, v_ii in paw.VV_aii.items(): 

62 vv_ii = 8 * v_ii + 4 * paw.VC_aii[a] 

63 for kpt in kpts: 

64 F_av[a] -= np.einsum('ij, niv, nj, n -> v', 

65 vv_ii, 

66 kpt.dPdR_aniv[a].conj(), 

67 kpt.proj[a], 

68 kpt.f_n).real * kpt.weight 

69 

70 

71def calculate_exx_for_pair(k1, 

72 k2, 

73 ghat, 

74 v_G, 

75 comm, 

76 paw, 

77 count) -> np.ndarray: 

78 

79 N1 = len(k1.u_nR) 

80 N2 = len(k2.u_nR) 

81 size = comm.size 

82 rank = comm.rank 

83 

84 Q_annL = [np.einsum('mi, ijL, nj -> mnL', 

85 k1.proj[a], 

86 Delta_iiL, 

87 k2.proj[a].conj()) 

88 for a, Delta_iiL in enumerate(paw.Delta_aiiL)] 

89 

90 if k1 is k2: 

91 n2max = (N1 + size - 1) // size 

92 else: 

93 n2max = N2 

94 

95 rho_nG = ghat.pd.empty(n2max, k1.u_nR.dtype) 

96 F_av = np.zeros((len(Q_annL), 3)) 

97 

98 for n1, u1_R in enumerate(k1.u_nR): 

99 if k1 is k2: 

100 B = (N1 + size - 1) // size 

101 n2a = min(rank * B, N2) 

102 n2b = min(n2a + B, N2) 

103 else: 

104 n2a = 0 

105 n2b = N2 

106 ff_n = k1.f_n[n1] * k2.f_n[n2a:n2b] * 2 * count 

107 

108 for n2, rho_G in enumerate(rho_nG[:n2b - n2a], n2a): 

109 rho_G[:] = ghat.pd.fft(u1_R * k2.u_nR[n2].conj()) 

110 

111 ghat.add(rho_nG[:n2b - n2a], 

112 {a: Q_nnL[n1, n2a:n2b] 

113 for a, Q_nnL in enumerate(Q_annL)}) 

114 

115 vrho_nG = rho_nG[:n2b - n2a] 

116 vrho_nG *= v_G 

117 

118 for a, v_nLv in ghat.derivative(vrho_nG).items(): 

119 F_av[a] -= np.einsum('n, nL, nLv -> v', 

120 ff_n, 

121 Q_annL[a][n1, n2a:n2b].conj(), 

122 v_nLv).real 

123 

124 for a, v_nL in ghat.integrate(vrho_nG).items(): 

125 v_iin = paw.Delta_aiiL[a].dot(v_nL.T) 

126 F_av[a] -= np.einsum('ijn, iv, nj, n -> v', 

127 v_iin, 

128 k1.dPdR_aniv[a][n1].conj(), 

129 k2.proj[a][n2a:n2b], 

130 ff_n).real * 2 

131 

132 return F_av