Coverage for gpaw/core/uniform_grid.py: 81%
464 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1from __future__ import annotations
3from functools import cached_property
4from math import pi
5from typing import Sequence, Literal, TYPE_CHECKING
6import numpy as np
8import gpaw.fftw as fftw
9from gpaw.core.arrays import DistributedArrays
10from gpaw.core.atom_centered_functions import UGAtomCenteredFunctions
11from gpaw.core.domain import Domain
12from gpaw.gpu import as_np, cupy_is_fake
13from gpaw.grid_descriptor import GridDescriptor
14from gpaw.mpi import MPIComm, serial_comm
15from gpaw.new import zips
16from gpaw.typing import (Array1D, Array2D, Array3D, Array4D, ArrayLike1D,
17 ArrayLike2D, Vector)
18from gpaw.new.c import add_to_density, add_to_density_gpu, symmetrize_ft
19from gpaw.fd_operators import Gradient
21if TYPE_CHECKING:
22 import plotly.graph_objects as go
25class UGDesc(Domain['UGArray']):
26 def __init__(self,
27 *,
28 cell: ArrayLike1D | ArrayLike2D, # bohr
29 size: ArrayLike1D,
30 pbc=(True, True, True),
31 zerobc=(False, False, False),
32 kpt: Vector | None = None, # in units of reciprocal cell
33 comm: MPIComm = serial_comm,
34 decomp: Sequence[Sequence[int]] | None = None,
35 dtype=None):
36 """Description of 3D uniform grid.
38 parameters
39 ----------
40 cell:
41 Unit cell given as three floats (orthorhombic grid), six floats
42 (three lengths and the angles in degrees) or a 3x3 matrix
43 (units: bohr).
44 size:
45 Number of grid points along axes.
46 pbc:
47 Periodic boundary conditions flag(s).
48 zerobc:
49 Zero-boundary conditions flag(s). Skip first grid-point
50 (assumed to be zero).
51 comm:
52 Communicator for domain-decomposition.
53 kpt:
54 K-point for Block-boundary conditions specified in units of the
55 reciprocal cell.
56 decomp:
57 Decomposition of the domain.
58 dtype:
59 Data-type (float or complex).
60 """
61 self.size_c = np.array(size, int)
62 if isinstance(zerobc, int):
63 zerobc = (zerobc,) * 3
64 self.zerobc_c = np.array(zerobc, bool)
66 if decomp is None:
67 gd = GridDescriptor(size, pbc_c=~self.zerobc_c, comm=comm)
68 decomp = gd.n_cp
69 self.decomp_cp = [np.asarray(d) for d in decomp]
71 self.parsize_c = np.array([len(d_p) - 1 for d_p in self.decomp_cp])
72 self.mypos_c = np.unravel_index(comm.rank, self.parsize_c)
74 self.start_c = np.array([d_p[p]
75 for d_p, p
76 in zips(self.decomp_cp, self.mypos_c)])
77 self.end_c = np.array([d_p[p + 1]
78 for d_p, p
79 in zips(self.decomp_cp, self.mypos_c)])
80 self.mysize_c = self.end_c - self.start_c
82 Domain.__init__(self, cell, pbc, kpt, comm, dtype)
83 self.myshape = tuple(self.mysize_c)
85 self.dv = self.volume / self.size_c.prod()
87 self.itemsize = 8 if self.dtype == float else 16
89 if (self.zerobc_c & self.pbc_c).any():
90 raise ValueError('Bad boundary conditions')
92 @property
93 def size(self):
94 """Size of uniform grid."""
95 return self.size_c.copy()
97 def global_shape(self) -> tuple[int, ...]:
98 """Actual size of uniform grid."""
99 return tuple(self.size_c - self.zerobc_c)
101 def __repr__(self):
102 return Domain.__repr__(self).replace(
103 'Domain(',
104 f'UGDesc(size={self.size_c.tolist()}, ')
106 def _short_string(self, global_shape):
107 return f'uniform wave function grid shape: {global_shape}'
109 @cached_property
110 def phase_factor_cd(self):
111 """Phase factor for block-boundary conditions."""
112 delta_d = np.array([-1, 1])
113 disp_cd = np.empty((3, 2))
114 for pos, pbc, size, disp_d in zips(self.mypos_c, self.pbc_c,
115 self.parsize_c, disp_cd):
116 disp_d[:] = -((pos + delta_d) // size)
117 return np.exp(2j * np.pi *
118 disp_cd *
119 self.kpt_c[:, np.newaxis])
121 def new(self,
122 *,
123 kpt=None,
124 dtype=None,
125 comm: MPIComm | Literal['inherit'] | None = 'inherit',
126 size=None,
127 pbc=None,
128 zerobc=None,
129 decomp=None) -> UGDesc:
130 """Create new uniform grid description."""
131 reuse_decomp = (decomp is None and comm == 'inherit' and
132 size is None and pbc is None and zerobc is None)
133 if reuse_decomp:
134 decomp = self.decomp_cp
135 comm = self.comm if comm == 'inherit' else comm
136 return UGDesc(cell=self.cell_cv,
137 size=self.size_c if size is None else size,
138 pbc=self.pbc_c if pbc is None else pbc,
139 zerobc=self.zerobc_c if zerobc is None else zerobc,
140 kpt=(self.kpt_c if self.kpt_c.any() else None)
141 if kpt is None else kpt,
142 comm=comm or serial_comm,
143 decomp=decomp,
144 dtype=self.dtype if dtype is None else dtype)
146 def empty(self,
147 dims: int | tuple[int, ...] = (),
148 comm: MPIComm = serial_comm,
149 xp=np) -> UGArray:
150 """Create new UGArray object.
152 parameters
153 ----------
154 dims:
155 Extra dimensions.
156 comm:
157 Distribute dimensions along this communicator.
158 """
159 return UGArray(self, dims, comm, xp=xp)
161 def from_data(self, data: np.ndarray) -> UGArray:
162 return UGArray(self, data.shape[:-3], data=data)
164 def blocks(self, data: np.ndarray):
165 """Yield views of blocks of data."""
166 s0, s1, s2 = self.parsize_c
167 d0_p, d1_p, d2_p = (d_p - d_p[0] for d_p in self.decomp_cp)
168 for p0 in range(s0):
169 b0, e0 = d0_p[p0:p0 + 2]
170 for p1 in range(s1):
171 b1, e1 = d1_p[p1:p1 + 2]
172 for p2 in range(s2):
173 b2, e2 = d2_p[p2:p2 + 2]
174 yield data[..., b0:e0, b1:e1, b2:e2]
176 def xyz(self) -> Array4D:
177 """Create array of (x, y, z) coordinates."""
178 indices_Rc = np.indices(self.mysize_c).transpose((1, 2, 3, 0))
179 indices_Rc += self.start_c
180 return indices_Rc @ (self.cell_cv.T / self.size_c).T
182 def atom_centered_functions(self,
183 functions,
184 positions,
185 *,
186 qspiral_v=None,
187 atomdist=None,
188 integrals=None,
189 cut=False,
190 xp=None):
191 """Create UGAtomCenteredFunctions object."""
192 assert qspiral_v is None
193 return UGAtomCenteredFunctions(functions,
194 positions,
195 self,
196 atomdist=atomdist,
197 integrals=integrals,
198 cut=cut,
199 xp=xp)
201 def transformer(self, other: UGDesc, stencil_range=3, xp=np):
202 """Create transformer from one grid to another.
204 (for interpolation and restriction).
205 """
206 from gpaw.transformers import Transformer
208 apply = Transformer(self._gd, other._gd, nn=stencil_range, xp=xp).apply
210 def transform(functions, out=None):
211 if out is None:
212 out = other.empty(functions.dims, functions.comm, xp=xp)
213 for input, output in zips(functions._arrays(), out._arrays()):
214 apply(input, output)
215 return out
217 return transform
219 def eikr(self, kpt_c: Vector | None = None) -> Array3D:
220 """Plane wave.
222 :::
223 _ _
224 ik.r
225 e
227 Parameters
228 ----------
229 kpt_c:
230 k-point in units of the reciprocal cell. Defaults to the
231 UGDesc objects own k-point.
232 """
233 if kpt_c is None:
234 kpt_c = self.kpt_c
235 index_Rc = np.indices(self.mysize_c).T + self.start_c
236 return np.exp(2j * pi * (index_Rc @ (kpt_c / self.size_c))).T
238 @property
239 def _gd(self):
240 # Make sure gd can be pickled (in serial):
241 comm = self.comm if self.comm.size > 1 else serial_comm
243 return GridDescriptor(self.size_c,
244 cell_cv=self.cell_cv,
245 pbc_c=~self.zerobc_c,
246 comm=comm,
247 parsize_c=[len(d_p) - 1
248 for d_p in self.decomp_cp])
250 @classmethod
251 def from_cell_and_grid_spacing(cls,
252 cell: ArrayLike1D | ArrayLike2D,
253 grid_spacing: float,
254 pbc=(True, True, True),
255 kpt: Vector | None = None,
256 comm: MPIComm = serial_comm,
257 dtype=None) -> UGDesc:
258 """Create UGDesc from grid-spacing."""
259 domain: Domain = Domain(cell, pbc, kpt, comm, dtype)
260 return domain.uniform_grid_with_grid_spacing(grid_spacing)
262 def fft_plans(self,
263 flags: int = fftw.MEASURE,
264 xp=np,
265 dtype=None) -> fftw.FFTPlans:
266 """Create FFTW-plans."""
267 if dtype is None:
268 dtype = self.dtype
269 if self.comm.rank == 0:
270 return fftw.create_plans(self.size_c, dtype, flags, xp)
271 else:
272 return fftw.create_plans([0, 0, 0], dtype)
274 def ranks_from_fractional_positions(self,
275 relpos_ac: Array2D) -> Array1D:
276 rank_ac = np.floor(relpos_ac * self.parsize_c).astype(int)
277 if (rank_ac < 0).any() or (rank_ac >= self.parsize_c).any():
278 raise ValueError('Positions outside cell!')
279 return np.ravel_multi_index(rank_ac.T, self.parsize_c) # type: ignore
281 def ekin_max(self) -> float:
282 """Maximum value of ekin so that all 0.5 * G^2 < ekin.
284 In 1D, this will be 0.5*(pi/h)^2 where h is the grid-spacing.
285 """
286 # Height of reciprocal cell (squared):
287 b2_c = np.pi**2 / (self.cell_cv**2).sum(1)
288 return 0.5 * (self.size_c**2 * b2_c).min()
291class UGArray(DistributedArrays[UGDesc]):
292 def __init__(self,
293 grid: UGDesc,
294 dims: int | tuple[int, ...] = (),
295 comm: MPIComm = serial_comm,
296 data: np.ndarray | None = None,
297 xp=None):
298 """Object for storing function(s) on a uniform grid.
300 parameters
301 ----------
302 grid:
303 Description of uniform grid.
304 dims:
305 Extra dimensions.
306 comm:
307 Distribute dimensions along this communicator.
308 data:
309 Data array for storage.
310 """
311 DistributedArrays. __init__(self, dims, grid.myshape,
312 comm, grid.comm, data, grid.dv,
313 grid.dtype, xp)
314 self.desc = grid
316 def __repr__(self):
317 txt = f'UGArray(grid={self.desc}, dims={self.dims}'
318 if self.comm.size > 1:
319 txt += f', comm={self.comm.rank}/{self.comm.size}'
320 if self.xp is not np:
321 txt += ', xp=cp'
322 return txt + ')'
324 def new(self, data=None, zeroed=False, dims=None):
325 """Create new UniforGridFunctions object of same kind.
327 Parameters
328 ----------
329 data:
330 Array to use for storage.
331 zeroed:
332 If True, set data to zero.
333 dims:
334 Extra dimensions (bands, spin, etc.), required if
335 data does not fit the full array.
336 """
337 if dims:
338 assert data is not None
339 else:
340 dims = self.dims
341 if data is None:
342 data = self.xp.empty_like(self.data)
344 f_xR = UGArray(self.desc, dims, self.comm, data)
345 if zeroed:
346 f_xR.data[:] = 0.0
347 return f_xR
349 def __getitem__(self, index):
350 data = self.data[index]
351 return UGArray(data=data,
352 dims=data.shape[:-3],
353 grid=self.desc)
355 def __imul__(self,
356 other: float | np.ndarray | UGArray
357 ) -> UGArray:
358 if isinstance(other, float):
359 self.data *= other
360 return self
361 if isinstance(other, UGArray):
362 other = other.data
363 assert other.shape[-3:] == self.data.shape[-3:]
364 self.data *= other
365 return self
367 def __mul__(self,
368 other: float | np.ndarray | UGArray
369 ) -> UGArray:
370 result = self.new(data=self.data.copy())
371 result *= other
372 return result
374 def _arrays(self):
375 return self.data.reshape((-1,) + self.data.shape[-3:])
377 def xy(self, *axes: int | None) -> tuple[Array1D, Array1D]:
378 """Extract x, y values along line.
380 Useful for plotting::
382 x, y = grid.xy(0, ..., 0)
383 plt.plot(x, y)
384 """
385 assert len(axes) == 3 + len(self.dims)
386 index = tuple([slice(0, None) if axis is None else axis
387 for axis in axes])
388 y = self.data[index] # type: ignore
389 c = axes[-3:].index(...)
390 grid = self.desc
391 dx = (grid.cell_cv[c]**2).sum()**0.5 / grid.size_c[c]
392 x = np.arange(grid.start_c[c], grid.end_c[c]) * dx
393 return x, as_np(y)
395 def to_complex(self) -> UGArray:
396 """Return a copy with dtype=complex."""
397 c = self.desc.new(dtype=complex).empty()
398 c.data[:] = self.data
399 return c
401 def scatter_from(self, data: np.ndarray | UGArray | None = None) -> None:
402 """Scatter data from rank-0 to all ranks."""
403 if isinstance(data, UGArray):
404 data = data.data
406 comm = self.desc.comm
407 if comm.size == 1:
408 self.data[:] = data
409 return
411 if comm.rank != 0:
412 comm.receive(self.data, 0, 42)
413 return
415 requests = []
416 assert isinstance(data, self.xp.ndarray)
417 for rank, block in enumerate(self.desc.blocks(data)):
418 if rank != 0:
419 block = block.copy()
420 request = comm.send(block, rank, 42, False)
421 # Remember to store a reference to the
422 # send buffer (block) so that is isn't
423 # deallocated:
424 requests.append((request, block))
425 else:
426 self.data[:] = block
428 for request, _ in requests:
429 comm.wait(request)
431 def gather(self, out=None, broadcast=False):
432 """Gather data from all ranks to rank-0."""
433 assert out is None
434 comm = self.desc.comm
435 if comm.size == 1:
436 return self
438 if broadcast or comm.rank == 0:
439 grid = self.desc.new(comm=serial_comm)
440 out = grid.empty(self.dims, comm=self.comm, xp=self.xp)
442 if comm.rank != 0:
443 # There can be several sends before the corresponding receives
444 # are posted, so use synchronous send here
445 comm.ssend(self.data, 0, 301)
446 if broadcast:
447 comm.broadcast(out.data, 0)
448 return out
449 return
451 # Put the subdomains from the slaves into the big array
452 # for the whole domain:
453 for rank, block in enumerate(self.desc.blocks(out.data)):
454 if rank != 0:
455 buf = self.xp.empty_like(block)
456 comm.receive(buf, rank, 301)
457 block[:] = buf
458 else:
459 block[:] = self.data
461 if broadcast:
462 comm.broadcast(out.data, 0)
464 return out
466 def fft(self, plan=None, pw=None, out=None):
467 r"""Do FFT.
469 Returns:
470 PWArray with values
471 :::
472 _ _
473 _ 1 / _ -iG.r _
474 C(G) = -- |dr e f(r),
475 V /
477 where `C(\bG)` are the plane wave coefficients and V is the cell
478 volume.
480 Parameters
481 ----------
482 plan:
483 Plan for FFT.
484 pw:
485 Target PW description.
486 out:
487 Target PWArray object.
488 """
489 assert self.dims == ()
490 if out is None:
491 assert pw is not None
492 out = pw.empty(xp=self.xp)
493 if pw is None:
494 pw = out.desc
495 if pw.dtype != self.desc.dtype:
496 raise TypeError(
497 f'Type mismatch: {self.desc.dtype} -> {pw.dtype}')
498 input = self
499 if self.desc.comm.size > 1:
500 input = input.gather()
501 if self.desc.comm.rank == 0:
502 plan = plan or self.desc.fft_plans(xp=self.xp)
503 coefs = plan.fft_sphere(input.data, pw)
504 else:
505 coefs = None
507 out.scatter_from(coefs)
509 return out
511 def norm2(self):
512 """Calculate integral over cell of absolute value squared.
514 :::
516 / _ 2 _
517 ||a(r)| dr
518 /
519 """
520 norm_x = []
521 arrays_xR = self._arrays()
522 for a_R in arrays_xR:
523 norm_x.append(self.xp.vdot(a_R, a_R).real * self.desc.dv)
524 result = self.xp.array(norm_x).reshape(self.mydims)
525 self.desc.comm.sum(result)
526 return result
528 def integrate(self, other=None, skip_sum=False):
529 """Integral of self or self times cc(other)."""
530 if other is not None:
531 assert self.desc.dtype == other.desc.dtype
532 a_xR = self._arrays()
533 b_yR = other._arrays()
534 a_xR = a_xR.reshape((len(a_xR), -1))
535 b_yR = b_yR.reshape((len(b_yR), -1))
536 result = (a_xR @ b_yR.T.conj()).reshape(self.dims + other.dims)
537 else:
538 # Make sure we have an array and not a scalar!
539 result = self.xp.asarray(self.data.sum(axis=(-3, -2, -1)))
541 if not skip_sum:
542 self.desc.comm.sum(result)
543 if result.ndim == 0:
544 result = result.item() # convert to scalar
545 return result * self.desc.dv
547 def to_pbc_grid(self):
548 """Convert to UniformGrid with ``pbc=(True, True, True)``."""
549 if not self.desc.zerobc_c.any():
550 return self
551 grid = self.desc.new(zerobc=False)
552 new = grid.empty(self.dims)
553 new.data[:] = 0.0
554 *_, i, j, k = self.data.shape
555 new.data[..., -i:, -j:, -k:] = self.data
556 return new
558 def multiply_by_eikr(self, kpt_c: Vector | None = None) -> None:
559 """Multiply by `exp(ik.r)`."""
560 if kpt_c is None:
561 kpt_c = self.desc.kpt_c
562 else:
563 kpt_c = np.asarray(kpt_c)
564 if kpt_c.any():
565 self.data *= self.desc.eikr(kpt_c)
567 def interpolate(self,
568 plan1: fftw.FFTPlans | None = None,
569 plan2: fftw.FFTPlans | None = None,
570 grid: UGDesc | None = None,
571 out: UGArray | None = None) -> UGArray:
572 """Interpolate to finer grid.
574 Parameters
575 ----------
576 plan1:
577 Plan for FFT (course grid).
578 plan2:
579 Plan for inverse FFT (fine grid).
580 grid:
581 Target grid.
582 out:
583 Target UGArray object.
584 """
585 if out is None:
586 if grid is None:
587 raise ValueError('Please specify "grid" or "out".')
588 out = grid.empty(self.dims, xp=self.xp)
590 if out.desc.zerobc_c.any() or self.desc.zerobc_c.any():
591 raise ValueError('Grids must have zerobc=False!')
593 if self.desc.comm.size > 1:
594 input = self.gather()
595 if input is not None:
596 output = input.interpolate(plan1, plan2,
597 out.desc.new(comm=None))
598 out.scatter_from(output.data)
599 else:
600 out.scatter_from()
601 return out
603 size1_c = self.desc.size_c
604 size2_c = out.desc.size_c
605 if (size2_c <= size1_c).any():
606 raise ValueError('Too few points in target grid!')
608 plan1 = plan1 or self.desc.fft_plans(xp=self.xp)
609 plan2 = plan2 or out.desc.fft_plans(xp=self.xp)
611 if self.dims:
612 for input, output in zips(self.flat(), out.flat()):
613 input.interpolate(plan1, plan2, grid, output)
614 return out
616 plan1.tmp_R[:] = self.data
617 kpt_c = self.desc.kpt_c
618 if kpt_c.any():
619 plan1.tmp_R *= self.desc.eikr(-kpt_c)
620 plan1.fft()
622 a_Q = plan1.tmp_Q
623 b_Q = plan2.tmp_Q
625 e0, e1, e2 = 1 - size1_c % 2 # even or odd size
626 a0, a1, a2 = size2_c // 2 - size1_c // 2
627 b0, b1, b2 = size1_c + (a0, a1, a2)
629 if self.desc.dtype == float:
630 b2 = (b2 - a2) // 2 + 1
631 a2 = 0
632 axes = [0, 1]
633 else:
634 axes = [0, 1, 2]
636 b_Q[:] = 0.0
637 b_Q[a0:b0, a1:b1, a2:b2] = self.xp.fft.fftshift(a_Q, axes=axes)
639 if e0:
640 b_Q[a0, a1:b1, a2:b2] *= 0.5
641 b_Q[b0, a1:b1, a2:b2] = b_Q[a0, a1:b1, a2:b2]
642 b0 += 1
643 if e1:
644 b_Q[a0:b0, a1, a2:b2] *= 0.5
645 b_Q[a0:b0, b1, a2:b2] = b_Q[a0:b0, a1, a2:b2]
646 b1 += 1
647 if self.desc.dtype == complex:
648 if e2:
649 b_Q[a0:b0, a1:b1, a2] *= 0.5
650 b_Q[a0:b0, a1:b1, b2] = b_Q[a0:b0, a1:b1, a2]
651 else:
652 if e2:
653 b_Q[a0:b0, a1:b1, b2 - 1] *= 0.5
655 b_Q[:] = self.xp.fft.ifftshift(b_Q, axes=axes)
656 plan2.ifft()
657 out.data[:] = plan2.tmp_R
658 out.data *= (1.0 / self.data.size)
659 out.multiply_by_eikr()
660 return out
662 def fft_restrict(self,
663 plan1: fftw.FFTPlans | None = None,
664 plan2: fftw.FFTPlans | None = None,
665 grid: UGDesc | None = None,
666 out: UGArray | None = None) -> UGArray:
667 """Restrict to coarser grid.
669 Parameters
670 ----------
671 plan1:
672 Plan for FFT.
673 plan2:
674 Plan for inverse FFT.
675 grid:
676 Target grid.
677 out:
678 Target UGArray object.
679 """
680 if out is None:
681 if grid is None:
682 raise ValueError('Please specify "grid" or "out".')
683 out = grid.empty(self.dims, xp=self.xp)
685 if out.desc.zerobc_c.any() or self.desc.zerobc_c.any():
686 raise ValueError('Grids must have zerobc=False!')
688 if self.desc.comm.size > 1:
689 input = self.gather()
690 if input is not None:
691 output = input.fft_restrict(plan1, plan2,
692 out.desc.new(comm=None))
693 out.scatter_from(output.data)
694 else:
695 out.scatter_from()
696 return out
698 size1_c = self.desc.size_c
699 size2_c = out.desc.size_c
701 plan1 = plan1 or self.desc.fft_plans()
702 plan2 = plan2 or out.desc.fft_plans()
704 if self.dims:
705 for input, output in zips(self.flat(), out.flat()):
706 input.fft_restrict(plan1, plan2, grid, output)
707 return out
709 plan1.tmp_R[:] = self.data
710 a_Q = plan2.tmp_Q
711 b_Q = plan1.tmp_Q
713 e0, e1, e2 = 1 - size2_c % 2 # even or odd size
714 a0, a1, a2 = size1_c // 2 - size2_c // 2
715 b0, b1, b2 = size2_c // 2 + size1_c // 2 + 1
717 if self.desc.dtype == float:
718 b2 = size2_c[2] // 2 + 1
719 a2 = 0
720 axes = [0, 1]
721 else:
722 axes = [0, 1, 2]
724 plan1.fft()
725 b_Q[:] = self.xp.fft.fftshift(b_Q, axes=axes)
727 if e0:
728 b_Q[a0, a1:b1, a2:b2] += b_Q[b0 - 1, a1:b1, a2:b2]
729 b_Q[a0, a1:b1, a2:b2] *= 0.5
730 b0 -= 1
731 if e1:
732 b_Q[a0:b0, a1, a2:b2] += b_Q[a0:b0, b1 - 1, a2:b2]
733 b_Q[a0:b0, a1, a2:b2] *= 0.5
734 b1 -= 1
735 if self.desc.dtype == complex and e2:
736 b_Q[a0:b0, a1:b1, a2] += b_Q[a0:b0, a1:b1, b2 - 1]
737 b_Q[a0:b0, a1:b1, a2] *= 0.5
738 b2 -= 1
740 a_Q[:] = b_Q[a0:b0, a1:b1, a2:b2]
741 a_Q[:] = self.xp.fft.ifftshift(a_Q, axes=axes)
742 plan2.ifft()
743 out.data[:] = plan2.tmp_R
744 out.data *= (1.0 / self.data.size)
745 return out
747 def abs_square(self,
748 weights: Array1D,
749 out: UGArray | None = None) -> None:
750 """Add weighted absolute square of data to output array."""
751 assert out is not None
753 if self.xp is np:
754 for f, psit_R in zips(weights, self.data):
755 add_to_density(f, psit_R, out.data)
756 elif cupy_is_fake:
757 for f, psit_R in zips(weights, self.data):
758 add_to_density(f, psit_R._data, out.data._data) # type: ignore
759 else:
760 add_to_density_gpu(self.xp.asarray(weights), self.data, out.data)
762 def symmetrize(self, rotation_scc, translation_sc):
763 """Make data symmetric."""
764 if len(rotation_scc) == 1:
765 return
767 a_xR = self.gather()
769 if a_xR is None:
770 b_xR = None
771 else:
772 if self.xp is not np:
773 a_xR = a_xR.to_xp(np)
774 b_xR = a_xR.new()
775 t_sc = (translation_sc * self.desc.size_c).round().astype(int)
776 offset_c = np.array(self.desc.zerobc_c, dtype=int)
777 for a_R, b_R in zips(a_xR._arrays(), b_xR._arrays()):
778 b_R[:] = 0.0
779 for r_cc, t_c in zips(rotation_scc, t_sc):
780 symmetrize_ft(a_R, b_R, r_cc, t_c, offset_c)
781 if self.xp is not np:
782 b_xR = b_xR.to_xp(self.xp)
783 self.scatter_from(b_xR)
785 self.data *= 1.0 / len(rotation_scc)
787 def randomize(self, seed: int | None = None) -> None:
788 """Insert random numbers between -0.5 and 0.5 into data."""
789 if seed is None:
790 seed = self.comm.rank + self.desc.comm.rank * self.comm.size
791 rng = self.xp.random.default_rng(seed)
792 a = self.data.view(float)
793 rng.random(a.shape, out=a)
794 a -= 0.5
796 def moment(self):
797 """Calculate moment of data."""
798 assert self.dims == ()
799 ug = self.desc
801 index_cr = [np.arange(ug.start_c[c], ug.end_c[c], dtype=float)
802 for c in range(3)]
803 for index_r, size in zip(index_cr, ug.size_c):
804 if index_r[0] == 0:
805 # We have periodic bc's, so index 0 is the same as index
806 # size (= last + 1). Include both points with 0.5 weight:
807 index_r[0] = 0.5 * size
809 rho_ijk = self.data
810 rho_ij = rho_ijk.sum(axis=2)
811 rho_ik = rho_ijk.sum(axis=1)
812 rho_cr = [rho_ij.sum(axis=1), rho_ij.sum(axis=0), rho_ik.sum(axis=0)]
813 if self.xp is not np:
814 rho_cr = [rho_r.get() for rho_r in rho_cr]
816 d_c = [index_r @ rho_r for index_r, rho_r in zips(index_cr, rho_cr)]
817 d_v = (d_c / ug.size_c) @ ug.cell_cv * self.dv
818 self.desc.comm.sum(d_v)
819 return d_v
821 def scaled(self, cell: float, values: float = 1.0) -> UGArray:
822 """Create new scaled UGArray object.
824 Unit cell axes are multiplied by `cell` and data by `values`.
825 """
826 grid = self.desc
827 grid = UGDesc(cell=grid.cell_cv * cell,
828 size=grid.size_c,
829 pbc=grid.pbc_c,
830 zerobc=grid.zerobc_c,
831 kpt=(grid.kpt_c if grid.kpt_c.any() else None),
832 dtype=grid.dtype,
833 comm=grid.comm)
834 return UGArray(grid, self.dims, self.comm, self.data * values)
836 def add_ked(self,
837 occ_n: Array1D,
838 taut_R: UGArray) -> None:
839 grad_v = [
840 Gradient(self.desc._gd, v, n=3, dtype=self.desc.dtype)
841 for v in range(3)]
842 tmp_R = self.desc.empty()
843 for f, psit_R in zips(occ_n, self):
844 for grad in grad_v:
845 grad(psit_R, tmp_R)
846 add_to_density(0.5 * f, tmp_R.data, taut_R.data)
848 def redist(self,
849 domain: UGDesc,
850 comm1: MPIComm, comm2: MPIComm) -> UGArray:
851 a = super().redist(domain, comm1, comm2)
852 assert isinstance(a, UGArray)
853 return a
855 def isosurface(self, show=True, **kwargs) -> go.Isosurface:
856 import plotly.graph_objects as go
857 values = self.data
858 assert values.ndim == 3
859 if values.dtype == complex:
860 values = abs(values)
861 x, y, z = (c.T.flatten() for c in self.desc.xyz().T)
862 vmin = values.min()
863 vmax = values.max()
864 kwargs = {
865 'isomin': vmin + (vmax - vmin) * 0.1,
866 'isomax': vmax - (vmax - vmin) * 0.1,
867 'caps': dict(x_show=False,
868 y_show=False,
869 z_show=False),
870 **kwargs}
871 surf = go.Isosurface(x=x, y=y, z=z, value=values.flatten(),
872 **kwargs)
873 if show:
874 go.Figure(data=[surf]).show()
875 return surf