Coverage for gpaw/core/domain.py: 89%

82 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3from typing import TYPE_CHECKING, Sequence, Literal, Generic, TypeVar 

4 

5import numpy as np 

6from ase.geometry.cell import cellpar_to_cell 

7 

8from gpaw.fftw import get_efficient_fft_size 

9from gpaw.mpi import MPIComm, serial_comm 

10from gpaw.typing import (Array2D, ArrayLike, ArrayLike1D, ArrayLike2D, 

11 DTypeLike, Vector, Self) 

12 

13if TYPE_CHECKING: 

14 from gpaw.core import UGDesc 

15 from gpaw.core.arrays import DistributedArrays 

16 

17 

18def normalize_cell(cell: ArrayLike) -> Array2D: 

19 """... 

20 

21 >>> normalize_cell([1, 2, 3]) 

22 array([[1., 0., 0.], 

23 [0., 2., 0.], 

24 [0., 0., 3.]]) 

25 """ 

26 cell = np.array(cell, float) 

27 if cell.ndim == 2: 

28 return cell 

29 if len(cell) == 3: 

30 return np.diag(cell) 

31 return cellpar_to_cell(cell) 

32 

33 

34XArray = TypeVar('XArray', bound='DistributedArrays') 

35 

36 

37class Domain(Generic[XArray]): 

38 itemsize: int 

39 

40 def __init__(self, 

41 cell: ArrayLike1D | ArrayLike2D, 

42 pbc=(True, True, True), 

43 kpt: Vector | None = None, 

44 comm: MPIComm = serial_comm, 

45 dtype: DTypeLike | None = None): 

46 """""" 

47 if isinstance(pbc, int): 

48 pbc = (pbc,) * 3 

49 self.cell_cv = normalize_cell(cell) 

50 self.pbc_c = np.array(pbc, bool) 

51 self.comm = comm 

52 

53 self.volume = abs(np.linalg.det(self.cell_cv)) 

54 self.orthogonal = not (self.cell_cv - 

55 np.diag(self.cell_cv.diagonal())).any() 

56 

57 assert np.dtype(dtype) in \ 

58 [None, np.float32, np.float64, np.complex64, np.complex128], dtype 

59 

60 # XXX: Gotta be careful about precision here: 

61 if kpt is not None: 

62 if dtype is None: 

63 dtype = complex 

64 else: 

65 if dtype is None: 

66 dtype = float 

67 kpt = (0.0, 0.0, 0.0) 

68 

69 self.kpt_c = np.array(kpt, float) 

70 

71 if self.kpt_c.any(): 

72 if dtype == float: 

73 raise ValueError(f'dtype must be complex for kpt={kpt}') 

74 for p, k in zip(pbc, self.kpt_c): 

75 if not p and k != 0: 

76 raise ValueError(f'Bad k-point {kpt} for pbc={pbc}') 

77 

78 self.dtype = np.dtype(dtype) # type: ignore 

79 

80 self.myshape: tuple[int, ...] 

81 

82 def new(self, 

83 *, 

84 kpt=None, 

85 dtype=None, 

86 comm: MPIComm | Literal['inherit'] | None = 'inherit' 

87 ) -> Self: 

88 raise NotImplementedError 

89 

90 def __repr__(self): 

91 comm = self.comm 

92 if self.kpt_c.any(): 

93 k = f', kpt={self.kpt_c.tolist()}' 

94 else: 

95 k = '' 

96 if (self.cell_cv == np.diag(self.cell_cv.diagonal())).all(): 

97 cell = self.cell_cv.diagonal().tolist() 

98 else: 

99 cell = self.cell_cv.tolist() 

100 return (f'Domain(cell={cell}, ' 

101 f'pbc={self.pbc_c.tolist()}, ' 

102 f'comm={comm.rank}/{comm.size}, ' 

103 f'dtype={self.dtype}{k})') 

104 

105 def global_shape(self) -> tuple[int, ...]: 

106 raise NotImplementedError 

107 

108 @property 

109 def cell(self): 

110 return self.cell_cv.copy() 

111 

112 @property 

113 def pbc(self): 

114 return self.pbc_c.copy() 

115 

116 @property 

117 def kpt(self): 

118 return self.kpt_c.copy() 

119 

120 def empty(self, 

121 shape: int | tuple[int, ...] = (), 

122 comm: MPIComm = serial_comm, xp=None) -> XArray: 

123 raise NotImplementedError 

124 

125 def zeros(self, 

126 shape: int | tuple[int, ...] = (), 

127 comm: MPIComm = serial_comm, xp=None) -> XArray: 

128 array = self.empty(shape, comm, xp=xp) 

129 array.data[:] = 0.0 

130 return array 

131 

132 @property 

133 def icell(self): 

134 """Inverse of unit cell. 

135 

136 >>> d = Domain([1, 2, 4]) 

137 >>> d.icell 

138 array([[1. , 0. , 0. ], 

139 [0. , 0.5 , 0. ], 

140 [0. , 0. , 0.25]]) 

141 >>> d.cell @ d.icell.T 

142 array([[1., 0., 0.], 

143 [0., 1., 0.], 

144 [0., 0., 1.]]) 

145 """ 

146 return np.linalg.inv(self.cell).T 

147 

148 def uniform_grid_with_grid_spacing(self, 

149 grid_spacing: float, 

150 n: int = 1, 

151 factors: Sequence[int] = (2, 3, 5, 7) 

152 ) -> UGDesc: 

153 from gpaw.core import UGDesc 

154 

155 L_c = (np.linalg.inv(self.cell_cv)**2).sum(0)**-0.5 

156 size_c = np.maximum(n, (L_c / grid_spacing / n + 0.5).astype(int) * n) 

157 if factors: 

158 size_c = np.array([get_efficient_fft_size(N, n, factors) 

159 for N in size_c]) 

160 return UGDesc(size=size_c, 

161 cell=self.cell_cv, 

162 pbc=self.pbc_c, 

163 kpt=self.kpt_c, 

164 dtype=self.dtype, 

165 comm=self.comm)