Coverage for gpaw/core/atom_arrays.py: 75%

317 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2 

3import numbers 

4from typing import Sequence, overload, Literal 

5 

6import numpy as np 

7from gpaw.core.matrix import Matrix 

8from gpaw.gpu import cupy as cp, XP 

9from gpaw.mpi import MPIComm, serial_comm 

10from gpaw.new import prod, zips 

11from gpaw.typing import Array1D, ArrayLike1D 

12from gpaw.new.c import dH_aii_times_P_ani_gpu 

13from gpaw.utilities import as_real_dtype 

14 

15 

16class AtomArraysLayout(XP): 

17 def __init__(self, 

18 shapes: Sequence[int | tuple[int, ...]], 

19 atomdist: AtomDistribution | MPIComm = serial_comm, 

20 dtype=float, 

21 xp=None): 

22 """Description of layout of atom arrays. 

23 

24 Parameters 

25 ---------- 

26 shapes: 

27 Shapse of arrays - one for each atom. 

28 atomdist: 

29 Distribution of atoms. 

30 dtype: 

31 Data-type (float or complex). 

32 """ 

33 self.shape_a = [shape if isinstance(shape, tuple) else (shape,) 

34 for shape in shapes] 

35 if not isinstance(atomdist, AtomDistribution): 

36 atomdist = AtomDistribution(np.zeros(len(shapes), int), atomdist) 

37 self.atomdist = atomdist 

38 self.dtype = np.dtype(dtype) 

39 XP.__init__(self, xp or np) 

40 

41 self.size = sum(prod(shape) for shape in self.shape_a) 

42 

43 self.myindices = [] 

44 self.mysize = 0 

45 I1 = 0 

46 for a in atomdist.indices: 

47 I2 = I1 + prod(self.shape_a[a]) 

48 self.myindices.append((a, I1, I2)) 

49 self.mysize += I2 - I1 

50 I1 = I2 

51 

52 def __len__(self): 

53 return len(self.shape_a) 

54 

55 def __repr__(self): 

56 return (f'AtomArraysLayout({self.shape_a}, {self.atomdist}, ' 

57 f'{self.dtype}, xp={self.xp.__name__})') 

58 

59 def new(self, atomdist=None, dtype=None, xp=None): 

60 """Create new AtomsArrayLayout object with new atomdist.""" 

61 return AtomArraysLayout(self.shape_a, 

62 atomdist or self.atomdist, 

63 dtype or self.dtype, 

64 xp or self.xp) 

65 

66 def empty(self, 

67 dims: int | tuple[int, ...] = (), 

68 comm: MPIComm = serial_comm) -> AtomArrays: 

69 """Create new AtomArrays object. 

70 

71 parameters 

72 ---------- 

73 dims: 

74 Extra dimensions. 

75 comm: 

76 Distribute dimensions along this communicator. 

77 """ 

78 return AtomArrays(self, dims, comm) 

79 

80 def zeros(self, 

81 dims: int | tuple[int, ...] = (), 

82 comm: MPIComm = serial_comm) -> AtomArrays: 

83 aa = self.empty(dims, comm) 

84 aa.data[:] = 0.0 

85 return aa 

86 

87 def sizes(self) -> tuple[list[dict[int, int]], Array1D]: 

88 """Compute array sizes for all ranks. 

89 

90 >>> AtomArraysLayout([3, 4]).sizes() 

91 ([{0: 3, 1: 4}], array([7])) 

92 """ 

93 comm = self.atomdist.comm 

94 size_ra: list[dict[int, int]] = [{} for _ in range(comm.size)] 

95 size_r = np.zeros(comm.size, int) 

96 for a, (rank, shape) in enumerate(zips(self.atomdist.rank_a, 

97 self.shape_a)): 

98 size = prod(shape) 

99 size_ra[rank][a] = size 

100 size_r[rank] += size 

101 return size_ra, size_r 

102 

103 

104class AtomDistribution: 

105 def __init__(self, ranks: ArrayLike1D, comm: MPIComm = serial_comm): 

