Coverage for gpaw/calcinfo.py: 96%
48 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from dataclasses import dataclass
2from typing import Union
4from ase import Atoms
6from gpaw.new.calculation import DFTCalculation
7from gpaw.new.ibzwfs import IBZ
8from gpaw.new.logger import Logger
9from gpaw.core import UGDesc
10from gpaw.core.domain import Domain
11from gpaw.setup import Setups
12from gpaw.mpi import MPIComm
13from gpaw.dft import Parameters
16@dataclass
17class CalcInfo:
18 atoms: Atoms
19 input_params: dict
20 ibz: IBZ
21 ncomponents: int
22 nspins: int
23 nbands: int
24 nelectrons: float
25 setups: Setups
26 grid: UGDesc
27 wf_description: Union[Domain, None]
28 communicators: Union[dict[str, MPIComm], None]
29 comm: Union[MPIComm, None]
30 log: Union[Logger, str, None]
32 def update_params(self, **updated_params):
33 params = self.input_params.copy()
34 if self.log is not None and 'txt' not in updated_params:
35 params['txt'] = self.log
36 if self.comm is not None and 'comm' not in updated_params:
37 params['comm'] = self.comm
38 params.update(updated_params)
39 return get_calculation_info(self.atoms, **params)
41 def dft_calculation(self) -> DFTCalculation:
42 return DFTCalculation.from_parameters(self.atoms.copy(),
43 Parameters(**self.input_params),
44 comm=self.comm,
45 log=self.log)
47 def ase_calculator(self):
48 return self.dft_calculation().ase_calculator()
51def get_calculation_info(atoms: Atoms,
52 **param_kwargs) -> CalcInfo:
53 """
54 Get information about a calculation, e.g. grid size, IBZ, nbands,
55 parallelization, etc. without actually performing the calculation
56 or initializing large arrays.
58 Parameters
59 ----------
60 atoms : Atoms
61 Atoms object
62 **param_kwargs :
63 Input parameters as keyword arguments
65 Returns
66 -------
67 CalcInfo
68 Information about the calculation with the given input parameters.
70 CalcInfo attributes
71 -----
72 atoms : Atoms
73 Atoms object
74 input_params : dict
75 Input parameters
76 ibz : IBZ
77 IBZ object with information about k-point grid
78 ncomponents : int
79 Number of spin components
80 nspins : int
81 Number of spin channels
82 nbands : int
83 Number of bands
84 setups : Setups
85 Setups object with information about pseudopotentials
86 grid : UGDesc
87 Grid object with information about the real space grid
88 wf_description : Domain
89 Domain object with information about the wavefunctions
90 (only for non-LCAO calculations)
91 communicators : dict
92 Dictionary with communicators for k-points, domains and bands
93 comm : MPIComm
94 MPI communicator
95 log : Logger
96 Logger object
98 CalcInfo methods
99 ----------------
100 update_params
101 Update input parameters and return new CalcInfo object
102 dft_calculation
103 Return DFTCalculation object with the given input parameters
104 ase_calculator
105 Return ASECalculation object with the given input parameters
106 """
107 if 'txt' in param_kwargs:
108 log = param_kwargs.pop('txt')
109 else:
110 log = None
111 if 'comm' in param_kwargs:
112 comm = param_kwargs.pop('comm')
113 else:
114 comm = None
115 dft_builder = Parameters(**param_kwargs).dft_component_builder(
116 atoms, comm=comm, log=log)
117 dft_params = CalcInfo(atoms=atoms,
118 input_params=param_kwargs,
119 ibz=dft_builder.ibz,
120 ncomponents=dft_builder.ncomponents,
121 nspins=dft_builder.nspins,
122 nbands=dft_builder.nbands,
123 nelectrons=dft_builder.nelectrons,
124 setups=dft_builder.setups,
125 grid=dft_builder.grid,
126 communicators=dft_builder.communicators,
127 wf_description=dft_builder.create_wf_description()
128 if dft_builder.mode != 'lcao' else None,
129 comm=comm,
130 log=log)
131 return dft_params