Coverage for gpaw/core/arrays.py: 56%
259 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
3from typing import TYPE_CHECKING, Generic, TypeVar, Callable, Literal
5import gpaw.fftw as fftw
6import numpy as np
7from ase.io.ulm import NDArrayReader
8from gpaw.core.domain import Domain
9from gpaw.core.matrix import Matrix
10from gpaw.mpi import MPIComm
11from gpaw.typing import Array1D, Self, ArrayND
12from gpaw.gpu import XP
13from gpaw.new import trace
15if TYPE_CHECKING:
16 from gpaw.core.uniform_grid import UGArray, UGDesc
18from gpaw.new import prod
20DomainType = TypeVar('DomainType', bound=Domain)
23class XArrayWithNoData:
24 def __init__(self,
25 comm,
26 dims,
27 desc,
28 xp):
29 self.comm = comm
30 self.dims = dims
31 self.desc = desc
32 self.xp = xp
33 self.data = None
35 def morph(self, desc):
36 from gpaw.new.calculation import ReuseWaveFunctionsError
37 raise ReuseWaveFunctionsError
40class DistributedArrays(Generic[DomainType], XP):
41 desc: DomainType
43 def __init__(self,
44 dims: int | tuple[int, ...],
45 myshape: tuple[int, ...],
46 comm: MPIComm,
47 domain_comm: MPIComm,
48 data: np.ndarray | None,
49 dv: float,
50 dtype,
51 xp=None):
52 self.myshape = myshape
53 self.comm = comm
54 self.domain_comm = domain_comm
55 self.dv = dv
57 # convert int to tuple:
58 self.dims = dims if isinstance(dims, tuple) else (dims,)
60 if self.dims:
61 mydims0 = (self.dims[0] + comm.size - 1) // comm.size
62 d1 = min(comm.rank * mydims0, self.dims[0])
63 d2 = min((comm.rank + 1) * mydims0, self.dims[0])
64 mydims0 = d2 - d1
65 self.mydims = (mydims0,) + self.dims[1:]
66 else:
67 self.mydims = ()
69 fullshape = self.mydims + self.myshape
71 if data is not None:
72 if data.shape != fullshape:
73 raise ValueError(
74 f'Bad shape for data: {data.shape} != {fullshape}')
75 if data.dtype != dtype:
76 raise ValueError(
77 f'Bad dtype for data: {data.dtype} != {dtype}')
78 if xp is not None:
79 assert (xp is np) == isinstance(
80 data, (np.ndarray, NDArrayReader)), xp
81 else:
82 data = (xp or np).empty(fullshape, dtype)
84 self.data = data
85 if isinstance(data, (np.ndarray, NDArrayReader)):
86 xp = np
87 else:
88 from gpaw.gpu import cupy as cp
89 xp = cp
90 XP.__init__(self, xp)
91 self._matrix: Matrix | None = None
93 def new(self, data=None, dims=None) -> DistributedArrays:
94 raise NotImplementedError
96 def create_work_buffer(self, data_buffer: np.ndarray):
97 """Create new Distributed array object of same
98 kind, to be used as a buffer array when doing
99 sliced operations.
101 Parameters
102 ----------
103 data_buffer:
104 Array to use for storage.
105 """
106 assert isinstance(data_buffer, self.xp.ndarray)
107 assert len(self.dims) >= 1
108 data_buffer = data_buffer.view(self.data.dtype)
109 datasize = data_buffer.size
110 X = self.data.shape[1:]
111 nX = int(np.prod(X))
112 # Choose mybands, s.t. they fit into
113 # data_buffer. Hence, datasize divided by nX
114 # rounded down.
115 mybands = min(datasize // nX,
116 self.data.shape[0])
117 data = data_buffer[:mybands * nX].reshape(
118 (mybands,) + X)
119 totalbands = self.comm.sum_scalar(mybands)
120 # Dims is (totalbands,) + self.dims[1:], where
121 # self.dims[1:] is extra dimensions, such as spin.
122 return self.new(data=data,
123 dims=(totalbands,) + self.dims[1:])
125 def copy(self):
126 return self.new(data=self.data.copy())
128 def sanity_check(self) -> None:
129 """Sanity check."""
130 pass
132 def __getitem__(self, index):
133 raise NotImplementedError
135 def __bool__(self):
136 raise ValueError
138 def __len__(self):
139 return self.dims[0]
141 def __iter__(self):
142 for index in range(self.dims[0]):
143 yield self[index]
145 def flat(self):
146 if self.dims == ():
147 yield self
148 else:
149 for index in np.indices(self.dims).reshape((len(self.dims), -1)).T:
150 yield self[tuple(index)]
152 def to_xp(self, xp):
153 if xp is self.xp:
154 assert xp is np, 'cp -> cp should not be needed!'
155 return self
156 if xp is np:
157 return self.new(data=self.xp.asnumpy(self.data))
158 else:
159 return self.new(data=xp.asarray(self.data))
161 @property
162 def matrix(self) -> Matrix:
163 if self._matrix is not None:
164 return self._matrix
166 nx = prod(self.myshape)
167 shape = (self.dims[0], prod(self.dims[1:]) * nx)
168 myshape = (self.mydims[0], prod(self.mydims[1:]) * nx)
169 dist = (self.comm, -1, 1)
171 data = self.data.reshape(myshape)
172 self._matrix = Matrix(*shape, data=data, dist=dist)
174 return self._matrix
176 @trace
177 def matrix_elements(self,
178 other: Self,
179 *,
180 out: Matrix | None = None,
181 symmetric: bool | Literal['_default'] = '_default',
182 function=None,
183 domain_sum=True,
184 cc: bool = False) -> Matrix:
185 if symmetric == '_default':
186 symmetric = self is other
188 comm = self.comm
190 if out is None:
191 out = Matrix(self.dims[0], other.dims[0],
192 dist=(comm, -1, 1),
193 dtype=self.desc.dtype,
194 xp=self.xp)
196 if comm.size == 1:
197 assert other.comm.size == 1
198 if function:
199 assert symmetric
200 other = function(other)
202 M1 = self.matrix
203 M2 = other.matrix
204 out = M1.multiply(M2, opb='C', alpha=self.dv,
205 symmetric=symmetric, out=out)
207 # Plane-wave expansion of real-valued
208 # functions needs a correction:
209 self._matrix_elements_correction(M1, M2, out, symmetric)
210 else:
211 if symmetric:
212 _parallel_me_sym(self, out, function)
213 else:
214 _parallel_me(self, other, out)
216 if not cc:
217 out.complex_conjugate()
219 if domain_sum:
220 self.domain_comm.sum(out.data)
221 return out
223 def _matrix_elements_correction(self,
224 M1: Matrix,
225 M2: Matrix,
226 out: Matrix,
227 symmetric: bool) -> None:
228 """Hook for PlaneWaveExpansion."""
229 pass
231 def abs_square(self,
232 weights: Array1D,
233 out: UGArray) -> None:
234 """Add weighted absolute square of data to output array.
236 See also :xkcd:`849`.
237 """
238 raise NotImplementedError
240 def add_ked(self,
241 weights: Array1D,
242 out: UGArray) -> None:
243 """Add weighted absolute square of gradient of data to output array."""
244 raise NotImplementedError
246 def gather(self, out=None, broadcast=False):
247 raise NotImplementedError
249 def gathergather(self):
250 a_xX = self.gather() # gather X
251 if a_xX is not None:
252 m_xX = a_xX.matrix.gather() # gather x
253 if m_xX.dist.comm.rank == 0:
254 data = m_xX.data
255 if a_xX.data.dtype != data.dtype:
256 data = data.view(complex)
257 return self.desc.new(comm=None).from_data(data)
259 def scatter_from(self, data: ArrayND | None = None) -> None:
260 raise NotImplementedError
262 def redist(self,
263 domain,
264 comm1: MPIComm, comm2: MPIComm) -> DistributedArrays:
265 result = domain.empty(self.dims)
266 if comm1.rank == 0:
267 a = self.gather()
268 else:
269 a = None
270 if comm2.rank == 0:
271 result.scatter_from(a)
272 comm2.broadcast(result.data, 0)
273 return result
275 def interpolate(self,
276 plan1: fftw.FFTPlans | None = None,
277 plan2: fftw.FFTPlans | None = None,
278 grid: UGDesc | None = None,
279 out: UGArray | None = None) -> UGArray:
280 raise NotImplementedError
282 def integrate(self, other: Self | None = None) -> np.ndarray:
283 raise NotImplementedError
285 def norm2(self, kind: str = 'normal', skip_sum=False) -> np.ndarray:
286 raise NotImplementedError
289def _parallel_me(psit1_nX: DistributedArrays,
290 psit2_nX: DistributedArrays,
291 M_nn: Matrix) -> None:
293 comm = psit2_nX.comm
294 nbands = psit2_nX.dims[0]
296 psit1_nX = psit1_nX[:]
298 B = (nbands + comm.size - 1) // comm.size
300 n_r = [min(r * B, nbands) for r in range(comm.size + 1)]
302 xp = psit1_nX.xp
303 buf1_nX = psit1_nX.desc.empty(B, xp=xp)
304 buf2_nX = psit1_nX.desc.empty(B, xp=xp)
305 psit_nX = psit2_nX
307 for shift in range(comm.size):
308 rrequest = None
309 srequest = None
311 if shift < comm.size - 1:
312 srank = (comm.rank + shift + 1) % comm.size
313 rrank = (comm.rank - shift - 1) % comm.size
314 n1 = n_r[rrank]
315 n2 = n_r[rrank + 1]
316 mynb = n2 - n1
317 if mynb > 0:
318 rrequest = comm.receive(buf1_nX.data[:mynb], rrank, 11, False)
319 if psit2_nX.data.size > 0:
320 srequest = comm.send(psit2_nX.data, srank, 11, False)
322 r2 = (comm.rank - shift) % comm.size
323 n1 = n_r[r2]
324 n2 = n_r[r2 + 1]
325 m_nn = psit1_nX.matrix_elements(psit_nX[:n2 - n1],
326 cc=True, domain_sum=False)
328 M_nn.data[:, n1:n2] = m_nn.data
330 if rrequest:
331 comm.wait(rrequest)
332 if srequest:
333 comm.wait(srequest)
335 psit_nX = buf1_nX
336 buf1_nX, buf2_nX = buf2_nX, buf1_nX
339def _parallel_me_sym(psit1_nX: DistributedArrays,
340 M_nn: Matrix,
341 operator: None | Callable[[DistributedArrays],
342 DistributedArrays]
343 ) -> None:
344 """..."""
345 comm = psit1_nX.comm
346 nbands = psit1_nX.dims[0]
347 B = (nbands + comm.size - 1) // comm.size
348 mynbands = psit1_nX.mydims[0]
350 n_r = [min(r * B, nbands) for r in range(comm.size + 1)]
351 mynbands_r = [n_r[r + 1] - n_r[r] for r in range(comm.size)]
352 assert mynbands_r[comm.rank] == mynbands
354 xp = psit1_nX.xp
355 psit2_nX = psit1_nX
356 buf1_nX = psit1_nX.desc.empty(B, xp=xp)
357 buf2_nX = psit1_nX.desc.empty(B, xp=xp)
358 half = comm.size // 2
360 for shift in range(half + 1):
361 rrequest = None
362 srequest = None
364 if shift < half:
365 srank = (comm.rank + shift + 1) % comm.size
366 rrank = (comm.rank - shift - 1) % comm.size
367 skip = comm.size % 2 == 0 and shift == half - 1
368 rmynb = mynbands_r[rrank]
369 if not (skip and comm.rank < half) and rmynb > 0:
370 rrequest = comm.receive(buf1_nX.data[:rmynb], rrank, 11, False)
371 if not (skip and comm.rank >= half) and psit1_nX.data.size > 0:
372 srequest = comm.send(psit1_nX.data, srank, 11, False)
374 if shift == 0:
375 if operator is not None:
376 op_psit1_nX = operator(psit1_nX)
377 else:
378 op_psit1_nX = psit1_nX
379 op_psit1_nX = op_psit1_nX[:] # local view
381 if not (comm.size % 2 == 0 and shift == half and comm.rank < half):
382 r2 = (comm.rank - shift) % comm.size
383 n1 = n_r[r2]
384 n2 = n_r[r2 + 1]
385 m_nn = op_psit1_nX.matrix_elements(psit2_nX[:n2 - n1],
386 symmetric=(shift == 0),
387 cc=True, domain_sum=False)
388 M_nn.data[:, n1:n2] = m_nn.data
390 if rrequest:
391 comm.wait(rrequest)
392 if srequest:
393 comm.wait(srequest)
395 psit2_nX = buf1_nX
396 buf1_nX, buf2_nX = buf2_nX, buf1_nX
398 requests = []
399 blocks = []
400 nrows = (comm.size - 1) // 2
401 for row in range(nrows):
402 for column in range(comm.size - nrows + row, comm.size):
403 if comm.rank == row:
404 n1 = n_r[column]
405 n2 = n_r[column + 1]
406 if mynbands > 0 and n2 > n1:
407 requests.append(
408 comm.send(M_nn.data[:, n1:n2].T.conj().copy(),
409 column, 12, False))
410 elif comm.rank == column:
411 n1 = n_r[row]
412 n2 = n_r[row + 1]
413 if mynbands > 0 and n2 > n1:
414 block = xp.empty((mynbands, n2 - n1), M_nn.dtype)
415 blocks.append((n1, n2, block))
416 requests.append(comm.receive(block, row, 12, False))
418 comm.waitall(requests)
419 for n1, n2, block in blocks:
420 M_nn.data[:, n1:n2] = block