106 """Atom-distribution. 

107 

108 Parameters 

109 ---------- 

110 ranks: 

111 List of ranks, one rank per atom. 

112 comm: 

113 MPI-communicator. 

114 """ 

115 self.comm = comm 

116 self.rank_a = np.array(ranks) 

117 # convert from np.int64 -> int: 

118 self.indices = [int(a) for a in np.where(self.rank_a == comm.rank)[0]] 

119 

120 def __len__(self) -> int: 

121 return len(self.rank_a) 

122 

123 @classmethod 

124 def from_number_of_atoms(cls, 

125 natoms: int, 

126 comm: MPIComm = serial_comm) -> AtomDistribution: 

127 """Distribute atoms evenly. 

128 

129 >>> AtomDistribution.from_number_of_atoms(3).rank_a 

130 array([0, 0, 0]) 

131 """ 

132 blocksize = (natoms + comm.size - 1) // comm.size 

133 rank_a = np.empty(natoms, int) 

134 a1 = 0 

135 for rank in range(comm.size): 

136 a2 = a1 + blocksize 

137 rank_a[a1:a2] = rank 

138 if a2 >= natoms: 

139 break 

140 a1 = a2 

141 return cls(rank_a, comm) 

142 

143 @classmethod 

144 def from_atom_indices(cls, 

145 atom_indices: Sequence[int], 

146 comm: MPIComm = serial_comm, 

147 *, 

148 natoms: int | None = None) -> AtomDistribution: 

149 """Create distribution from atom indices. 

150 

151 >>> AtomDistribution.from_atom_indices([0, 1, 2]).rank_a 

152 array([0, 0, 0]) 

153 """ 

154 if natoms is None: 

155 natoms = comm.max_scalar(max(atom_indices)) + 1 

156 rank_a = np.zeros(natoms, int) # type: ignore 

157 rank_a[atom_indices] = comm.rank 

158 comm.sum(rank_a) 

159 return cls(rank_a, comm) 

160 

161 def __repr__(self): 

162 return (f'AtomDistribution(ranks={self.rank_a}, ' 

163 f'comm={self.comm.rank}/{self.comm.size})') 

164 

165 def gather(self): 

166 return AtomDistribution(np.zeros(len(self.rank_a), int)) 

167 

168 

169class AtomArrays: 

170 def __init__(self, 

171 layout: AtomArraysLayout, 

172 dims: int | Sequence[int] = (), 

173 comm: MPIComm = serial_comm, 

174 data: np.ndarray | None = None): 

175 """AtomArrays object. 

176 

177 parameters 

178 ---------- 

179 layout: 

180 Layout-description. 

181 dims: 

182 Extra dimensions. 

183 comm: 

184 Distribute dimensions along this communicator. 

185 data: 

186 Data array for storage. 

187 """ 

188 myshape = (layout.mysize,) 

189 domain_comm = layout.atomdist.comm 

190 dtype = layout.dtype 

191 

192 self.myshape = myshape 

193 self.comm = comm 

194 self.domain_comm = domain_comm 

195 

196 # convert int to tuple: 

197 self.dims = tuple(dims) if not isinstance(dims, int) else (dims,) 

198 

199 if self.dims: 

200 d1, d2 = self.my_slice() 

201 mydims0 = d2 - d1 

202 self.mydims = (mydims0,) + self.dims[1:] 

203 else: 

204 self.mydims = () 

205 

206 fullshape = self.mydims + self.myshape 

207 

208 if data is not None: 

209 if data.shape != fullshape: 

210 raise ValueError( 

211 f'Bad shape for data: {data.shape} != {fullshape}') 

212 if data.dtype != dtype: 

213 raise ValueError( 

214 f'Bad dtype for data: {data.dtype} != {dtype}') 

215 else: 

216 data = layout.xp.empty(fullshape, dtype) 

217 

218 self.data = data 

219 self._matrix: Matrix | None = None # matrix view 

220 

221 self.layout = layout 

222 self._arrays = {} 

223 for a, I1, I2 in layout.myindices: 

224 self._arrays[a] = self.data[..., I1:I2].reshape( 

225 self.mydims + layout.shape_a[a]) 

