Coverage for gpaw/grid_descriptor.py: 74%

383 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4"""Grid-descriptors 

5 

6This module contains a classes defining uniform 3D grids. 

7For radial grid descriptors, look atom/radialgd.py. 

8 

9""" 

10 

11import numbers 

12from math import pi 

13from typing import Sequence 

14from numpy import lcm 

15from fractions import Fraction 

16 

17import numpy as np 

18 

19from scipy.ndimage import map_coordinates 

20 

21import gpaw.cgpaw as cgpaw 

22import gpaw.mpi as mpi 

23from gpaw.domain import Domain 

24from gpaw.new import prod 

25from gpaw.typing import Array1D, Array3D, Vector 

26from gpaw.utilities.blas import mmm, r2k, rk 

27 

28NONBLOCKING = False 

29 

30 

31class GridBoundsError(ValueError): 

32 pass 

33 

34 

35class BadGridError(ValueError): 

36 pass 

37 

38 

39class GridDescriptor(Domain): 

40 r"""Descriptor-class for uniform 3D grid 

41 

42 A ``GridDescriptor`` object holds information on how functions, such 

43 as wave functions and electron densities, are discreticed in a 

44 certain domain in space. The main information here is how many 

45 grid points are used in each direction of the unit cell. 

46 

47 There are methods for tasks such as allocating arrays, performing 

48 symmetry operations and integrating functions over space. All 

49 methods work correctly also when the domain is parallelized via 

50 domain decomposition. 

51 

52 This is how a 2x2x2 3D array is laid out in memory:: 

53 

54 3-----7 

55 |\ |\ 

56 | \ | \ 

57 | 1-----5 z 

58 2--|--6 | y | 

59 \ | \ | \ | 

60 \| \| \| 

61 0-----4 +-----x 

62 

63 Example: 

64 

65 >>> a = np.zeros((2, 2, 2)) 

66 >>> a.ravel()[:] = range(8) 

67 >>> a 

68 array([[[0., 1.], 

69 [2., 3.]], 

70 <BLANKLINE> 

71 [[4., 5.], 

72 [6., 7.]]]) 

73 """ 

74 

75 ndim = 3 # dimension of ndarrays 

76 

77 def __init__(self, N_c, cell_cv=[1, 1, 1], pbc_c=True, 

78 comm=None, parsize_c=None, allow_empty_domains=False): 

79 """Construct grid-descriptor object. 

80 

81 parameters: 

82 

83 N_c: 3 ints 

84 Number of grid points along axes. 

85 cell_cv: 3 float's or 3x3 floats 

86 Unit cell. 

87 pbc_c: one or three bools 

88 Periodic boundary conditions flag(s). 

89 comm: MPI-communicator 

90 Communicator for domain-decomposition. 

91 parsize_c: tuple of 3 ints, a single int or None 

92 Number of domains. 

93 allow_empty_domains: bool 

94 Allow parallelization that would generate empty domains. 

95 

96 Note that if pbc_c[c] is False, then the actual number of gridpoints 

97 along axis c is one less than N_c[c]. 

98 

99 Attributes: 

100 

101 ========== ======================================================== 

102 ``dv`` Volume per grid point. 

103 ``h_cv`` Array of the grid spacing along the three axes. 

104 ``N_c`` Array of the number of grid points along the three axes. 

105 ``n_c`` Number of grid points on this CPU. 

106 ``beg_c`` Beginning of grid-point indices (inclusive). 

107 ``end_c`` End of grid-point indices (exclusive). 

108 ``comm`` MPI-communicator for domain decomposition. 

109 ========== ======================================================== 

110 

111 The length unit is Bohr. 

112 """ 

113 

114 if isinstance(pbc_c, int): 

115 pbc_c = (pbc_c,) * 3 

116 if comm is None: 

117 comm = mpi.world 

118 

119 self.N_c = np.array(N_c, int) 

120 if (self.N_c != N_c).any(): 

