Coverage for gpaw/nlopt/adapters.py: 84%

64 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2from types import SimpleNamespace 

3from typing import TYPE_CHECKING 

4 

5import numpy as np 

6 

7if TYPE_CHECKING: 

8 from gpaw.core.atom_arrays import AtomArrays 

9 from gpaw.core.plane_waves import PWArray 

10 from gpaw.new.ase_interface import ASECalculator 

11 from gpaw.new.pwfd.wave_functions import PWFDWaveFunctions 

12 from gpaw.typing import ArrayND 

13 

14 

15class GSInfo: 

16 """ 

17 This is the base class for the ground state adapters in the non-linear 

18 optics module. It is only compatible with GPAW_NEW. 

19 

20 The class should never be called directly, but should instead be called 

21 through the CollinearGSInfo or NoncollinearGSInfo classes. 

22 

23 These subclasses are necessary due to the different ways which the spin 

24 index is handled in collinear and noncollinear ground state calculations. 

25 """ 

26 def __init__(self, 

27 calc: ASECalculator): 

28 assert calc.params.mode.name == 'pw', \ 

29 'Calculator must be in plane wave mode.' 

30 

31 dft = calc.dft 

32 self.nabla_aiiv = [setup.nabla_iiv for setup in dft.setups] 

33 

34 ibzwfs = self.ibzwfs = dft.ibzwfs 

35 if not (ibzwfs.domain_comm.size == 1 and ibzwfs.band_comm.size == 1): 

36 raise ValueError('Calculator must be initialised with ' 

37 'only k-point parallelisation.') 

38 if isinstance(ibzwfs.wfs_qs[0][0].psit_nX, SimpleNamespace): 

39 raise ValueError('Calculator is missing wfs data. If loading from ' 

40 'a .gpw file, please recalculate wave functions.') 

41 

42 density = dft.density 

43 self.collinear = density.collinear 

44 self.ndensities = density.ndensities 

45 

46 grid = density.nt_sR.desc 

47 self.ucvol = np.abs(np.linalg.det(grid.cell)) 

48 self.bzvol = np.abs(np.linalg.det(2 * np.pi * grid.icell)) 

49 

50 def get_plane_wave_coefficients(self, 

51 wfs: PWFDWaveFunctions, 

52 bands: slice, 

53 spin: int) -> tuple[ArrayND, ArrayND]: 

54 """ 

55 Returns the plane wave coefficients and reciprocal vectors. 

56 

57 Output is an array with shape (band index, reciprocal vector index) 

58 """ 

59 psit_nG = wfs.psit_nX[bands] 

60 G_plus_k_Gv = psit_nG.desc.G_plus_k_Gv 

61 return G_plus_k_Gv, self._pw_data(psit_nG, spin) 

62 

63 def get_wave_function_projections(self, 

64 wfs: PWFDWaveFunctions, 

65 bands: slice, 

66 spin: int): 

67 """ 

68 Returns the projections of the pseudo wfs onto the partial waves. 

69 

70 Output is a dictionary with atom index keys and array values with 

71 shape (band index, partial wave index) 

72 """ 

73 return self._proj_data(wfs.P_ani, bands, spin) 

74 

75 def get_wfs(self, 

76 wfs_s: list[PWFDWaveFunctions], 

77 spin: int) -> PWFDWaveFunctions: 

78 raise NotImplementedError 

79 

80 @staticmethod 

81 def _pw_data(psit: PWArray, 

82 spin: int) -> ArrayND: 

83 raise NotImplementedError 

84 

85 @staticmethod 

86 def _proj_data(P: AtomArrays, 

87 bands: slice, 

88 spin: int) -> dict[int, ArrayND]: 

89 raise NotImplementedError 

90 

91 

92class CollinearGSInfo(GSInfo): 

93 def __init__(self, 

94 calc: ASECalculator): 

95 super().__init__(calc) 

96 self.ns = self.ndensities 

97 

98 def get_wfs(self, 

99 wfs_s: list[PWFDWaveFunctions], 

100 spin: int) -> PWFDWaveFunctions: 

101 return wfs_s[spin] 

102 

103 @staticmethod 

104 def _pw_data(psit_nG: PWArray, 

105 _: int | None = None) -> ArrayND: 

106 return psit_nG.data 

107 

108 @staticmethod 

109 def _proj_data(P_ani: AtomArrays, 

110 bands: slice, 

111 _: int | None = None) -> dict[int, ArrayND]: 

112 return {a: P_ni[bands] for a, P_ni in P_ani.items()} 

113 

114 

115class NoncollinearGSInfo(GSInfo): 

116 def __init__(self, 

117 calc: ASECalculator): 

118 super().__init__(calc) 

119 self.ns = 2 

120 

121 def get_wfs(self, 

122 wfs_s: list[PWFDWaveFunctions], 

123 _: int | None = None) -> PWFDWaveFunctions: 

124 return wfs_s[0] 

125 

126 @staticmethod 

127 def _pw_data(psit_nsG: PWArray, 

128 spin: int) -> ArrayND: 

129 return psit_nsG.data[:, spin] 

130 

131 @staticmethod 

132 def _proj_data(P_ansi: AtomArrays, 

133 bands: slice, 

134 spin: int) -> dict[int, ArrayND]: 

135 return {a: P_nsi[bands, spin] for a, P_nsi in P_ansi.items()}