Coverage for gpaw/calcinfo.py: 96%

48 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1from dataclasses import dataclass 

2from typing import Union 

3 

4from ase import Atoms 

5 

6from gpaw.new.calculation import DFTCalculation 

7from gpaw.new.ibzwfs import IBZ 

8from gpaw.new.logger import Logger 

9from gpaw.core import UGDesc 

10from gpaw.core.domain import Domain 

11from gpaw.setup import Setups 

12from gpaw.mpi import MPIComm 

13from gpaw.dft import Parameters 

14 

15 

16@dataclass 

17class CalcInfo: 

18 atoms: Atoms 

19 input_params: dict 

20 ibz: IBZ 

21 ncomponents: int 

22 nspins: int 

23 nbands: int 

24 nelectrons: float 

25 setups: Setups 

26 grid: UGDesc 

27 wf_description: Union[Domain, None] 

28 communicators: Union[dict[str, MPIComm], None] 

29 comm: Union[MPIComm, None] 

30 log: Union[Logger, str, None] 

31 

32 def update_params(self, **updated_params): 

33 params = self.input_params.copy() 

34 if self.log is not None and 'txt' not in updated_params: 

35 params['txt'] = self.log 

36 if self.comm is not None and 'comm' not in updated_params: 

37 params['comm'] = self.comm 

38 params.update(updated_params) 

39 return get_calculation_info(self.atoms, **params) 

40 

41 def dft_calculation(self) -> DFTCalculation: 

42 return DFTCalculation.from_parameters(self.atoms.copy(), 

43 Parameters(**self.input_params), 

44 comm=self.comm, 

45 log=self.log) 

46 

47 def ase_calculator(self): 

48 return self.dft_calculation().ase_calculator() 

49 

50 

51def get_calculation_info(atoms: Atoms, 

52 **param_kwargs) -> CalcInfo: 

53 """ 

54 Get information about a calculation, e.g. grid size, IBZ, nbands, 

55 parallelization, etc. without actually performing the calculation 

56 or initializing large arrays. 

57 

58 Parameters 

59 ---------- 

60 atoms : Atoms 

61 Atoms object 

62 **param_kwargs : 

63 Input parameters as keyword arguments 

64 

65 Returns 

66 ------- 

67 CalcInfo 

68 Information about the calculation with the given input parameters. 

69 

70 CalcInfo attributes 

71 ----- 

72 atoms : Atoms 

73 Atoms object 

74 input_params : dict 

75 Input parameters 

76 ibz : IBZ 

77 IBZ object with information about k-point grid 

78 ncomponents : int 

79 Number of spin components 

80 nspins : int 

81 Number of spin channels 

82 nbands : int 

83 Number of bands 

84 setups : Setups 

85 Setups object with information about pseudopotentials 

86 grid : UGDesc 

87 Grid object with information about the real space grid 

88 wf_description : Domain 

89 Domain object with information about the wavefunctions 

90 (only for non-LCAO calculations) 

91 communicators : dict 

92 Dictionary with communicators for k-points, domains and bands 

93 comm : MPIComm 

94 MPI communicator 

95 log : Logger 

96 Logger object 

97 

98 CalcInfo methods 

99 ---------------- 

100 update_params 

101 Update input parameters and return new CalcInfo object 

102 dft_calculation 

103 Return DFTCalculation object with the given input parameters 

104 ase_calculator 

105 Return ASECalculation object with the given input parameters 

106 """ 

107 if 'txt' in param_kwargs: 

108 log = param_kwargs.pop('txt') 

109 else: 

110 log = None 

111 if 'comm' in param_kwargs: 

112 comm = param_kwargs.pop('comm') 

113 else: 

114 comm = None 

115 dft_builder = Parameters(**param_kwargs).dft_component_builder( 

116 atoms, comm=comm, log=log) 

117 dft_params = CalcInfo(atoms=atoms, 

118 input_params=param_kwargs, 

119 ibz=dft_builder.ibz, 

120 ncomponents=dft_builder.ncomponents, 

121 nspins=dft_builder.nspins, 

122 nbands=dft_builder.nbands, 

123 nelectrons=dft_builder.nelectrons, 

124 setups=dft_builder.setups, 

125 grid=dft_builder.grid, 

126 communicators=dft_builder.communicators, 

127 wf_description=dft_builder.create_wf_description() 

128 if dft_builder.mode != 'lcao' else None, 

129 comm=comm, 

130 log=log) 

131 return dft_params