Coverage for gpaw/core/pwacf.py: 94%
375 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3from math import pi
4from typing import TYPE_CHECKING
6import numpy as np
8from gpaw.core.atom_arrays import AtomArraysLayout, AtomDistribution
9from gpaw.core.atom_centered_functions import AtomCenteredFunctions
10from gpaw.core.matrix import Matrix
11from gpaw.core.uniform_grid import UGArray
12from gpaw.ffbt import rescaled_fourier_bessel_transform
13from gpaw.gpu import cupy_is_fake, gpu_gemm
14# from gpaw.lfc import BaseLFC
15from gpaw.new import prod, trace, tracectx
16from gpaw.new.c import pwlfc_expand, pwlfc_expand_gpu
17from gpaw.spherical_harmonics import Y, nablarlYL
18from gpaw.spline import Spline
19from gpaw.typing import ArrayLike1D
20from gpaw.utilities import as_complex_dtype, as_real_dtype
21from gpaw.utilities.blas import mmm
23if TYPE_CHECKING:
24 from gpaw.core.plane_waves import PWDesc, PWArray
27class PWAtomCenteredFunctions(AtomCenteredFunctions):
28 def __init__(self,
29 functions,
30 relpos,
31 pw,
32 atomdist=None,
33 integrals=None,
34 xp=None):
35 AtomCenteredFunctions.__init__(self, functions, relpos, atomdist)
36 self.pw = pw
37 self.xp = xp or np
38 self.integrals = integrals
40 def new(self, pw, atomdist):
41 return PWAtomCenteredFunctions(
42 self.functions,
43 self.relpos_ac,
44 pw,
45 atomdist=atomdist,
46 xp=self.xp)
48 def _lazy_init(self):
49 if self._lfc is not None:
50 return
52 self._lfc = PWLFC(self.functions, self.pw, xp=self.xp,
53 integrals=self.integrals)
54 if self._atomdist is None:
55 self._atomdist = AtomDistribution.from_number_of_atoms(
56 len(self.relpos_ac), self.pw.comm)
58 self._lfc.set_positions(self.relpos_ac, self._atomdist)
59 self._layout = AtomArraysLayout([sum(2 * f.l + 1 for f in funcs)
60 for funcs in self.functions],
61 self._atomdist,
62 self.pw.dtype,
63 xp=self.xp)
65 def __repr__(self):
66 s = super().__repr__()
67 if self.xp is np:
68 return s
69 return s[:-1] + ', xp=cp)'
71 def to_uniform_grid(self,
72 out: UGArray,
73 scale: float = 1.0) -> UGArray:
74 out_G = self.pw.zeros(xp=out.xp)
75 self.add_to(out_G, scale)
76 return out_G.ifft(out=out)
78 def change_cell(self, new_pw):
79 self.pw = new_pw
80 self._lfc = None
82 def multiply(self,
83 C_nM: Matrix,
84 out_nG: PWArray) -> None:
85 """Convert from LCAO expansion to PW expansion."""
86 self._lazy_init()
87 lfc = self._lfc
88 assert lfc is not None
89 for G1, G2 in lfc.block():
90 f_GI = lfc.expand(G1, G2, cc=False)
91 a_nG = out_nG.data[:, G1:G2]
92 if lfc.real:
93 a_nG = a_nG.view(f_GI.dtype)
94 if self.xp is np:
95 mmm(1.0 / self.pw.dv, C_nM.data, 'N', f_GI, 'T', 0.0, a_nG)
96 else:
97 gpu_gemm('N', 'T',
98 C_nM.data, f_GI, a_nG, 1.0 / self.pw.dv, 0.0)
101class PWLFC: # (BaseLFC)
102 def __init__(self,
103 functions,
104 pw: PWDesc,
105 *,
106 xp,
107 integrals: ArrayLike1D | float | None = None,
108 blocksize: int | None = 5000):
109 """Reciprocal-space plane-wave localized function collection.
111 spline_aj: list of list of spline objects
112 Splines.
113 pd: PWDescriptor
114 Plane-wave descriptor object.
115 blocksize: int
116 Block-size to use when looping over G-vectors. Use None for
117 doing all G-vectors in one big block.
118 """
120 self.xp = xp
121 self.pw = pw
122 self.spline_aj = functions
124 self.dtype = pw.dtype
125 self.real = np.issubdtype(pw.dtype, np.floating)
127 self.initialized = False
129 # These will be filled in later:
130 self.Y_GL = np.zeros((0, 0))
131 self.f_Gs: np.ndarray = np.zeros((0, 0))
132 self.l_s: np.ndarray | None = None
133 self.a_J: np.ndarray | None = None
134 self.s_J: np.ndarray | None = None
135 self.I_J: np.ndarray | None = None
136 self.lmax = -1
138 if blocksize is not None:
139 if pw.maxmysize <= blocksize:
140 # No need to block G-vectors
141 blocksize = None
142 self.blocksize = blocksize
144 # These are set later in set_potitions():
145 self.eikR_a = None
146 self.my_atom_indices = None
147 self.my_indices = None
148 self.pos_av = None
149 self.nI = None
151 self.comm = pw.comm
153 if isinstance(integrals, float):
154 self.integral_a = np.zeros(len(functions)) + integrals
155 elif integrals is None:
156 self.integral_a = np.zeros(len(functions))
157 else:
158 self.integral_a = np.array(integrals)
160 @trace
161 def initialize(self) -> None:
162 """Initialize position-independent stuff."""
163 if self.initialized:
164 return
166 xp = self.xp
168 splines: dict[Spline, int] = {}
169 for spline_j in self.spline_aj:
170 for spline in spline_j:
171 if spline not in splines:
172 splines[spline] = len(splines)
173 nsplines = len(splines)
175 nJ = sum(len(spline_j) for spline_j in self.spline_aj)
177 self.f_Gs = xp.empty(self.pw.myshape + (nsplines,),
178 dtype=as_real_dtype(self.dtype))
179 self.l_s = np.empty(nsplines, np.int32)
180 self.a_J = np.empty(nJ, np.int32)
181 self.s_J = np.empty(nJ, np.int32)
182 self.I_J = np.empty(nJ, np.int32)
183 # Fourier transform radial functions:
184 J = 0
185 done: set[Spline] = set()
186 I = 0
187 for a, spline_j in enumerate(self.spline_aj):
188 for spline in spline_j:
189 s = splines[spline] # get spline index
190 if spline not in done:
191 f = rescaled_fourier_bessel_transform(spline)
192 G_G = (2 * self.pw.ekin_G)**0.5
193 self.f_Gs[:, s] = xp.asarray(f.map(G_G))
194 l = spline.get_angular_momentum_number()
195 self.l_s[s] = l
196 integral = self.integral_a[a]
197 if l == 0 and integral != 0.0:
198 x = integral / self.f_Gs[0, s] * (4 * pi)**0.5
199 self.f_Gs[:, s] *= x
200 done.add(spline)
201 self.a_J[J] = a
202 self.s_J[J] = s
203 self.I_J[J] = I
204 I += 2 * spline.get_angular_momentum_number() + 1
205 J += 1
207 self.lmax = max(self.l_s, default=-1)
209 # Spherical harmonics:
210 G_Gv = self.pw.G_plus_k_Gv
211 self.Y_GL = xp.empty((len(G_Gv), (self.lmax + 1)**2),
212 dtype=as_real_dtype(self.dtype))
213 for L in range((self.lmax + 1)**2):
214 self.Y_GL[:, L] = xp.asarray(Y(L, *G_Gv.T))
216 self.l_s = xp.asarray(self.l_s)
217 self.a_J = xp.asarray(self.a_J)
218 self.s_J = xp.asarray(self.s_J)
219 self.I_J = xp.asarray(self.I_J)
221 self.initialized = True
223 def get_function_count(self, a):
224 return sum(2 * spline.get_angular_momentum_number() + 1
225 for spline in self.spline_aj[a])
227 @trace
228 def set_positions(self, spos_ac, atomdist):
229 self.initialize()
231 xp = self.xp
233 if self.real:
234 self.eikR_a = xp.ones(len(spos_ac),
235 dtype=as_real_dtype(self.dtype))
236 else:
237 self.eikR_a = xp.asarray(
238 np.exp(2j * pi * (spos_ac @ self.pw.kpt_c)),
239 dtype=as_complex_dtype(self.dtype))
240 self.pos_av = xp.asarray(np.dot(spos_ac, self.pw.cell),
241 dtype=as_real_dtype(self.dtype))
243 self.pos_avT = xp.asarray(self.pos_av.T,
244 as_real_dtype(self.dtype))
245 self.G_plus_k_Gv = self.xp.asarray(self.pw.G_plus_k_Gv,
246 as_real_dtype(self.dtype))
248 rank_a = atomdist.rank_a
250 self.my_atom_indices = []
251 self.my_indices = []
252 I1 = 0
253 for a, rank in enumerate(rank_a):
254 I2 = I1 + self.get_function_count(a)
255 if rank == self.comm.rank:
256 self.my_atom_indices.append(a)
257 self.my_indices.append((a, I1, I2))
258 I1 = I2
259 self.nI = I1
261 @trace
262 def expand(self, G1=0, G2=None, cc=False):
263 """Expand functions in plane-waves.
265 q: int
266 k-point index.
267 G1: int
268 Start G-vector index.
269 G2: int
270 End G-vector index.
271 cc: bool
272 Complex conjugate.
273 """
274 xp = self.xp
276 if G2 is None:
277 G2 = self.Y_GL.shape[0]
279 Gk_Gv = self.G_plus_k_Gv[G1:G2]
280 pos_av = self.pos_av
281 eikR_a = xp.asarray(self.eikR_a,
282 dtype=as_complex_dtype(self.dtype))
284 f_Gs = self.f_Gs[G1:G2]
285 Y_GL = self.Y_GL[G1:G2]
287 if not self.real:
288 f_GI = xp.empty((G2 - G1, self.nI), as_complex_dtype(self.dtype))
289 else:
290 # Special layout because BLAS does not have real-complex
291 # multiplications. f_GI(G,I) layout:
292 #
293 # real(G1, 0), real(G1, 1), ...
294 # imag(G1, 0), imag(G1, 1), ...
295 # real(G1+1, 0), real(G1+1, 1), ...
296 # imag(G1+1, 0), imag(G1+1, 1), ...
297 # ...
299 f_GI = xp.empty((2 * (G2 - G1), self.nI),
300 as_real_dtype(self.dtype))
302 if xp is np:
303 # Fast C-code:
304 pwlfc_expand(f_Gs, Gk_Gv, pos_av, eikR_a, Y_GL,
305 self.l_s, self.a_J, self.s_J,
306 cc, f_GI)
307 elif cupy_is_fake:
308 pwlfc_expand(f_Gs._data, Gk_Gv._data, pos_av._data,
309 eikR_a._data, Y_GL._data,
310 self.l_s._data, self.a_J._data, self.s_J._data,
311 cc, f_GI._data)
312 else:
313 pwlfc_expand_gpu(f_Gs, Gk_Gv, pos_av, eikR_a, Y_GL,
314 self.l_s, self.a_J, self.s_J,
315 cc, f_GI, self.I_J)
316 return f_GI
318 def block(self, ensure_same_number_of_blocks=False):
319 nG = self.Y_GL.shape[0]
320 B = self.blocksize
321 if B:
322 G1 = 0
323 while G1 < nG:
324 G2 = min(G1 + B, nG)
325 yield G1, G2
326 G1 = G2
327 if ensure_same_number_of_blocks:
328 # Make sure we yield the same number of times:
329 nb = (self.pw.maxmysize + B - 1) // B
330 mynb = (nG + B - 1) // B
331 if mynb < nb:
332 yield nG, nG # empty block
333 else:
334 yield 0, nG
336 @trace
337 def get_emiGR_Ga(self, G1, G2):
338 Gk_Gv = self.G_plus_k_Gv[G1:G2]
339 GkR_Ga = Gk_Gv @ self.pos_avT
340 return self.xp.exp(-1j * GkR_Ga) * self.eikR_a
342 @trace
343 def add(self, a_xG, c_axi=1.0, q=None):
344 if self.nI == 0:
345 return
346 c_xI = self.xp.empty(a_xG.shape[:-1] + (self.nI,), self.dtype)
348 if isinstance(c_axi, float):
349 assert a_xG.ndim == 1
350 c_xI[:] = c_axi
351 else:
352 if self.comm.size != 1:
353 c_xI[:] = 0.0
354 for a, I1, I2 in self.my_indices:
355 c_xI[..., I1:I2] = c_axi[a] * self.eikR_a[a].conj()
356 if self.comm.size != 1:
357 self.comm.sum(c_xI)
359 nx = prod(c_xI.shape[:-1])
360 if nx == 0:
361 return
362 c_xI = c_xI.reshape((nx, self.nI))
363 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.dtype)
365 for G1, G2 in self.block():
366 f_GI = self.expand(G1, G2, cc=False)
368 if self.real:
369 # f_IG = f_IG.view(float)
370 G1 *= 2
371 G2 *= 2
373 with tracectx('gemm'):
374 if self.xp is np:
375 mmm(1.0 / self.pw.dv, c_xI, 'N', f_GI, 'T',
376 1.0, a_xG[:, G1:G2])
377 else:
378 gpu_gemm('N', 'T',
379 c_xI, f_GI, a_xG[:, G1:G2],
380 1.0 / self.pw.dv, 1.0)
382 @trace
383 def integrate(self, a_xG, c_axi=None, q=-1, add_to=False):
384 xp = self.xp
385 if self.nI == 0:
386 return c_axi
387 c_xI = xp.zeros(a_xG.shape[:-1] + (self.nI,), self.dtype)
389 nx = prod(c_xI.shape[:-1])
390 if nx == 0:
391 return
392 b_xI = c_xI.reshape((nx, self.nI))
393 a_xG = a_xG.reshape((nx, a_xG.shape[-1]))
395 alpha = 1.0
396 if self.real:
397 alpha *= 2
398 a_xG = a_xG.view(self.dtype)
400 if c_axi is None:
401 c_axi = self.dict(a_xG.shape[:-1])
403 x = 0.0
404 for G1, G2 in self.block():
405 f_GI = self.expand(G1, G2, cc=not self.real)
406 if self.real:
407 if G1 == 0 and self.comm.rank == 0:
408 f_GI[0] *= 0.5
409 G1 *= 2
410 G2 *= 2
411 if xp is np:
412 mmm(alpha, a_xG[:, G1:G2], 'N', f_GI, 'N', x, b_xI)
413 else:
414 gpu_gemm('N', 'N',
415 a_xG[:, G1:G2], f_GI, b_xI,
416 alpha, x)
417 x = 1.0
419 self.comm.sum(b_xI)
420 with tracectx('Displace integrals', gpu=True):
421 if add_to:
422 for a, I1, I2 in self.my_indices:
423 c_axi[a] += self.eikR_a[a] * c_xI[..., I1:I2]
424 else:
425 for a, I1, I2 in self.my_indices:
426 c_axi[a][:] = self.eikR_a[a] * c_xI[..., I1:I2]
428 return c_axi
430 @trace
431 def derivative(self, a_xG, c_axiv=None, q=-1):
432 xp = self.xp
433 c_vxI = xp.zeros((3,) + a_xG.shape[:-1] + (self.nI,), self.dtype)
434 nx = prod(c_vxI.shape[1:-1])
435 if nx == 0:
436 return
437 b_vxI = c_vxI.reshape((3, nx, self.nI))
438 a_xG = a_xG.reshape((nx, a_xG.shape[-1])).view(self.dtype)
440 alpha = 1.0
442 if c_axiv is None:
443 c_axiv = self.dict(a_xG.shape[:-1], derivative=True)
445 x = 0.0
446 for G1, G2 in self.block():
447 f_GI = self.expand(G1, G2, cc=True)
448 G_Gv = xp.asarray(self.pw.G_plus_k_Gv[G1:G2],
449 dtype=as_real_dtype(self.dtype))
450 if self.real:
451 d_GI = xp.empty_like(f_GI)
452 for v in range(3):
453 d_GI[::2] = f_GI[1::2] * G_Gv[:, v, np.newaxis]
454 d_GI[1::2] = f_GI[::2] * G_Gv[:, v, np.newaxis]
455 if xp is np:
456 mmm(2 * alpha,
457 a_xG[:, 2 * G1:2 * G2], 'N',
458 d_GI, 'N',
459 x, b_vxI[v])
460 else:
461 gpu_gemm('N', 'N',
462 a_xG[:, 2 * G1:2 * G2],
463 d_GI,
464 b_vxI[v],
465 2 * alpha, x)
466 else:
467 for v in range(3):
468 if xp is np:
469 mmm(-alpha,
470 a_xG[:, G1:G2], 'N',
471 f_GI * G_Gv[:, v, np.newaxis], 'N',
472 x, b_vxI[v])
473 else:
474 gpu_gemm('N', 'N',
475 a_xG[:, G1:G2],
476 f_GI * G_Gv[:, v, np.newaxis],
477 b_vxI[v],
478 -alpha, x)
479 x = 1.0
481 self.comm.sum(c_vxI)
483 for v in range(3):
484 if self.real:
485 for a, I1, I2 in self.my_indices:
486 c_axiv[a][..., v] = c_vxI[v, ..., I1:I2]
487 else:
488 for a, I1, I2 in self.my_indices:
489 c_axiv[a][..., v] = (1.0j * self.eikR_a[a] *
490 c_vxI[v, ..., I1:I2])
492 return c_axiv
494 @trace
495 def stress_tensor_contribution(self, a_xG, c_axi=1.0):
496 xp = self.xp
497 cache = {}
498 things = []
499 I1 = 0
500 lmax = 0
501 for a, spline_j in enumerate(self.spline_aj):
502 for spline in spline_j:
503 if spline not in cache:
504 s = rescaled_fourier_bessel_transform(spline)
505 G_G = (2 * self.pw.ekin_G)**0.5
506 f_G = []
507 dfdGoG_G = []
508 for G in G_G:
509 f, dfdG = s.get_value_and_derivative(G)
510 if G < 1e-10:
511 G = 1.0
512 f_G.append(f)
513 dfdGoG_G.append(dfdG / G)
514 f_G = xp.array(f_G)
515 dfdGoG_G = xp.array(dfdGoG_G)
516 cache[spline] = (f_G, dfdGoG_G)
517 else:
518 f_G, dfdGoG_G = cache[spline]
519 l = spline.l
520 lmax = max(l, lmax)
521 I2 = I1 + 2 * l + 1
522 things.append((a, l, I1, I2, f_G, dfdGoG_G))
523 I1 = I2
525 if isinstance(c_axi, float):
526 c_axi = {a: c_axi for a in range(len(self.pos_av))}
528 G0_Gv = self.pw.G_plus_k_Gv
530 stress_vv = xp.zeros((3, 3))
531 for G1, G2 in self.block(ensure_same_number_of_blocks=True):
532 G_Gv = G0_Gv[G1:G2]
533 Z_LvG = xp.array([nablarlYL(L, G_Gv.T)
534 for L in range((lmax + 1)**2)])
535 G_Gv = xp.asarray(G_Gv)
536 aa_xG = a_xG[..., G1:G2]
537 for v1 in range(3):
538 for v2 in range(3):
539 stress_vv[v1, v2] += self._stress_tensor_contribution(
540 v1, v2, things, G1, G2, G_Gv, aa_xG, c_axi, Z_LvG)
542 return stress_vv
544 @trace
545 def _stress_tensor_contribution(self, v1, v2, things, G1, G2,
546 G_Gv, a_xG, c_axi, Z_LvG):
547 xp = self.xp
548 f_IG = xp.empty((self.nI, G2 - G1), as_complex_dtype(self.dtype))
549 emiGR_Ga = self.get_emiGR_Ga(G1, G2)
550 Y_LG = self.Y_GL.T
551 for a, l, I1, I2, f_G, dfdGoG_G in things:
552 L1 = l**2
553 L2 = (l + 1)**2
554 f_IG[I1:I2] = (emiGR_Ga[:, a] * (-1.0j)**l *
555 (dfdGoG_G[G1:G2] * G_Gv[:, v1] * G_Gv[:, v2] *
556 Y_LG[L1:L2, G1:G2] +
557 f_G[G1:G2] * G_Gv[:, v1] * Z_LvG[L1:L2, v2]))
559 c_xI = xp.zeros(a_xG.shape[:-1] + (self.nI,), self.pw.dtype)
561 x = prod(c_xI.shape[:-1])
562 if x == 0:
563 return 0.0
564 b_xI = c_xI.reshape((x, self.nI))
565 a_xG = a_xG.reshape((x, a_xG.shape[-1]))
567 alpha = 1.0
568 if self.real:
569 alpha = 2.0
570 if G1 == 0 and self.pw.comm.rank == 0:
571 f_IG[:, 0] *= 0.5
572 f_IG = f_IG.view(as_real_dtype(f_IG.dtype))
573 a_xG = a_xG.copy().view(as_real_dtype(f_IG.dtype))
575 if xp is np:
576 mmm(alpha, a_xG, 'N', f_IG, 'C', 0.0, b_xI)
577 else:
578 gpu_gemm('N', 'H', a_xG, f_IG, b_xI, alpha, 0.0)
579 self.comm.sum(b_xI)
581 stress = 0.0
582 for a, I1, I2 in self.my_indices:
583 stress -= self.eikR_a[a] * (c_axi[a] * c_xI[..., I1:I2]).sum()
584 return stress.real