121 raise ValueError('Non-int number of grid points %s' % N_c) 

122 

123 Domain.__init__(self, cell_cv, pbc_c, comm, parsize_c, self.N_c) 

124 self.rank = self.comm.rank 

125 

126 self.beg_c = np.empty(3, int) 

127 self.end_c = np.empty(3, int) 

128 

129 self.n_cp = [] 

130 for c in range(3): 

131 n_p = (np.arange(self.parsize_c[c] + 1) * float(self.N_c[c]) / 

132 self.parsize_c[c]) 

133 n_p = np.around(n_p + 0.4999).astype(int) 

134 

135 if not self.pbc_c[c]: 

136 n_p[0] = 1 

137 

138 if np.any(n_p[1:] == n_p[:-1]): 

139 if allow_empty_domains: 

140 # If there are empty domains, sort them to the end 

141 n_p[:] = (np.arange(self.parsize_c[c] + 1) + 

142 1 - self.pbc_c[c]).clip(0, self.N_c[c]) 

143 else: 

144 msg = ('Grid {} too small for {} cores!' 

145 .format('x'.join(str(n) for n in self.N_c), 

146 'x'.join(str(n) for n in self.parsize_c))) 

147 raise BadGridError(msg) 

148 

149 self.beg_c[c] = n_p[self.parpos_c[c]] 

150 self.end_c[c] = n_p[self.parpos_c[c] + 1] 

151 self.n_cp.append(n_p) 

152 

153 self.n_c = self.end_c - self.beg_c 

154 

155 self.h_cv = self.cell_cv / self.N_c[:, np.newaxis] 

156 self.volume = abs(np.linalg.det(self.cell_cv)) 

157 self.dv = self.volume / self.N_c.prod() 

158 

159 self.orthogonal = not (self.cell_cv - 

160 np.diag(self.cell_cv.diagonal())).any() 

161 

162 def __repr__(self): 

163 if self.orthogonal: 

164 cellstring = np.diag(self.cell_cv).tolist() 

165 else: 

166 cellstring = self.cell_cv.tolist() 

167 

168 pcoords = tuple(self.get_processor_position_from_rank()) 

169 return ('GridDescriptor(%s, cell_cv=%s, pbc_c=%s, comm=[%d/%d, ' 

170 'domain=%s], parsize=%s)' 

171 % (self.N_c.tolist(), cellstring, 

172 np.array(self.pbc_c).astype(int).tolist(), self.comm.rank, 

173 self.comm.size, pcoords, self.parsize_c.tolist())) 

174 

175 def new_descriptor(self, N_c=None, cell_cv=None, pbc_c=None, 

176 comm=None, parsize_c=None, allow_empty_domains=False): 

177 """Create new descriptor based on this one. 

178 

179 The new descriptor will use the same class (possibly a subclass) 

180 and all arguments will be equal to those of this descriptor 

181 unless new arguments are provided.""" 

182 if N_c is None: 

183 N_c = self.N_c 

184 if cell_cv is None: 

185 cell_cv = self.cell_cv 

186 if pbc_c is None: 

187 pbc_c = self.pbc_c 

188 if comm is None: 

189 comm = self.comm 

190 if parsize_c is None and comm.size == self.comm.size: 

191 parsize_c = self.parsize_c 

192 return self.__class__(N_c, cell_cv, pbc_c, comm, parsize_c, 

193 allow_empty_domains) 

194 

195 def coords(self, c, pad=True): 

196 """Return coordinates along one of the three axes. 

197 

198 Useful for plotting:: 

199 

200 import matplotlib.pyplot as plt 

201 plt.plot(gd.coords(0), data[:, 0, 0]) 

202 plt.show() 

203 

204 """ 

205 L = np.linalg.norm(self.cell_cv[c]) 

206 N = self.N_c[c] 

207 h = L / N 

208 p = self.pbc_c[c] or pad 

