Coverage for gpaw/core/domain.py: 89%
82 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3from typing import TYPE_CHECKING, Sequence, Literal, Generic, TypeVar
5import numpy as np
6from ase.geometry.cell import cellpar_to_cell
8from gpaw.fftw import get_efficient_fft_size
9from gpaw.mpi import MPIComm, serial_comm
10from gpaw.typing import (Array2D, ArrayLike, ArrayLike1D, ArrayLike2D,
11 DTypeLike, Vector, Self)
13if TYPE_CHECKING:
14 from gpaw.core import UGDesc
15 from gpaw.core.arrays import DistributedArrays
18def normalize_cell(cell: ArrayLike) -> Array2D:
19 """...
21 >>> normalize_cell([1, 2, 3])
22 array([[1., 0., 0.],
23 [0., 2., 0.],
24 [0., 0., 3.]])
25 """
26 cell = np.array(cell, float)
27 if cell.ndim == 2:
28 return cell
29 if len(cell) == 3:
30 return np.diag(cell)
31 return cellpar_to_cell(cell)
34XArray = TypeVar('XArray', bound='DistributedArrays')
37class Domain(Generic[XArray]):
38 itemsize: int
40 def __init__(self,
41 cell: ArrayLike1D | ArrayLike2D,
42 pbc=(True, True, True),
43 kpt: Vector | None = None,
44 comm: MPIComm = serial_comm,
45 dtype: DTypeLike | None = None):
46 """"""
47 if isinstance(pbc, int):
48 pbc = (pbc,) * 3
49 self.cell_cv = normalize_cell(cell)
50 self.pbc_c = np.array(pbc, bool)
51 self.comm = comm
53 self.volume = abs(np.linalg.det(self.cell_cv))
54 self.orthogonal = not (self.cell_cv -
55 np.diag(self.cell_cv.diagonal())).any()
57 assert np.dtype(dtype) in \
58 [None, np.float32, np.float64, np.complex64, np.complex128], dtype
60 # XXX: Gotta be careful about precision here:
61 if kpt is not None:
62 if dtype is None:
63 dtype = complex
64 else:
65 if dtype is None:
66 dtype = float
67 kpt = (0.0, 0.0, 0.0)
69 self.kpt_c = np.array(kpt, float)
71 if self.kpt_c.any():
72 if dtype == float:
73 raise ValueError(f'dtype must be complex for kpt={kpt}')
74 for p, k in zip(pbc, self.kpt_c):
75 if not p and k != 0:
76 raise ValueError(f'Bad k-point {kpt} for pbc={pbc}')
78 self.dtype = np.dtype(dtype) # type: ignore
80 self.myshape: tuple[int, ...]
82 def new(self,
83 *,
84 kpt=None,
85 dtype=None,
86 comm: MPIComm | Literal['inherit'] | None = 'inherit'
87 ) -> Self:
88 raise NotImplementedError
90 def __repr__(self):
91 comm = self.comm
92 if self.kpt_c.any():
93 k = f', kpt={self.kpt_c.tolist()}'
94 else:
95 k = ''
96 if (self.cell_cv == np.diag(self.cell_cv.diagonal())).all():
97 cell = self.cell_cv.diagonal().tolist()
98 else:
99 cell = self.cell_cv.tolist()
100 return (f'Domain(cell={cell}, '
101 f'pbc={self.pbc_c.tolist()}, '
102 f'comm={comm.rank}/{comm.size}, '
103 f'dtype={self.dtype}{k})')
105 def global_shape(self) -> tuple[int, ...]:
106 raise NotImplementedError
108 @property
109 def cell(self):
110 return self.cell_cv.copy()
112 @property
113 def pbc(self):
114 return self.pbc_c.copy()
116 @property
117 def kpt(self):
118 return self.kpt_c.copy()
120 def empty(self,
121 shape: int | tuple[int, ...] = (),
122 comm: MPIComm = serial_comm, xp=None) -> XArray:
123 raise NotImplementedError
125 def zeros(self,
126 shape: int | tuple[int, ...] = (),
127 comm: MPIComm = serial_comm, xp=None) -> XArray:
128 array = self.empty(shape, comm, xp=xp)
129 array.data[:] = 0.0
130 return array
132 @property
133 def icell(self):
134 """Inverse of unit cell.
136 >>> d = Domain([1, 2, 4])
137 >>> d.icell
138 array([[1. , 0. , 0. ],
139 [0. , 0.5 , 0. ],
140 [0. , 0. , 0.25]])
141 >>> d.cell @ d.icell.T
142 array([[1., 0., 0.],
143 [0., 1., 0.],
144 [0., 0., 1.]])
145 """
146 return np.linalg.inv(self.cell).T
148 def uniform_grid_with_grid_spacing(self,
149 grid_spacing: float,
150 n: int = 1,
151 factors: Sequence[int] = (2, 3, 5, 7)
152 ) -> UGDesc:
153 from gpaw.core import UGDesc
155 L_c = (np.linalg.inv(self.cell_cv)**2).sum(0)**-0.5
156 size_c = np.maximum(n, (L_c / grid_spacing / n + 0.5).astype(int) * n)
157 if factors:
158 size_c = np.array([get_efficient_fft_size(N, n, factors)
159 for N in size_c])
160 return UGDesc(size=size_c,
161 cell=self.cell_cv,
162 pbc=self.pbc_c,
163 kpt=self.kpt_c,
164 dtype=self.dtype,
165 comm=self.comm)