226 

227 def __len__(self) -> int: 

228 return len(self.layout) 

229 

230 def my_slice(self) -> tuple[int, int]: 

231 mydims0 = (self.dims[0] + self.comm.size - 1) // self.comm.size 

232 d1 = min(self.comm.rank * mydims0, self.dims[0]) 

233 d2 = min((self.comm.rank + 1) * mydims0, self.dims[0]) 

234 return d1, d2 

235 

236 def __repr__(self): 

237 txt = f'AtomArrays({self.layout}, dims={self.dims}' 

238 if self.comm.size > 1: 

239 txt += f', comm={self.comm.rank}/{self.comm.size}' 

240 return txt + ')' 

241 

242 @property 

243 def matrix(self) -> Matrix: 

244 if self._matrix is not None: 

245 return self._matrix 

246 

247 shape = (self.dims[0], prod(self.dims[1:]) * prod(self.myshape)) 

248 myshape = (self.mydims[0], prod(self.mydims[1:]) * prod(self.myshape)) 

249 dist = (self.comm, -1, 1) 

250 

251 data = self.data.reshape(myshape) 

252 self._matrix = Matrix(*shape, data=data, dist=dist) 

253 

254 return self._matrix 

255 

256 def new(self, *, layout=None, data=None, xp=None): 

257 """Create new AtomArrays object of same kind. 

258 

259 Parameters 

260 ---------- 

261 layout: 

262 Layout-description. 

263 data: 

264 Array to use for storage. 

265 """ 

266 if xp is not None: 

267 assert layout is None 

268 assert data is None 

269 if self.layout.xp is not xp: 

270 layout = self.layout.new(xp=xp) 

271 return AtomArrays(layout or self.layout, 

272 self.dims, 

273 self.comm, 

274 data=data) 

275 

276 def to_cpu(self): 

277 if self.layout.xp is np: 

278 return self 

279 return self.new(layout=self.layout.new(xp=np), 

280 data=cp.asnumpy(self.data)) 

281 

282 def to_xp(self, xp): 

283 if self.layout.xp is xp: 

284 assert xp is np, 'cp -> cp should not be needed!' 

285 return self 

286 if xp is np: 

287 return self.new(layout=self.layout.new(xp=np), 

288 data=cp.asnumpy(self.data)) 

289 return self.new(layout=self.layout.new(xp=cp), 

290 data=cp.asarray(self.data)) 

291 

292 @overload 

293 def __getitem__(self, a: int) -> np.ndarray: 

294 ... 

295 

296 @overload 

297 def __getitem__(self, a: tuple) -> AtomArrays: 

298 ... 

299 

300 def __getitem__(self, a): 

301 if isinstance(a, numbers.Integral): 

302 return self._arrays[a] 

303 assert len(self.dims) >= 1 

304 a0, a1 = a 

305 assert a0 == slice(None) 

306 data = self.data[a1] 

307 a_ai = AtomArrays(self.layout, dims=data.shape[:-1], data=data) 

308 return a_ai 

309 

310 def copy(self): 

311 return self.new(data=self.data.copy()) 

312 

313 def get(self, a): 

314 return self._arrays.get(a) 

315 

316 def __setitem__(self, a, value): 

317 self._arrays[a][:] = value 

318 

319 def __contains__(self, a): 

320 return a in self._arrays 

321 

322 def items(self): 

323 return self._arrays.items() 

324 

325 def keys(self): 

326 return self._arrays.keys() 

327 

328 def values(self): 

329 return self._arrays.values() 

330 

331 @overload 

332 def gather(self, broadcast: Literal[False] = False, copy: bool = False 

333 ) -> AtomArrays | None: 

334 ... 

335 

336 @overload 

337 def gather(self, broadcast: Literal[True], copy: bool = False 

338 ) -> AtomArrays: 

339 ... 

340 

341 def gather(self, broadcast=False, copy=False): 

342 """Gather all atoms on master.""" 

343 comm = self.layout.atomdist.comm 

344 if comm.size == 1: 

345 if copy: 

