Coverage for gpaw/new/potential.py: 89%

76 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from ase.units import Bohr, Ha 

5 

6from gpaw.core.arrays import DistributedArrays as XArray 

7from gpaw.core.atom_arrays import AtomArrays, AtomDistribution 

8from gpaw.core.domain import Domain as XDesc 

9from gpaw.core import PWArray, UGArray, UGDesc 

10from gpaw.mpi import MPIComm, broadcast_float 

11from gpaw.new import zips 

12 

13 

14class Potential: 

15 def __init__(self, 

16 vt_sR: UGArray, 

17 dH_asii: AtomArrays, 

18 dedtaut_sR: UGArray | None, 

19 vHt_x: XArray | None = None, 

20 e_stress: float = np.nan): 

21 self.vt_sR = vt_sR 

22 self.dH_asii = dH_asii 

23 self.dedtaut_sR = dedtaut_sR 

24 self.vHt_x = vHt_x # initial guess for Hartree potential 

25 self.e_stress = e_stress # idotropic contribution to stress tensor 

26 

27 def __repr__(self): 

28 return (f'Potential({self.vt_sR}, {self.dH_asii}, ' 

29 f'{self.dedtaut_sR})') 

30 

31 def __str__(self) -> str: 

32 return (f'potential:\n' 

33 f' grid points: {self.vt_sR.desc.size}\n') 

34 

35 def dH(self, P_ani, out_ani, spin): 

36 if len(P_ani.dims) == 1: # collinear wave functions 

37 P_ani.block_diag_multiply(self.dH_asii, out_ani, spin) 

38 return 

39 

40 # Non-collinear wave functions: 

41 P_ansi = P_ani 

42 out_ansi = out_ani 

43 

44 for (a, P_nsi), out_nsi in zips(P_ansi.items(), out_ansi.values()): 

45 v_ii, x_ii, y_ii, z_ii = (dh_ii.T for dh_ii in self.dH_asii[a]) 

46 assert v_ii.dtype == complex 

47 out_nsi[:, 0] = (P_nsi[:, 0] @ (v_ii + z_ii) + 

48 P_nsi[:, 1] @ (x_ii - 1j * y_ii)) 

49 out_nsi[:, 1] = (P_nsi[:, 1] @ (v_ii - z_ii) + 

50 P_nsi[:, 0] @ (x_ii + 1j * y_ii)) 

51 return out_ansi 

52 

53 def move(self, atomdist: AtomDistribution) -> None: 

54 """Move atoms inplace.""" 

55 self.dH_asii = self.dH_asii.moved(atomdist) 

56 

57 def redist(self, 

58 grid: UGDesc, 

59 desc: XDesc, 

60 atomdist: AtomDistribution, 

61 comm1: MPIComm, 

62 comm2: MPIComm) -> Potential: 

63 return Potential( 

64 self.vt_sR.redist(grid, comm1, comm2), 

65 self.dH_asii.redist(atomdist, comm1, comm2), 

66 None if self.dedtaut_sR is None else self.dedtaut_sR.redist( 

67 grid, comm1, comm2), 

68 None if self.vHt_x is None else self.vHt_x.redist( 

69 desc, comm1, comm2)) 

70 

71 def write_to_gpw(self, writer, flags): 

72 dH_asp = self.dH_asii.to_cpu().to_lower_triangle().gather() 

73 vt_sR = self.vt_sR.to_xp(np).gather() 

74 if self.dedtaut_sR is not None: 

75 dedtaut_sR = self.dedtaut_sR.to_xp(np).gather() 

76 if self.vHt_x is not None: 

77 vHt_x = self.vHt_x.to_xp(np).gather() 

78 if dH_asp is None: 

79 return 

80 

81 writer.write( 

82 potential=flags.to_storage_dtype(vt_sR.data * Ha), 

83 atomic_hamiltonian_matrices=dH_asp.data * Ha) 

84 if self.vHt_x is not None: 

85 vHt_x_data = flags.to_storage_dtype(vHt_x.data * Ha) 

86 writer.write(electrostatic_potential=vHt_x_data) 

87 if self.dedtaut_sR is not None: 

88 dedtaut_sR_data = flags.to_storage_dtype(dedtaut_sR.data * Bohr**3) 

89 writer.write(mgga_potential=dedtaut_sR_data) 

90 

91 def get_vacuum_level(self) -> float: 

92 grid = self.vt_sR.desc 

93 if grid.pbc_c.all(): 

94 return np.nan 

95 if grid.zerobc_c.any(): 

96 return 0.0 

97 if self.vHt_x is None: 

98 raise ValueError('No electrostatic potential') 

99 if isinstance(self.vHt_x, UGArray): 

100 vHt_r = self.vHt_x.gather() 

101 elif isinstance(self.vHt_x, PWArray): 

102 vHt_g = self.vHt_x.gather() 

103 if vHt_g is not None: 

104 vHt_r = vHt_g.ifft(grid=vHt_g.desc.minimal_uniform_grid()) 

105 else: 

106 vHt_r = None 

107 else: 

108 return np.nan # TB-mode 

109 vacuum_level = 0.0 

110 if vHt_r is not None: 

111 for c, periodic in enumerate(grid.pbc_c): 

112 if not periodic: 

113 xp = vHt_r.xp 

114 vacuum_level += float(xp.moveaxis(vHt_r.data, 

115 c, 0)[0].mean()) 

116 

117 vacuum_level /= (3 - grid.pbc_c.sum()) 

118 

119 return broadcast_float(vacuum_level, grid.comm) * Ha