Coverage for gpaw/core/arrays.py: 56%

259 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3from typing import TYPE_CHECKING, Generic, TypeVar, Callable, Literal 

4 

5import gpaw.fftw as fftw 

6import numpy as np 

7from ase.io.ulm import NDArrayReader 

8from gpaw.core.domain import Domain 

9from gpaw.core.matrix import Matrix 

10from gpaw.mpi import MPIComm 

11from gpaw.typing import Array1D, Self, ArrayND 

12from gpaw.gpu import XP 

13from gpaw.new import trace 

14 

15if TYPE_CHECKING: 

16 from gpaw.core.uniform_grid import UGArray, UGDesc 

17 

18from gpaw.new import prod 

19 

20DomainType = TypeVar('DomainType', bound=Domain) 

21 

22 

23class XArrayWithNoData: 

24 def __init__(self, 

25 comm, 

26 dims, 

27 desc, 

28 xp): 

29 self.comm = comm 

30 self.dims = dims 

31 self.desc = desc 

32 self.xp = xp 

33 self.data = None 

34 

35 def morph(self, desc): 

36 from gpaw.new.calculation import ReuseWaveFunctionsError 

37 raise ReuseWaveFunctionsError 

38 

39 

40class DistributedArrays(Generic[DomainType], XP): 

41 desc: DomainType 

42 

43 def __init__(self, 

44 dims: int | tuple[int, ...], 

45 myshape: tuple[int, ...], 

46 comm: MPIComm, 

47 domain_comm: MPIComm, 

48 data: np.ndarray | None, 

49 dv: float, 

50 dtype, 

51 xp=None): 

52 self.myshape = myshape 

53 self.comm = comm 

54 self.domain_comm = domain_comm 

55 self.dv = dv 

56 

57 # convert int to tuple: 

58 self.dims = dims if isinstance(dims, tuple) else (dims,) 

59 

60 if self.dims: 

61 mydims0 = (self.dims[0] + comm.size - 1) // comm.size 

62 d1 = min(comm.rank * mydims0, self.dims[0]) 

63 d2 = min((comm.rank + 1) * mydims0, self.dims[0]) 

64 mydims0 = d2 - d1 

65 self.mydims = (mydims0,) + self.dims[1:] 

66 else: 

67 self.mydims = () 

68 

69 fullshape = self.mydims + self.myshape 

70 

71 if data is not None: 

72 if data.shape != fullshape: 

73 raise ValueError( 

74 f'Bad shape for data: {data.shape} != {fullshape}') 

75 if data.dtype != dtype: 

76 raise ValueError( 

77 f'Bad dtype for data: {data.dtype} != {dtype}') 

78 if xp is not None: 

79 assert (xp is np) == isinstance( 

80 data, (np.ndarray, NDArrayReader)), xp 

81 else: 

82 data = (xp or np).empty(fullshape, dtype) 

83 

84 self.data = data 

85 if isinstance(data, (np.ndarray, NDArrayReader)): 

86 xp = np 

87 else: 

88 from gpaw.gpu import cupy as cp 

89 xp = cp 

90 XP.__init__(self, xp) 

91 self._matrix: Matrix | None = None 

92 

93 def new(self, data=None, dims=None) -> DistributedArrays: 

94 raise NotImplementedError 

95 

96 def create_work_buffer(self, data_buffer: np.ndarray): 

97 """Create new Distributed array object of same 

98 kind, to be used as a buffer array when doing 

99 sliced operations. 

100 

101 Parameters 

102 ---------- 

103 data_buffer: 

104 Array to use for storage. 

105 """ 

106 assert isinstance(data_buffer, self.xp.ndarray) 

107 assert len(self.dims) >= 1 

108 data_buffer = data_buffer.view(self.data.dtype) 

109 datasize = data_buffer.size 

110 X = self.data.shape[1:] 

111 nX = int(np.prod(X)) 

112 # Choose mybands, s.t. they fit into 

113 # data_buffer. Hence, datasize divided by nX 

114 # rounded down. 

115 mybands = min(datasize // nX, 

116 self.data.shape[0]) 

117 data = data_buffer[:mybands * nX].reshape( 

118 (mybands,) + X) 

119 totalbands = self.comm.sum_scalar(mybands) 

120 # Dims is (totalbands,) + self.dims[1:], where 

121 # self.dims[1:] is extra dimensions, such as spin. 

122 return self.new(data=data, 

123 dims=(totalbands,) + self.dims[1:]) 

124 

125 def copy(self): 

126 return self.new(data=self.data.copy()) 

127 

128 def sanity_check(self) -> None: 

129 """Sanity check.""" 

130 pass 

131 

132 def __getitem__(self, index): 

133 raise NotImplementedError 

134 

135 def __bool__(self): 

136 raise ValueError 

137 

138 def __len__(self): 

139 return self.dims[0] 

140 

141 def __iter__(self): 

142 for index in range(self.dims[0]): 

143 yield self[index] 

144 

145 def flat(self): 

146 if self.dims == (): 

147 yield self 

148 else: 

149 for index in np.indices(self.dims).reshape((len(self.dims), -1)).T: 