346 aa = self.new() 

347 aa.data[:] = self.data 

348 return aa 

349 return self 

350 

351 if comm.rank == 0 or broadcast: 

352 aa = self.new(layout=self.layout.new(atomdist=serial_comm)) 

353 else: 

354 aa = None 

355 

356 if comm.rank == 0: 

357 size_ra, size_r = self.layout.sizes() 

358 n = prod(self.mydims) 

359 m = size_r.max() 

360 buffer = self.layout.xp.empty(n * m, self.layout.dtype) 

361 for rank in range(1, comm.size): 

362 buf = buffer[:n * size_r[rank]].reshape((n, size_r[rank])) 

363 comm.receive(buf, rank) 

364 b1 = 0 

365 for a, size in size_ra[rank].items(): 

366 b2 = b1 + size 

367 A = aa[a] 

368 A[:] = buf[:, b1:b2].reshape(A.shape) 

369 b1 = b2 

370 for a, array in self._arrays.items(): 

371 aa[a] = array 

372 else: 

373 comm.send(self.data, 0) 

374 

375 if broadcast: 

376 comm.broadcast(aa.data, 0) 

377 

378 return aa 

379 

380 def scatter_from(self, 

381 data: np.ndarray | AtomArrays | None = None) -> None: 

382 """Scatter atoms.""" 

383 if isinstance(data, AtomArrays): 

384 data = data.data 

385 comm = self.layout.atomdist.comm 

386 xp = self.layout.xp 

387 if comm.size == 1: 

388 assert data is not None 

389 self.data[:] = data 

390 return 

391 

392 if comm.rank != 0: 

393 comm.receive(self.data, 0, 42) 

394 return 

395 

396 size_ra, size_r = self.layout.sizes() 

397 aa = self.new(layout=self.layout.new(atomdist=serial_comm), 

398 data=data) 

399 requests = [] 

400 for rank, (totsize, size_a) in enumerate(zips(size_r, size_ra)): 

401 if rank != 0: 

402 buf = xp.empty(self.mydims + (totsize,), self.layout.dtype) 

403 b1 = 0 

404 for a, size in size_a.items(): 

405 b2 = b1 + size 

406 buf[..., b1:b2] = aa[a].reshape(self.mydims + (size,)) 

407 b1 = b2 

408 request = comm.send(buf, rank, 42, False) 

409 # Remember to store a reference to the 

410 # send buffer (buf) so that is isn't 

411 # deallocated 

412 requests.append((request, buf)) 

413 else: 

414 for a in size_a: 

415 self[a] = aa[a] 

416 

417 for request, _ in requests: 

418 comm.wait(request) 

419 

420 def to_lower_triangle(self): 

421 """Convert `N*N` matrices to `N*(N+1)/2` vectors. 

422 

423 >>> a = AtomArraysLayout([(3, 3)]).empty() 

424 >>> a[0][:] = [[11, 12, 13], 

425 ... [12, 22, 23], 

426 ... [13, 23, 33]] 

427 >>> a.to_lower_triangle()[0] 

428 array([11., 12., 22., 13., 23., 33.]) 

429 """ 

430 shape_a = [] 

431 for i1, i2 in self.layout.shape_a: 

432 assert i1 == i2 

