Coverage for gpaw/transformers.py: 81%

100 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1# Copyright (C) 2003 CAMP 

2# Please see the accompanying LICENSE file for further information. 

3 

4"""Grid transformers. 

5 

6This module defines tools for doing interpolations/restrictions between 

7differentt uniform 3D grids. 

8""" 

9 

10import numpy as np 

11 

12from gpaw import debug 

13from gpaw.gpu import cupy_is_fake 

14from gpaw.utilities import is_contiguous 

15import gpaw.cgpaw as cgpaw 

16 

17 

18class _Transformer: 

19 def __init__(self, gdin, gdout, nn=1, dtype=float, xp=np): 

20 self.gdin = gdin 

21 self.gdout = gdout 

22 self.nn = nn 

23 assert 1 <= nn <= 4 

24 self.dtype = dtype 

25 self.xp = xp 

26 

27 pad_cd = np.empty((3, 2), int) 

28 neighborpad_cd = np.empty((3, 2), int) 

29 skip_cd = np.empty((3, 2), int) 

30 

31 if (gdin.N_c == 2 * gdout.N_c).all(): 

32 # Restriction: 

33 pad_cd[:, 0] = 2 * nn - 1 - 2 * gdout.beg_c + gdin.beg_c 

34 pad_cd[:, 1] = 2 * nn - 2 + 2 * gdout.end_c - gdin.end_c 

35 neighborpad_cd[:, 0] = 2 * nn - 2 + 2 * gdout.beg_c - gdin.beg_c 

36 neighborpad_cd[:, 1] = 2 * nn - 1 - 2 * gdout.end_c + gdin.end_c 

37 self.interpolate = False 

38 else: 

39 assert (gdout.N_c == 2 * gdin.N_c).all() 

40 # Interpolation: 

41 pad_cd[:, 0] = nn - 1 - gdout.beg_c // 2 + gdin.beg_c 

42 pad_cd[:, 1] = nn + gdout.end_c // 2 - gdin.end_c 

43 neighborpad_cd[:, 0] = nn + gdout.beg_c // 2 - gdin.beg_c 

44 neighborpad_cd[:, 1] = nn - 1 - gdout.end_c // 2 + gdin.end_c 

45 skip_cd[:, 0] = gdout.beg_c % 2 

46 skip_cd[:, 1] = gdout.end_c % 2 

47 self.interpolate = True 

48 

49 inpoints = (gdin.n_c[0] + 2 * nn - 1) * (gdin.n_c[1] + 2 * nn - 1) 

50 outpoints = gdout.n_c[0] * gdout.n_c[1] 

51 

52 if inpoints > outpoints: 

53 points = ' x '.join([str(N) for N in gdin.N_c]) 

54 raise ValueError('Cannot construct interpolator. Grid %s ' 

55 'may be too small' % points) 

56 

57 assert (pad_cd.ravel() >= 0).all() 

58 self.ngpin = tuple(gdin.n_c) 

59 self.ngpout = tuple(gdout.n_c) 

60 assert dtype in [float, complex] 

61 

62 self.pad_cd = pad_cd 

63 self.neighborpad_cd = neighborpad_cd 

64 self.skip_cd = skip_cd 

65 

66 if gdin.comm.size > 1: 

67 comm = gdin.comm.get_c_object() 

68 else: 

69 comm = None 

70 

71 self.transformer = cgpaw.Transformer(gdin.n_c, gdout.n_c, 

72 2 * nn, pad_cd, 

73 neighborpad_cd, skip_cd, 

74 gdin.neighbor_cd, 

75 dtype == float, comm, 

76 self.interpolate, 

77 xp is not np) 

78 

79 def apply(self, input, output=None, phases=None): 

80 if output is None: 

81 output = self.gdout.empty(input.shape[:-3], dtype=self.dtype, 

82 xp=self.xp) 

83 if self.xp is np: 

84 self.transformer.apply(input, output, phases) 

85 elif cupy_is_fake: 

86 self.transformer.apply(input._data, output._data, phases) 

87 else: 

88 self.transformer.apply_gpu(input.data.ptr, 

89 output.data.ptr, 

90 input.shape, input.dtype, phases) 

91 return output 

92 

93 def get_async_sizes(self): 

94 return self.transformer.get_async_sizes() 

95 

96 

97class TransformerWrapper: 

98 def __init__(self, transformer): 

99 self.transformer = transformer 

100 self.dtype = transformer.dtype 

101 self.ngpin = transformer.ngpin 

102 self.ngpout = transformer.ngpout 

103 self.nn = transformer.nn 

104 

105 def apply(self, input, output=None, phases=None): 

106 assert is_contiguous(input, self.dtype) 

107 assert input.shape[-3:] == self.ngpin 

108 if output is not None: 

109 assert is_contiguous(output, self.dtype) 

110 assert output.shape[-3:] == self.ngpout 

111 assert (self.dtype == float or 

112 (phases.dtype == complex and 

113 phases.shape == (3, 2))) 

114 

115 return self.transformer.apply(input, output, phases) 

116 

117 def get_async_sizes(self): 

118 return self.transformer.get_async_sizes() 

119 

120 

121def Transformer(gdin, gdout, nn=1, dtype=float, xp=np): 

122 if nn != 9: 

123 t = _Transformer(gdin, gdout, nn, dtype, xp) 

124 if debug: 

125 t = TransformerWrapper(t) 

126 return t 

127 

128 class T: 

129 nn = 1 

130 

131 def apply(self, input, output, phases=None): 

132 output[:] = input 

133 

134 return T() 

135 

136 

137def multiple_transform_apply(transformerlist, inputs, outputs, phases=None): 

138 return cgpaw.multiple_transform_apply(transformerlist, inputs, outputs, 

139 phases) 

140 

141 

142def coefs(k, p): 

143 for i in range(0, k * p, p): 

144 print('%2d' % i, end=' ') 

145 for x in range((k // 2 - 1) * p, k // 2 * p + 1): 

146 n = 1 

147 d = 1 

148 for j in range(0, k * p, p): 

149 if j == i: 

150 continue 

151 n *= x - j 

152 d *= i - j 

153 print('%14.16f' % (n / d), end=' ') 

154 print()