Coverage for gpaw/tddft/utils.py: 39%
119 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1# Written by Lauri Lehtovaara 2008
2import numpy as np
4from gpaw.utilities.blas import axpy
7class MultiBlas:
8 def __init__(self, gd):
9 self.gd = gd
11 # Multivector ZAXPY: a x + y => y
12 def multi_zaxpy(self, a, x, y, nvec):
13 if isinstance(a, (float, complex)):
14 for i in range(nvec):
15 axpy(a * (1 + 0j), x[i], y[i])
16 else:
17 for i in range(nvec):
18 axpy(a[i] * (1.0 + 0.0j), x[i], y[i])
20 # Multivector dot product, a^H b, where ^H is transpose
21 def multi_zdotc(self, s, x, y, nvec):
22 for i in range(nvec):
23 s[i] = np.vdot(x[i], y[i])
24 self.gd.comm.sum(s)
25 return s
27 # Multiscale: a x => x
28 def multi_scale(self, a, x, nvec):
29 if isinstance(a, (float, complex)):
30 x *= a
31 else:
32 for i in range(nvec):
33 x[i] *= a[i]
36class BandPropertyMonitor:
37 def __init__(self, wfs, name, interval=1):
38 self.niter = 0
39 self.interval = interval
41 self.wfs = wfs
43 self.name = name
45 def __call__(self):
46 self.update(self.wfs)
47 self.niter += self.interval
49 def update(self, wfs):
50 # strictly serial XXX!
51 data_un = []
53 for u, kpt in enumerate(wfs.kpt_u):
54 data_n = getattr(kpt, self.name)
56 data_un.append(data_n)
58 self.write(np.array(data_un))
60 def write(self, data):
61 pass
64class BandPropertyWriter(BandPropertyMonitor):
65 def __init__(self, filename, wfs, name, interval=1):
66 BandPropertyMonitor.__init__(self, wfs, name, interval)
67 self.fileobj = open(filename, 'w')
69 def write(self, data):
70 self.fileobj.write(data.tostring())
71 self.fileobj.flush()
73 def __del__(self):
74 self.fileobj.close()
77class StaticOverlapMonitor:
78 def __init__(self, wfs, wf_u, P_aui, interval=1):
79 self.niter = 0
80 self.interval = interval
82 self.wfs = wfs
84 self.wf_u = wf_u
85 self.P_aui = P_aui
87 def __call__(self):
88 self.update(self.wfs)
89 self.niter += self.interval
91 def update(self, wfs, calculate_P_ani=False):
92 # strictly serial XXX!
93 Porb_un = []
95 for u, kpt in enumerate(wfs.kpt_u):
96 swf = self.wf_u[u].ravel()
98 psit_n = kpt.psit_nG.reshape((len(kpt.f_n), -1))
99 Porb_n = np.dot(psit_n.conj(), swf) * wfs.gd.dv
101 P_ani = kpt.P_ani
103 if calculate_P_ani:
104 # wfs.pt.integrate(psit_nG, P_ani, kpt.q)
105 raise NotImplementedError(
106 'In case you were wondering, TODO XXX')
108 for a, P_ni in P_ani.items():
109 sP_i = self.P_aui[a][u]
110 for n in range(wfs.bd.nbands):
111 for i in range(len(P_ni[0])):
112 for j in range(len(P_ni[0])):
113 Porb_n[n] += (P_ni[n][i].conj() *
114 wfs.setups[a].dO_ii[i][j] *
115 sP_i[j])
117 Porb_un.append(Porb_n)
119 self.write(np.array(Porb_un))
121 def write(self, data):
122 pass
125class StaticOverlapWriter(StaticOverlapMonitor):
126 def __init__(self, filename, wfs, overlap, interval=1):
127 StaticOverlapMonitor.__init__(self, wfs, overlap, interval)
128 self.fileobj = open(filename, 'w')
130 def write(self, data):
131 self.fileobj.write(data.tostring())
132 self.fileobj.flush()
134 def __del__(self):
135 self.fileobj.close()
138class DynamicOverlapMonitor:
139 def __init__(self, wfs, overlap, interval=1):
140 self.niter = 0
141 self.interval = interval
143 self.setups = overlap.setups
144 self.operator = overlap.operator
145 self.wfs = wfs
147 def __call__(self):
148 self.update(self.wfs)
149 self.niter += self.interval
151 def update(self, wfs, calculate_P_ani=False):
153 # strictly serial XXX!
154 S_unn = []
156 for kpt in wfs.kpt_u:
157 psit_nG = kpt.psit_nG
158 P_ani = kpt.P_ani
160 if calculate_P_ani:
161 # wfs.pt.integrate(psit_nG, P_ani, kpt.q)
162 raise NotImplementedError(
163 'In case you were wondering, TODO XXX')
165 # Construct the overlap matrix:
166 def S(x):
167 return x
169 dS_aii = {a: self.setups[a].dO_ii for a in P_ani}
170 S_nn = self.operator.calculate_matrix_elements(psit_nG, P_ani,
171 S, dS_aii)
172 S_unn.append(S_nn)
174 self.write(np.array(S_unn))
176 def write(self, data):
177 pass
180class DynamicOverlapWriter(DynamicOverlapMonitor):
181 def __init__(self, filename, wfs, overlap, interval=1):
182 DynamicOverlapMonitor.__init__(self, wfs, overlap, interval)
183 self.fileobj = open(filename, 'w')
185 def write(self, data):
186 self.fileobj.write(data.tostring())
187 self.fileobj.flush()
189 def __del__(self):
190 self.fileobj.close()