150 yield self[tuple(index)] 

151 

152 def to_xp(self, xp): 

153 if xp is self.xp: 

154 assert xp is np, 'cp -> cp should not be needed!' 

155 return self 

156 if xp is np: 

157 return self.new(data=self.xp.asnumpy(self.data)) 

158 else: 

159 return self.new(data=xp.asarray(self.data)) 

160 

161 @property 

162 def matrix(self) -> Matrix: 

163 if self._matrix is not None: 

164 return self._matrix 

165 

166 nx = prod(self.myshape) 

167 shape = (self.dims[0], prod(self.dims[1:]) * nx) 

168 myshape = (self.mydims[0], prod(self.mydims[1:]) * nx) 

169 dist = (self.comm, -1, 1) 

170 

171 data = self.data.reshape(myshape) 

172 self._matrix = Matrix(*shape, data=data, dist=dist) 

173 

174 return self._matrix 

175 

176 @trace 

177 def matrix_elements(self, 

178 other: Self, 

179 *, 

180 out: Matrix | None = None, 

181 symmetric: bool | Literal['_default'] = '_default', 

182 function=None, 

183 domain_sum=True, 

184 cc: bool = False) -> Matrix: 

185 if symmetric == '_default': 

186 symmetric = self is other 

187 

188 comm = self.comm 

189 

190 if out is None: 

191 out = Matrix(self.dims[0], other.dims[0], 

192 dist=(comm, -1, 1), 

193 dtype=self.desc.dtype, 

194 xp=self.xp) 

195 

196 if comm.size == 1: 

197 assert other.comm.size == 1 

198 if function: 

199 assert symmetric 

200 other = function(other) 

201 

202 M1 = self.matrix 

203 M2 = other.matrix 

204 out = M1.multiply(M2, opb='C', alpha=self.dv, 

205 symmetric=symmetric, out=out) 

206 

207 # Plane-wave expansion of real-valued 

208 # functions needs a correction: 

209 self._matrix_elements_correction(M1, M2, out, symmetric) 

210 else: 

211 if symmetric: 

212 _parallel_me_sym(self, out, function) 

213 else: 

214 _parallel_me(self, other, out) 

215 

216 if not cc: 

217 out.complex_conjugate() 

218 

219 if domain_sum: 

220 self.domain_comm.sum(out.data) 

221 return out 

222 

223 def _matrix_elements_correction(self, 

224 M1: Matrix, 

225 M2: Matrix, 

226 out: Matrix, 

227 symmetric: bool) -> None: 

228 """Hook for PlaneWaveExpansion.""" 

229 pass 

230 

231 def abs_square(self, 

232 weights: Array1D, 

233 out: UGArray) -> None: 

234 """Add weighted absolute square of data to output array. 

235 

236 See also :xkcd:`849`. 

237 """ 

238 raise NotImplementedError 

239 

240 def add_ked(self, 

241 weights: Array1D, 

242 out: UGArray) -> None: 

243 """Add weighted absolute square of gradient of data to output array.""" 

244 raise NotImplementedError 

245 

246 def gather(self, out=None, broadcast=False): 

247 raise NotImplementedError 

248 

249 def gathergather(self): 

250 a_xX = self.gather() # gather X 

251 if a_xX is not None: 

252 m_xX = a_xX.matrix.gather() # gather x 

253 if m_xX.dist.comm.rank == 0: 

254 data = m_xX.data 

255 if a_xX.data.dtype != data.dtype: 

256 data = data.view(complex) 

257 return self.desc.new(comm=None).from_data(data) 

258 

259 def scatter_from(self, data: ArrayND | None = None) -> None: 

260 raise NotImplementedError 

261 

262 def redist(self, 

263 domain, 

264 comm1: MPIComm, comm2: MPIComm) -> DistributedArrays: 

265 result = domain.empty(self.dims) 

266 if comm1.rank == 0: 

267 a = self.gather() 

268 else: 

269 a = None 

270 if comm2.rank == 0: 

271 result.scatter_from(a) 

272 comm2.broadcast(result.data, 0) 

273 return result 

274 

275 def interpolate(self, 

276 plan1: fftw.FFTPlans | None = None, 

277 plan2: fftw.FFTPlans | None = None, 

278 grid: UGDesc | None = None, 

279 out: UGArray | None = None) -> UGArray: 

280 raise NotImplementedError 

281 

282 def integrate(self, other: Self | None = None) -> np.ndarray: 

283 raise NotImplementedError 

284 

285 def norm2(self, kind: str = 'normal', skip_sum=False) -> np.ndarray: 

286 raise NotImplementedError 

287 

288 

289def _parallel_me(psit1_nX: DistributedArrays, 

290 psit2_nX: DistributedArrays, 

291 M_nn: Matrix) -> None: 

292 

293 comm = psit2_nX.comm 

294 nbands = psit2_nX.dims[0] 

295 

296 psit1_nX = psit1_nX[:] 

297 

298 B = (nbands + comm.size - 1) // comm.size 

299 

300 n_r = [min(r * B, nbands) for r in range(comm.size + 1)] 

301 

302 xp = psit1_nX.xp 

