Coverage for gpaw/tddft/utils.py: 39%

119 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1# Written by Lauri Lehtovaara 2008 

2import numpy as np 

3 

4from gpaw.utilities.blas import axpy 

5 

6 

7class MultiBlas: 

8 def __init__(self, gd): 

9 self.gd = gd 

10 

11 # Multivector ZAXPY: a x + y => y 

12 def multi_zaxpy(self, a, x, y, nvec): 

13 if isinstance(a, (float, complex)): 

14 for i in range(nvec): 

15 axpy(a * (1 + 0j), x[i], y[i]) 

16 else: 

17 for i in range(nvec): 

18 axpy(a[i] * (1.0 + 0.0j), x[i], y[i]) 

19 

20 # Multivector dot product, a^H b, where ^H is transpose 

21 def multi_zdotc(self, s, x, y, nvec): 

22 for i in range(nvec): 

23 s[i] = np.vdot(x[i], y[i]) 

24 self.gd.comm.sum(s) 

25 return s 

26 

27 # Multiscale: a x => x 

28 def multi_scale(self, a, x, nvec): 

29 if isinstance(a, (float, complex)): 

30 x *= a 

31 else: 

32 for i in range(nvec): 

33 x[i] *= a[i] 

34 

35 

36class BandPropertyMonitor: 

37 def __init__(self, wfs, name, interval=1): 

38 self.niter = 0 

39 self.interval = interval 

40 

41 self.wfs = wfs 

42 

43 self.name = name 

44 

45 def __call__(self): 

46 self.update(self.wfs) 

47 self.niter += self.interval 

48 

49 def update(self, wfs): 

50 # strictly serial XXX! 

51 data_un = [] 

52 

53 for u, kpt in enumerate(wfs.kpt_u): 

54 data_n = getattr(kpt, self.name) 

55 

56 data_un.append(data_n) 

57 

58 self.write(np.array(data_un)) 

59 

60 def write(self, data): 

61 pass 

62 

63 

64class BandPropertyWriter(BandPropertyMonitor): 

65 def __init__(self, filename, wfs, name, interval=1): 

66 BandPropertyMonitor.__init__(self, wfs, name, interval) 

67 self.fileobj = open(filename, 'w') 

68 

69 def write(self, data): 

70 self.fileobj.write(data.tostring()) 

71 self.fileobj.flush() 

72 

73 def __del__(self): 

74 self.fileobj.close() 

75 

76 

77class StaticOverlapMonitor: 

78 def __init__(self, wfs, wf_u, P_aui, interval=1): 

79 self.niter = 0 

80 self.interval = interval 

81 

82 self.wfs = wfs 

83 

84 self.wf_u = wf_u 

85 self.P_aui = P_aui 

86 

87 def __call__(self): 

88 self.update(self.wfs) 

89 self.niter += self.interval 

90 

91 def update(self, wfs, calculate_P_ani=False): 

92 # strictly serial XXX! 

93 Porb_un = [] 

94 

95 for u, kpt in enumerate(wfs.kpt_u): 

96 swf = self.wf_u[u].ravel() 

97 

98 psit_n = kpt.psit_nG.reshape((len(kpt.f_n), -1)) 

99 Porb_n = np.dot(psit_n.conj(), swf) * wfs.gd.dv 

100 

101 P_ani = kpt.P_ani 

102 

103 if calculate_P_ani: 

104 # wfs.pt.integrate(psit_nG, P_ani, kpt.q) 

105 raise NotImplementedError( 

106 'In case you were wondering, TODO XXX') 

107 

108 for a, P_ni in P_ani.items(): 

109 sP_i = self.P_aui[a][u] 

110 for n in range(wfs.bd.nbands): 

111 for i in range(len(P_ni[0])): 

112 for j in range(len(P_ni[0])): 

113 Porb_n[n] += (P_ni[n][i].conj() * 

114 wfs.setups[a].dO_ii[i][j] * 

115 sP_i[j]) 

116 

117 Porb_un.append(Porb_n) 

118 

119 self.write(np.array(Porb_un)) 

120 

121 def write(self, data): 

122 pass 

123 

124 

125class StaticOverlapWriter(StaticOverlapMonitor): 

126 def __init__(self, filename, wfs, overlap, interval=1): 

127 StaticOverlapMonitor.__init__(self, wfs, overlap, interval) 

128 self.fileobj = open(filename, 'w') 

129 

130 def write(self, data): 

131 self.fileobj.write(data.tostring()) 

132 self.fileobj.flush() 

133 

134 def __del__(self): 

135 self.fileobj.close() 

136 

137 

138class DynamicOverlapMonitor: 

139 def __init__(self, wfs, overlap, interval=1): 

140 self.niter = 0 

141 self.interval = interval 

142 

143 self.setups = overlap.setups 

144 self.operator = overlap.operator 

145 self.wfs = wfs 

146 

147 def __call__(self): 

148 self.update(self.wfs) 

149 self.niter += self.interval 

150 

151 def update(self, wfs, calculate_P_ani=False): 

152 

153 # strictly serial XXX! 

154 S_unn = [] 

155 

156 for kpt in wfs.kpt_u: 

157 psit_nG = kpt.psit_nG 

158 P_ani = kpt.P_ani 

159 

160 if calculate_P_ani: 

161 # wfs.pt.integrate(psit_nG, P_ani, kpt.q) 

162 raise NotImplementedError( 

163 'In case you were wondering, TODO XXX') 

164 

165 # Construct the overlap matrix: 

166 def S(x): 

167 return x 

168 

169 dS_aii = {a: self.setups[a].dO_ii for a in P_ani} 

170 S_nn = self.operator.calculate_matrix_elements(psit_nG, P_ani, 

171 S, dS_aii) 

172 S_unn.append(S_nn) 

173 

174 self.write(np.array(S_unn)) 

175 

176 def write(self, data): 

177 pass 

178 

179 

180class DynamicOverlapWriter(DynamicOverlapMonitor): 

181 def __init__(self, filename, wfs, overlap, interval=1): 

182 DynamicOverlapMonitor.__init__(self, wfs, overlap, interval) 

183 self.fileobj = open(filename, 'w') 

184 

185 def write(self, data): 

186 self.fileobj.write(data.tostring()) 

187 self.fileobj.flush() 

188 

189 def __del__(self): 

190 self.fileobj.close()