209 return np.linspace((1 - p) * h, L, N - 1 + p, False) 

210 

211 def get_grid_spacings(self): 

212 L_c = (np.linalg.inv(self.cell_cv)**2).sum(0)**-0.5 

213 return L_c / self.N_c 

214 

215 def get_size_of_global_array(self, pad=False): 

216 if pad: 

217 return self.N_c 

218 else: 

219 return self.N_c - 1 + self.pbc_c 

220 

221 def flat_index(self, G_c): 

222 g1, g2, g3 = G_c - self.beg_c 

223 return g3 + self.n_c[2] * (g2 + g1 * self.n_c[1]) 

224 

225 def get_slice(self): 

226 return [slice(b - 1 + p, e - 1 + p) for b, e, p in 

227 zip(self.beg_c, self.end_c, self.pbc_c)] 

228 

229 def zeros(self, n=(), dtype=float, global_array=False, pad=False, xp=np): 

230 """Return new zeroed 3D array for this domain. 

231 

232 The type can be set with the ``dtype`` keyword (default: 

233 ``float``). Extra dimensions can be added with ``n=dim``. A 

234 global array spanning all domains can be allocated with 

235 ``global_array=True``.""" 

236 

237 return self._new_array(n, dtype, True, global_array, pad, xp) 

238 

239 def empty(self, n=(), dtype=float, global_array=False, pad=False, xp=np): 

240 """Return new uninitialized 3D array for this domain. 

241 

242 The type can be set with the ``dtype`` keyword (default: 

243 ``float``). Extra dimensions can be added with ``n=dim``. A 

244 global array spanning all domains can be allocated with 

245 ``global_array=True``.""" 

246 

247 return self._new_array(n, dtype, False, global_array, pad, xp) 

248 

249 def _new_array(self, n=(), dtype=float, zero=True, 

250 global_array=False, pad=False, xp=np): 

251 if global_array: 

252 shape = self.get_size_of_global_array(pad) 

253 else: 

254 shape = self.n_c 

255 

256 if isinstance(n, numbers.Integral): 

257 n = (n,) 

258 

259 shape = n + tuple(shape) 

260 

261 if zero: 

262 return xp.zeros(shape, dtype) 

263 else: 

264 return xp.empty(shape, dtype) 

265 

266 def get_axial_communicator(self, axis): 

267 peer_ranks = [] 

268 pos_c = self.parpos_c.copy() 

269 for i in range(self.parsize_c[axis]): 

270 pos_c[axis] = i 

271 peer_ranks.append(self.get_rank_from_processor_position(pos_c)) 

272 peer_comm = self.comm.new_communicator(peer_ranks) 

273 return peer_comm 

274 

275 def integrate(self, a_xg, b_yg=None, 

276 global_integral=True, hermitian=False): 

277 """Integrate function(s) over domain. 

278 

279 a_xg: ndarray 

280 Function(s) to be integrated. 

281 b_yg: ndarray 

282 If present, integrate a_xg.conj() * b_yg. 

283 global_integral: bool 

284 If the array(s) are distributed over several domains, then the 

285 total sum will be returned. To get the local contribution 

286 only, use global_integral=False. 

287 hermitian: bool 

288 Result is hermitian. 

289 """ 

290 

291 xshape = a_xg.shape[:-3] 

292 

293 if b_yg is None: 

294 # Only one array: 

295 result = a_xg.reshape(xshape + (-1,)).sum(axis=-1) * self.dv 

296 if global_integral: 

297 if result.ndim == 0: 

298 result = self.comm.sum_scalar(result.item()) 

299 else: 

300 self.comm.sum(result) 

301 return result 

302 

303 gsize = prod(a_xg.shape[-3:]) 

304 A_xg = np.ascontiguousarray(a_xg.reshape((-1, gsize))) 

305 B_yg = np.ascontiguousarray(b_yg.reshape((-1, gsize))) 

306 