303 buf1_nX = psit1_nX.desc.empty(B, xp=xp) 

304 buf2_nX = psit1_nX.desc.empty(B, xp=xp) 

305 psit_nX = psit2_nX 

306 

307 for shift in range(comm.size): 

308 rrequest = None 

309 srequest = None 

310 

311 if shift < comm.size - 1: 

312 srank = (comm.rank + shift + 1) % comm.size 

313 rrank = (comm.rank - shift - 1) % comm.size 

314 n1 = n_r[rrank] 

315 n2 = n_r[rrank + 1] 

316 mynb = n2 - n1 

317 if mynb > 0: 

318 rrequest = comm.receive(buf1_nX.data[:mynb], rrank, 11, False) 

319 if psit2_nX.data.size > 0: 

320 srequest = comm.send(psit2_nX.data, srank, 11, False) 

321 

322 r2 = (comm.rank - shift) % comm.size 

323 n1 = n_r[r2] 

324 n2 = n_r[r2 + 1] 

325 m_nn = psit1_nX.matrix_elements(psit_nX[:n2 - n1], 

326 cc=True, domain_sum=False) 

327 

328 M_nn.data[:, n1:n2] = m_nn.data 

329 

330 if rrequest: 

331 comm.wait(rrequest) 

332 if srequest: 

333 comm.wait(srequest) 

334 

335 psit_nX = buf1_nX 

336 buf1_nX, buf2_nX = buf2_nX, buf1_nX 

337 

338 

339def _parallel_me_sym(psit1_nX: DistributedArrays, 

340 M_nn: Matrix, 

341 operator: None | Callable[[DistributedArrays], 

342 DistributedArrays] 

343 ) -> None: 

344 """...""" 

345 comm = psit1_nX.comm 

346 nbands = psit1_nX.dims[0] 

347 B = (nbands + comm.size - 1) // comm.size 

348 mynbands = psit1_nX.mydims[0] 

349 

350 n_r = [min(r * B, nbands) for r in range(comm.size + 1)] 

351 mynbands_r = [n_r[r + 1] - n_r[r] for r in range(comm.size)] 

352 assert mynbands_r[comm.rank] == mynbands 

353 

354 xp = psit1_nX.xp 

355 psit2_nX = psit1_nX 

356 buf1_nX = psit1_nX.desc.empty(B, xp=xp) 

357 buf2_nX = psit1_nX.desc.empty(B, xp=xp) 

358 half = comm.size // 2 

359 

360 for shift in range(half + 1): 

361 rrequest = None 

362 srequest = None 

363 

364 if shift < half: 

365 srank = (comm.rank + shift + 1) % comm.size 

366 rrank = (comm.rank - shift - 1) % comm.size 

367 skip = comm.size % 2 == 0 and shift == half - 1 

368 rmynb = mynbands_r[rrank] 

369 if not (skip and comm.rank < half) and rmynb > 0: 

370 rrequest = comm.receive(buf1_nX.data[:rmynb], rrank, 11, False) 

371 if not (skip and comm.rank >= half) and psit1_nX.data.size > 0: 

372 srequest = comm.send(psit1_nX.data, srank, 11, False) 

373 

374 if shift == 0: 

375 if operator is not None: 

376 op_psit1_nX = operator(psit1_nX) 

377 else: 

378 op_psit1_nX = psit1_nX 

379 op_psit1_nX = op_psit1_nX[:] # local view 

380 

381 if not (comm.size % 2 == 0 and shift == half and comm.rank < half): 

382 r2 = (comm.rank - shift) % comm.size 

383 n1 = n_r[r2] 

384 n2 = n_r[r2 + 1] 

385 m_nn = op_psit1_nX.matrix_elements(psit2_nX[:n2 - n1], 

386 symmetric=(shift == 0), 

387 cc=True, domain_sum=False) 

388 M_nn.data[:, n1:n2] = m_nn.data 

389 

390 if rrequest: 

391 comm.wait(rrequest) 

392 if srequest: 

393 comm.wait(srequest) 

394 

395 psit2_nX = buf1_nX 

396 buf1_nX, buf2_nX = buf2_nX, buf1_nX 

397 

398 requests = [] 

399 blocks = [] 

400 nrows = (comm.size - 1) // 2 

401 for row in range(nrows): 

402 for column in range(comm.size - nrows + row, comm.size): 

403 if comm.rank == row: 

404 n1 = n_r[column] 

405 n2 = n_r[column + 1] 

406 if mynbands > 0 and n2 > n1: 

407 requests.append( 

408 comm.send(M_nn.data[:, n1:n2].T.conj().copy(), 

409 column, 12, False)) 

410 elif comm.rank == column: 

411 n1 = n_r[row] 

412 n2 = n_r[row + 1] 

413 if mynbands > 0 and n2 > n1: 

414 block = xp.empty((mynbands, n2 - n1), M_nn.dtype) 

415 blocks.append((n1, n2, block)) 

416 requests.append(comm.receive(block, row, 12, False)) 

417 

418 comm.waitall(requests) 

419 for n1, n2, block in blocks: 

420 M_nn.data[:, n1:n2] = block