Coverage for gpaw/transformers.py: 81%
100 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1# Copyright (C) 2003 CAMP
2# Please see the accompanying LICENSE file for further information.
4"""Grid transformers.
6This module defines tools for doing interpolations/restrictions between
7differentt uniform 3D grids.
8"""
10import numpy as np
12from gpaw import debug
13from gpaw.gpu import cupy_is_fake
14from gpaw.utilities import is_contiguous
15import gpaw.cgpaw as cgpaw
18class _Transformer:
19 def __init__(self, gdin, gdout, nn=1, dtype=float, xp=np):
20 self.gdin = gdin
21 self.gdout = gdout
22 self.nn = nn
23 assert 1 <= nn <= 4
24 self.dtype = dtype
25 self.xp = xp
27 pad_cd = np.empty((3, 2), int)
28 neighborpad_cd = np.empty((3, 2), int)
29 skip_cd = np.empty((3, 2), int)
31 if (gdin.N_c == 2 * gdout.N_c).all():
32 # Restriction:
33 pad_cd[:, 0] = 2 * nn - 1 - 2 * gdout.beg_c + gdin.beg_c
34 pad_cd[:, 1] = 2 * nn - 2 + 2 * gdout.end_c - gdin.end_c
35 neighborpad_cd[:, 0] = 2 * nn - 2 + 2 * gdout.beg_c - gdin.beg_c
36 neighborpad_cd[:, 1] = 2 * nn - 1 - 2 * gdout.end_c + gdin.end_c
37 self.interpolate = False
38 else:
39 assert (gdout.N_c == 2 * gdin.N_c).all()
40 # Interpolation:
41 pad_cd[:, 0] = nn - 1 - gdout.beg_c // 2 + gdin.beg_c
42 pad_cd[:, 1] = nn + gdout.end_c // 2 - gdin.end_c
43 neighborpad_cd[:, 0] = nn + gdout.beg_c // 2 - gdin.beg_c
44 neighborpad_cd[:, 1] = nn - 1 - gdout.end_c // 2 + gdin.end_c
45 skip_cd[:, 0] = gdout.beg_c % 2
46 skip_cd[:, 1] = gdout.end_c % 2
47 self.interpolate = True
49 inpoints = (gdin.n_c[0] + 2 * nn - 1) * (gdin.n_c[1] + 2 * nn - 1)
50 outpoints = gdout.n_c[0] * gdout.n_c[1]
52 if inpoints > outpoints:
53 points = ' x '.join([str(N) for N in gdin.N_c])
54 raise ValueError('Cannot construct interpolator. Grid %s '
55 'may be too small' % points)
57 assert (pad_cd.ravel() >= 0).all()
58 self.ngpin = tuple(gdin.n_c)
59 self.ngpout = tuple(gdout.n_c)
60 assert dtype in [float, complex]
62 self.pad_cd = pad_cd
63 self.neighborpad_cd = neighborpad_cd
64 self.skip_cd = skip_cd
66 if gdin.comm.size > 1:
67 comm = gdin.comm.get_c_object()
68 else:
69 comm = None
71 self.transformer = cgpaw.Transformer(gdin.n_c, gdout.n_c,
72 2 * nn, pad_cd,
73 neighborpad_cd, skip_cd,
74 gdin.neighbor_cd,
75 dtype == float, comm,
76 self.interpolate,
77 xp is not np)
79 def apply(self, input, output=None, phases=None):
80 if output is None:
81 output = self.gdout.empty(input.shape[:-3], dtype=self.dtype,
82 xp=self.xp)
83 if self.xp is np:
84 self.transformer.apply(input, output, phases)
85 elif cupy_is_fake:
86 self.transformer.apply(input._data, output._data, phases)
87 else:
88 self.transformer.apply_gpu(input.data.ptr,
89 output.data.ptr,
90 input.shape, input.dtype, phases)
91 return output
93 def get_async_sizes(self):
94 return self.transformer.get_async_sizes()
97class TransformerWrapper:
98 def __init__(self, transformer):
99 self.transformer = transformer
100 self.dtype = transformer.dtype
101 self.ngpin = transformer.ngpin
102 self.ngpout = transformer.ngpout
103 self.nn = transformer.nn
105 def apply(self, input, output=None, phases=None):
106 assert is_contiguous(input, self.dtype)
107 assert input.shape[-3:] == self.ngpin
108 if output is not None:
109 assert is_contiguous(output, self.dtype)
110 assert output.shape[-3:] == self.ngpout
111 assert (self.dtype == float or
112 (phases.dtype == complex and
113 phases.shape == (3, 2)))
115 return self.transformer.apply(input, output, phases)
117 def get_async_sizes(self):
118 return self.transformer.get_async_sizes()
121def Transformer(gdin, gdout, nn=1, dtype=float, xp=np):
122 if nn != 9:
123 t = _Transformer(gdin, gdout, nn, dtype, xp)
124 if debug:
125 t = TransformerWrapper(t)
126 return t
128 class T:
129 nn = 1
131 def apply(self, input, output, phases=None):
132 output[:] = input
134 return T()
137def multiple_transform_apply(transformerlist, inputs, outputs, phases=None):
138 return cgpaw.multiple_transform_apply(transformerlist, inputs, outputs,
139 phases)
142def coefs(k, p):
143 for i in range(0, k * p, p):
144 print('%2d' % i, end=' ')
145 for x in range((k // 2 - 1) * p, k // 2 * p + 1):
146 n = 1
147 d = 1
148 for j in range(0, k * p, p):
149 if j == i:
150 continue
151 n *= x - j
152 d *= i - j
153 print('%14.16f' % (n / d), end=' ')
154 print()