307 result_yx = np.zeros((len(B_yg), len(A_xg)), A_xg.dtype) 

308 

309 if a_xg is b_yg: 

310 rk(self.dv, A_xg, 0.0, result_yx) 

311 elif hermitian: 

312 r2k(0.5 * self.dv, A_xg, B_yg, 0.0, result_yx) 

313 else: 

314 # gemm(self.dv, A_xg, B_yg, 0.0, result_yx, 'c') 

315 mmm(self.dv, B_yg, 'N', A_xg, 'C', 0.0, result_yx) 

316 

317 if global_integral: 

318 self.comm.sum(result_yx) 

319 

320 yshape = b_yg.shape[:-3] 

321 result = result_yx.T.reshape(xshape + yshape) 

322 

323 if result.ndim == 0: 

324 return result.item() 

325 else: 

326 return result 

327 

328 def coarsen(self): 

329 """Return coarsened `GridDescriptor` object. 

330 

331 Reurned descriptor has 2x2x2 fewer grid points.""" 

332 

333 if (self.N_c % 2).any(): 

334 raise ValueError('Grid %s not divisible by 2!' % self.N_c) 

335 

336 return self.new_descriptor(self.N_c // 2) 

337 

338 def refine(self): 

339 """Return refined `GridDescriptor` object. 

340 

341 Returned descriptor has 2x2x2 more grid points.""" 

342 return self.new_descriptor(self.N_c * 2) 

343 

344 def get_boxes(self, spos_c, rcut, cut=True): 

345 """Find boxes enclosing sphere.""" 

346 N_c = self.N_c 

347 ncut = rcut * (self.icell_cv**2).sum(axis=1)**0.5 * self.N_c 

348 npos_c = spos_c * N_c 

349 beg_c = np.ceil(npos_c - ncut).astype(int) 

350 end_c = np.ceil(npos_c + ncut).astype(int) 

351 

352 if cut: 

353 for c in range(3): 

354 if not self.pbc_c[c]: 

355 if beg_c[c] < 0: 

356 beg_c[c] = 0 

357 if end_c[c] > N_c[c]: 

358 end_c[c] = N_c[c] 

359 else: 

360 for c in range(3): 

361 if (not self.pbc_c[c] and 

362 (beg_c[c] < 0 or end_c[c] > N_c[c])): 

363 msg = ('Box at %.3f %.3f %.3f crosses boundary. ' 

364 'Beg. of box %s, end of box %s, max box size %s' % 

365 (tuple(spos_c) + (beg_c, end_c, self.N_c))) 

366 raise GridBoundsError(msg) 

367 

368 range_c = ([], [], []) 

369 

370 for c in range(3): 

371 b = beg_c[c] 

372 e = b 

373 

374 while e < end_c[c]: 

375 b0 = b % N_c[c] 

376 

377 e = min(end_c[c], b + N_c[c] - b0) 

378 

379 if b0 < self.beg_c[c]: 

380 b1 = b + self.beg_c[c] - b0 

381 else: 

382 b1 = b 

383 

384 e0 = b0 - b + e 

385 

386 if e0 > self.end_c[c]: 

387 e1 = e - (e0 - self.end_c[c]) 

388 else: 

389 e1 = e 

390 if e1 > b1: 

391 range_c[c].append((b1, e1)) 

392 b = e 

393 

394 boxes = [] 

395 

396 for b0, e0 in range_c[0]: 

397 for b1, e1 in range_c[1]: 

398 for b2, e2 in range_c[2]: 

399 b = np.array((b0, b1, b2)) 

400 e = np.array((e0, e1, e2)) 

401 beg_c = np.array((b0 % N_c[0], b1 % N_c[1], b2 % N_c[2])) 

402 end_c = beg_c + e - b 

403 disp = (b - beg_c) / N_c 

404 beg_c = np.maximum(beg_c, self.beg_c) 

405 end_c = np.minimum(end_c, self.end_c) 

406 if (beg_c[0] < end_c[0] and 

407 beg_c[1] < end_c[1] and 

408 beg_c[2] < end_c[2]): 

409 boxes.append((beg_c, end_c, disp)) 

410 

411 return boxes 

412 

413 def get_nearest_grid_point(self, spos_c, force_to_this_domain=False): 

414 """Return index of nearest grid point. 

415 

416 The nearest grid point can be on a different CPU than the one the 

417 nucleus belongs to (i.e. return can be negative, or larger than 

418 gd.end_c), in which case something clever should be done. 

419 The point can be forced to the grid descriptors domain to be 

420 consistent with self.get_rank_from_position(spos_c). 

421 """ 

422 g_c = np.around(self.N_c * spos_c).astype(int) 

423 if force_to_this_domain: 

424 for c in range(3): 

425 g_c[c] = max(g_c[c], self.beg_c[c]) 

426 g_c[c] = min(g_c[c], self.end_c[c] - 1) 

427 return g_c - self.beg_c 

428 

429 def plane_wave(self, k_c): 

430 """Evaluate plane wave on grid. 

431 

432 Returns:: 

433 

434 _ _ 

435 ik.r 

436 e , 

437 

438 where the wave vector is given by k_c (in units of reciprocal 

439 lattice vectors).""" 

440 

441 index_Gc = np.indices(self.n_c).T + self.beg_c 

442 return np.exp(2j * pi * np.dot(index_Gc, k_c / self.N_c).T) 

443 

444 def symmetrize(self, a_g, op_scc, ft_sc=None): 

445 # ft_sc: fractional translations 

446 # XXXX documentation missing. This is some kind of array then? 

447 if len(op_scc) == 1: 

448 return 

449 

450 if ft_sc is not None and not ft_sc.any(): 

451 ft_sc = None 

452 

453 if ft_sc is not None: 

454 compat = self.check_grid_compatibility(ft_sc) 

455 if not compat: 

456 newN_c = self.get_nearest_compatible_grid(ft_sc) 

457 e = 'The specified number of grid points, ' \ 

458 + str(self.N_c) + ', is not compatible with the'\ 

459 ' symmetry of the atoms. Nearest compatible grid'\ 

460 ' size is ' + str(newN_c) + '.' 

461 raise ValueError(e) 

462 

463 A_g = self.collect(a_g) 

464 if self.comm.rank == 0: 

465 B_g = np.zeros_like(A_g) 

466 for s, op_cc in enumerate(op_scc): 

467 if ft_sc is None: 

468 cgpaw.symmetrize(A_g, B_g, op_cc, 1 - self.pbc_c) 

469 else: 

470 t_c = (ft_sc[s] * self.N_c).round().astype(int) 

471 cgpaw.symmetrize_ft(A_g, B_g, op_cc, t_c, 

472 1 - self.pbc_c) 

473 else: 

474 B_g = None 

475 self.distribute(B_g, a_g) 

476 a_g /= len(op_scc) 

477 

478 def check_grid_compatibility(self, ft_sc): 

479 # checks that grid is compatible with fractional translations 

480 t_sc = ft_sc * self.N_c 

481 intt_sc = t_sc.round().astype(int) 

482 compat = np.allclose(t_sc, intt_sc, atol=1e-6) 

483 return compat 

484 

485 def get_nearest_compatible_grid(self, ft_sc): 

486 newN_c = np.zeros(self.N_c.shape) 

487 for c, N in enumerate(self.N_c): 

488 frac_s = [Fraction(str(ft_c[c])).limit_denominator(1000) 

489 for ft_c in ft_sc] 

490 lcm_denom = lcm.reduce([frac.denominator for frac in frac_s]) 

491 dNminus = N - (N % lcm_denom) 

492 dNplus = dNminus + lcm_denom 

493 if dNminus > 0 and np.abs(dNminus - N) < np.abs(dNplus - N): 

494 newN_c[c] = dNminus 

495 else: 

496 newN_c[c] = dNplus 

497 return newN_c.astype(int) 

498 

499 def collect(self, a_xg, out=None, broadcast=False): 

500 """Collect distributed array to master-CPU or all CPU's.""" 

501 if self.comm.size == 1: 

502 if out is None: 

503 return a_xg 

504 out[:] = a_xg 

505 return out 

506 

507 xshape = a_xg.shape[:-3] 

508 

509 # Collect all arrays on the master: 

510 if self.rank != 0: 

511 # There can be several sends before the corresponding receives 

512 # are posted, so use syncronous send here 

513 self.comm.ssend(a_xg, 0, 301) 

514 if broadcast: 

515 A_xg = self.empty(xshape, a_xg.dtype, global_array=True) 

516 self.comm.broadcast(A_xg, 0) 

517 return A_xg 

518 else: 

519 return np.nan 

520 

521 # Put the subdomains from the slaves into the big array 

522 # for the whole domain: 

523 if out is None: 

524 A_xg = self.empty(xshape, a_xg.dtype, global_array=True) 

525 else: 

526 A_xg = out 

527 parsize_c = self.parsize_c 

528 r = 0 

529 for n0 in range(parsize_c[0]): 

530 b0, e0 = self.n_cp[0][n0:n0 + 2] - self.beg_c[0] 

531 for n1 in range(parsize_c[1]): 

532 b1, e1 = self.n_cp[1][n1:n1 + 2] - self.beg_c[1] 

533 for n2 in range(parsize_c[2]): 

534 b2, e2 = self.n_cp[2][n2:n2 + 2] - self.beg_c[2] 

535 if r != 0: 

536 a_xg = np.empty(xshape + 

537 ((e0 - b0), (e1 - b1), (e2 - b2)), 

538 a_xg.dtype.char) 

539 self.comm.receive(a_xg, r, 301) 

540 A_xg[..., b0:e0, b1:e1, b2:e2] = a_xg 

541 r += 1 

542 if broadcast: 

543 self.comm.broadcast(A_xg, 0) 

544 return A_xg 

545 

546 def distribute(self, B_xg, out=None): 

547 """Distribute full array B_xg to subdomains, result in b_xg. 

548 

549 B_xg is not used by the slaves (i.e. it should be None on all slaves) 

550 b_xg must be allocated on all nodes and will be overwritten. 

551 """ 

552 

553 if self.comm.size == 1: 

554 if out is None: 

555 return B_xg 

556 out[:] = B_xg 

557 return out 

558 

559 if out is None: 

560 out = self.empty(B_xg.shape[:-3], dtype=B_xg.dtype) 

561 

562 if self.rank != 0: 

563 self.comm.receive(out, 0, 42) 

564 else: 

565 parsize_c = self.parsize_c 

566 requests = [] 

567 r = 0 

568 for n0 in range(parsize_c[0]): 

569 b0, e0 = self.n_cp[0][n0:n0 + 2] - self.beg_c[0] 

570 for n1 in range(parsize_c[1]): 

571 b1, e1 = self.n_cp[1][n1:n1 + 2] - self.beg_c[1] 

572 for n2 in range(parsize_c[2]): 

573 b2, e2 = self.n_cp[2][n2:n2 + 2] - self.beg_c[2] 

574 if r != 0: 

575 a_xg = B_xg[..., b0:e0, b1:e1, b2:e2].copy() 

576 request = self.comm.send(a_xg, r, 42, NONBLOCKING) 

577 # Remember to store a reference to the 

578 # send buffer (a_xg) so that is isn't 

579 # deallocated: 

580 requests.append((request, a_xg)) 

581 else: 

582 out[:] = B_xg[..., b0:e0, b1:e1, b2:e2] 

583 r += 1 

584 

585 for request, a_xg in requests: 

586 self.comm.wait(request) 

587 

588 return out 

589 

590 def zero_pad(self, a_xg, global_array=True): 

591 """Pad array with zeros as first element along non-periodic directions. 

592 

593 Array may either be local or in standard decomposition. 

594 """ 

595 

596 # We could infer what global_array should be from a_xg.shape. 

597 # But as it is now, there is a bit of redundancy to avoid 

598 # confusing errors 

599 

600 gshape = a_xg.shape[-3:] 

601 padding_c = 1 - self.pbc_c 

602 if global_array: 

603 assert (gshape == self.N_c - padding_c).all(), gshape 

604 bshape = tuple(self.N_c) 

605 else: 

606 assert (gshape == self.n_c).all() 

607 parpos_c = self.get_processor_position_from_rank() 

608 # Only pad where domain is on edge: 

609 padding_c *= (parpos_c == 0) 

610 bshape = tuple(self.n_c + padding_c) 

611 

612 if self.pbc_c.all(): 

613 return a_xg 

614 

615 npbx, npby, npbz = padding_c 

616 b_xg = np.zeros(a_xg.shape[:-3] + tuple(bshape), dtype=a_xg.dtype) 

617 b_xg[..., npbx:, npby:, npbz:] = a_xg 

618 return b_xg 

619 

620 def dipole_moment(self, 

621 rho_R: Array3D, 

622 center_v: Vector = None) -> Array1D: 

623 """Calculate dipole moment of density. 

624 

625 Integration region will be centered on center_v. Default center 

626 is center of unit cell. 

627 """ 

628 index_cr = [np.arange(self.beg_c[c], self.end_c[c], dtype=float) 

629 for c in range(3)] 

630 

631 if center_v is not None: 

632 corner_c = (np.linalg.solve(self.h_cv.T, 

633 center_v) % self.N_c) - self.N_c / 2 

634 for corner, index_r, N in zip(corner_c, index_cr, self.N_c): 

635 index_r -= corner 

636 index_r %= N 

637 index_r += corner 

638 

639 rho_ijk = rho_R 

640 rho_ij = rho_ijk.sum(axis=2) 

641 rho_ik = rho_ijk.sum(axis=1) 

642 rho_cr = [rho_ij.sum(axis=1), rho_ij.sum(axis=0), rho_ik.sum(axis=0)] 

643 

644 d_c = [np.dot(index_cr[c], rho_cr[c]) for c in range(3)] 

645 d_v = -np.dot(d_c, self.h_cv) * self.dv 

646 self.comm.sum(d_v) 

647 return d_v 

648 

649 def calculate_dipole_moment(self, rho_g, center=False, origin_c=None): 

650 """Calculate dipole moment of density.""" 

651 r_cz = [np.arange(self.beg_c[c], self.end_c[c]) for c in range(3)] 

652 if center: 

653 assert origin_c is None 

654 r_cz = [r_cz[c] - 0.5 * self.N_c[c] for c in range(3)] 

655 elif origin_c is not None: 

656 r_cz = [r_cz[c] - origin_c[c] for c in range(3)] 

657 

658 rho_01 = rho_g.sum(axis=2) 

659 rho_02 = rho_g.sum(axis=1) 

660 rho_cz = [rho_01.sum(axis=1), rho_01.sum(axis=0), rho_02.sum(axis=0)] 

661 rhog_c = [np.dot(r_cz[c], rho_cz[c]) for c in range(3)] 

662 d_c = -np.dot(rhog_c, self.h_cv) * self.dv 

663 self.comm.sum(d_c) 

664 return d_c 

665 

666 def wannier_matrix(self, psit_nG, psit_nG1, G_c, nbands=None): 

667 """Wannier localization integrals 

668 

669 The soft part of Z is given by (Eq. 27 ref1):: 

670 

671 ~ ~ -i G.r ~ 

672 Z = <psi | e |psi > 

673 nm n m 

674 

675 psit_nG and psit_nG1 are the set of wave functions for the two 

676 different spin/kpoints in question. 

677 

678 ref1: Thygesen et al, Phys. Rev. B 72, 125119 (2005) 

679 """ 

680 

681 if nbands is None: 

682 nbands = len(psit_nG) 

683 

684 if nbands == 0: 

685 return np.zeros((0, 0), complex) 

686 

687 e_G = np.exp(-2j * pi * np.dot(np.indices(self.n_c).T + 

688 self.beg_c, G_c / self.N_c).T) 

689 a_nG = (e_G * psit_nG[:nbands].conj()).reshape((nbands, -1)) 

690 return np.inner(a_nG, 

691 psit_nG1[:nbands].reshape((nbands, -1))) * self.dv 

692 

693 def find_center(self, a_R): 

694 """Calculate center of positive function.""" 

695 assert self.orthogonal 

696 r_vR = self.get_grid_point_coordinates() 

697 a_R = a_R.astype(complex) 

698 center = [] 

699 for L, r_R in zip(self.cell_cv.diagonal(), r_vR): 

700 z = self.integrate(a_R, np.exp(2j * pi / L * r_R)) 

701 center.append(np.angle(z) / (2 * pi) * L % L) 

702 return np.array(center) 

703 

704 def bytecount(self, dtype=float): 

705 """Get the number of bytes used by a grid of specified dtype.""" 

706 return int(np.prod(self.n_c)) * np.array(1, dtype).itemsize 

707 

708 def get_grid_point_coordinates(self, dtype=float, global_array=False): 

709 """Construct cartesian coordinates of grid points in the domain.""" 

710 r_vG = np.dot(np.indices(self.n_c, dtype).T + self.beg_c, 

711 self.h_cv).T.copy() 

712 if global_array: 

713 return self.collect(r_vG, broadcast=True) # XXX waste! 

714 else: 

715 return r_vG 

716 

717 def get_grid_point_distance_vectors(self, r_v, mic=True, dtype=float): 

718 """Return distances to a given vector in the domain. 

719 

720 mic: if true adopts the mininimum image convention 

721 procedure by W. Smith in 'The Minimum image convention in 

722 Non-Cubic MD cells' March 29, 1989 

723 """ 

724 s_Gc = (np.indices(self.n_c, dtype).T + self.beg_c) / self.N_c 

725 r_c = np.linalg.solve(self.cell_cv.T, r_v) 

726 # do the correction twice works better because of rounding errors 

727 # e.g.: -1.56250000e-25 % 1.0 = 1.0, 

728 # but (-1.56250000e-25 % 1.0) % 1.0 = 0.0 

729 r_c = np.where(self.pbc_c, r_c % 1.0, r_c) 

730 s_Gc -= np.where(self.pbc_c, r_c % 1.0, r_c) 

731 

732 if mic: 

733 s_Gc -= self.pbc_c * (2 * s_Gc).astype(int) 

734 # sanity check 

735 assert (s_Gc * self.pbc_c >= -0.5).all() 

736 assert (s_Gc * self.pbc_c <= 0.5).all() 

737 

738 return np.dot(s_Gc, self.cell_cv).T.copy() 

739 

740 def interpolate_grid_points(self, spos_nc, vt_g): 

741 """Return interpolated values. 

742 

743 Calculate interpolated values from array vt_g based on the 

744 scaled coordinates on spos_c. 

745 

746 This doesn't work in parallel, since it would require 

747 communication between neighbouring grids.""" 

748 

749 assert self.comm.size == 1 

750 

751 vt_g = self.zero_pad(vt_g) 

752 return map_coordinates(vt_g, 

753 (spos_nc * self.N_c).T, 

754 order=3, 

755 mode='wrap') 

756 

757 def is_my_grid_point(self, R_c: Sequence[int]) -> bool: 

758 """Check if grid point belongs to this domain.""" 

759 return ((self.beg_c <= R_c) & (R_c < self.end_c)).all()