433 shape_a.append((i1 * (i1 + 1) // 2,)) 

434 xp = self.layout.xp 

435 layout = AtomArraysLayout(shape_a, 

436 self.layout.atomdist, 

437 dtype=self.layout.dtype, 

438 xp=xp) 

439 a_axp = layout.empty(self.dims) 

440 for a_xii, a_xp in zips(self.values(), a_axp.values()): 

441 i = a_xii.shape[-1] 

442 L = xp.tril_indices(i) 

443 for a_p, a_ii in zips(a_xp.reshape((-1, i * (i + 1) // 2)), 

444 a_xii.reshape((-1, i, i))): 

445 a_p[:] = a_ii[L] 

446 

447 return a_axp 

448 

449 def to_full(self): 

450 r"""Convert `N(N+1)/2` vectors to `N\times N` matrices. 

451 

452 >>> a = AtomArraysLayout([6]).empty() 

453 >>> a[0][:] = [1, 2, 3, 4, 5, 6] 

454 >>> a.to_full()[0] 

455 array([[1., 2., 4.], 

456 [2., 3., 5.], 

457 [4., 5., 6.]]) 

458 """ 

459 shape_a = [] 

460 for (p,) in self.layout.shape_a: 

461 i = int((2 * p + 0.25)**0.5) 

462 shape_a.append((i, i)) 

463 layout = AtomArraysLayout(shape_a, 

464 self.layout.atomdist, 

465 self.layout.dtype) 

466 a_axii = layout.empty(self.dims) 

467 for a_xp, a_xii in zips(self.values(), a_axii.values()): 

468 i = a_xii.shape[-1] 

469 a_xii[(...,) + np.tril_indices(i)] = a_xp 

470 u = (...,) + np.triu_indices(i, 1) 

471 a_xii[u] = np.swapaxes(a_xii, -1, -2)[u].conj() 

472 return a_axii 

473 

474 def moved(self, atomdist): 

475 if (self.layout.atomdist.rank_a == atomdist.rank_a).all(): 

476 return self 

477 assert self.comm.size == 1 

478 layout = self.layout.new(atomdist=atomdist) 

479 new = layout.empty(self.dims) 

480 comm = atomdist.comm 

481 xp = self.layout.xp 

482 requests = [] 

483 for a, I1, I2 in self.layout.myindices: 

484 r = layout.atomdist.rank_a[a] 

485 if r == comm.rank: 

486 new[a][:] = self[a] 

487 else: 

488 requests.append(comm.send(xp.ascontiguousarray(self[a]), 

489 r, block=False)) 

490 

491 for a, I1, I2 in layout.myindices: 

492 r = self.layout.atomdist.rank_a[a] 

493 if r != comm.rank: 

494 target = new[a] 

495 buf = xp.empty_like(target) 

496 comm.receive(buf, r) 

497 target[:] = buf 

498 

499 comm.waitall(requests) 

500 return new 

501 

502 def redist(self, 

503 atomdist: AtomDistribution, 

504 comm1: MPIComm, 

505 comm2: MPIComm) -> AtomArrays: 

506 layout = self.layout.new(atomdist=atomdist) 

507 result = layout.empty(dims=self.dims) 

508 if comm1.rank == 0: 

509 a = self.gather() 

510 else: 

511 a = None 

512 if comm2.rank == 0: 

513 result.scatter_from(a) 

514 comm2.broadcast(result.data, 0) 

515 return result 

516 

517 def block_diag_multiply(self, 

518 block_diag_matrix_axii: AtomArrays, 

519 out_ani: AtomArrays, 

520 index: int | None = None) -> None: 

521 """Multiply by block diagonal matrix. 

522 

523 with A, B and C refering to ``self``, ``block_diag_matrix_axii`` and 

524 ``out_ani``::: 

525 

526 -- a a a 

527 > A B -> C 

528 -- ni ij nj 

529 i 

530 

531 If index is not None, ``block_diag_matrix_axii`` must have an extra 

532 dimension: :math:`B_{ij}^{ax}` and x=index is used. 

533 """ 

534 xp = self.layout.xp 

535 if xp is np: 

536 if index is not None: 

537 block_diag_matrix_axii = block_diag_matrix_axii[:, index] 

538 for P_ni, dX_ii, out_ni in zips(self.values(), 

539 block_diag_matrix_axii.values(), 

540 out_ani.values()): 

541 out_ni[:] = P_ni @ dX_ii 

542 return 

543 

544 ni_a = xp.array( 

545 [I2 - I1 for a, I1, I2 in self.layout.myindices], 

546 dtype=np.int32) 

547 data = block_diag_matrix_axii.data 

548 if index is not None: 

549 data = data[index] 

550 if self.data.size > 0: 

551 realdtype = as_real_dtype(self.data.dtype) 

552 dH_aii_times_P_ani_gpu(xp.asarray(data, dtype=realdtype), 

553 ni_a, self.data, out_ani.data)