Coverage for gpaw/wavefunctions/arrays.py: 80%
282 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import numpy as np
3from gpaw.matrix import Matrix, create_distribution
6class MatrixInFile:
7 def __init__(self, M, N, dtype, data, dist):
8 self.shape = (M, N)
9 self.dtype = dtype
10 self.array = data # pointer to data in a file
11 self.dist = create_distribution(M, N, *dist)
14class ArrayWaveFunctions:
15 def __init__(self, M, N, dtype, data, dist, collinear):
16 self.collinear = collinear
17 if not collinear:
18 N *= 2
19 if data is None or isinstance(data, np.ndarray):
20 self.matrix = Matrix(M, N, dtype, data, dist)
21 self.in_memory = True
22 else:
23 self.matrix = MatrixInFile(M, N, dtype, data, dist)
24 self.in_memory = False
25 self.comm = None
26 self.dtype = self.matrix.dtype
28 def __len__(self):
29 return len(self.matrix)
31 def multiply(self, alpha, opa, b, opb, beta, c, symmetric):
32 self.matrix.multiply(alpha, opa, b.matrix, opb, beta, c, symmetric)
33 if opa == 'N' and self.comm:
34 if self.comm.size > 1:
35 c.comm = self.comm
36 c.state = 'a sum is needed'
37 assert opb in 'TC' and b.comm is self.comm
39 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False,
40 operator=None, result=None, serial=False):
41 if out is None:
42 out = Matrix(len(self), len(other or self), dtype=self.dtype,
43 dist=(self.matrix.dist.comm,
44 self.matrix.dist.rows,
45 self.matrix.dist.columns))
46 if other is None or isinstance(other, ArrayWaveFunctions):
47 assert cc
48 if other is None:
49 assert symmetric
50 operate_and_multiply(self, self.dv, out, operator, result)
51 elif not serial:
52 assert not symmetric
53 operate_and_multiply_not_symmetric(self, self.dv, out,
54 other)
55 else:
56 self.multiply(self.dv, 'N', other, 'C', 0.0, out, symmetric)
57 else:
58 assert not cc
59 P_ani = {a: P_ni for a, P_ni in out.items()}
60 other.integrate(self.array, P_ani, self.kpt)
61 return out
63 def add(self, lfc, coefs):
64 lfc.add(self.array, dict(coefs.items()), self.kpt)
66 def apply(self, func, out=None):
67 out = out or self.new()
68 func(self.array, out.array)
69 return out
71 def __setitem__(self, i, x):
72 x.eval(self.matrix)
74 def __iadd__(self, other):
75 other.eval(self.matrix, 1.0)
76 return self
78 def eval(self, matrix):
79 matrix.array[:] = self.matrix.array
81 def read_from_file(self):
82 """Read wave functions from file into memory."""
83 matrix = Matrix(*self.matrix.shape,
84 dtype=self.dtype, dist=self.matrix.dist)
85 # Read band by band to save memory
86 rows = matrix.dist.rows
87 blocksize = (matrix.shape[0] + rows - 1) // rows
88 for myn, psit_G in enumerate(matrix.array):
89 n = matrix.dist.comm.rank * blocksize + myn
90 if self.comm.rank == 0:
91 big_psit_G = self.array[n]
92 if big_psit_G.dtype == complex and self.dtype == float:
93 big_psit_G = big_psit_G.view(float)
94 elif big_psit_G.dtype == float and self.dtype == complex:
95 big_psit_G = np.asarray(big_psit_G, complex)
96 else:
97 big_psit_G = None
98 self._distribute(big_psit_G, psit_G)
99 self.matrix = matrix
100 self.in_memory = True
103class UniformGridWaveFunctions(ArrayWaveFunctions):
104 def __init__(self, nbands, gd, dtype=None, data=None, kpt=None, dist=None,
105 spin=0, collinear=True):
106 ngpts = gd.n_c.prod()
107 ArrayWaveFunctions.__init__(self, nbands, ngpts, dtype, data, dist,
108 collinear)
110 M = self.matrix
112 if data is None:
113 M.array = M.array.reshape(-1).reshape(M.dist.shape)
115 self.myshape = (M.dist.shape[0],) + tuple(gd.n_c)
116 self.gd = gd
117 self.dv = gd.dv
118 self.kpt = kpt
119 self.spin = spin
120 self.comm = gd.comm
122 @property
123 def array(self):
124 if self.in_memory:
125 return self.matrix.array.reshape(self.myshape)
126 else:
127 return self.matrix.array
129 def _distribute(self, big_psit_R, psit_R):
130 self.gd.distribute(big_psit_R, psit_R.reshape(self.gd.n_c))
132 def __repr__(self):
133 s = ArrayWaveFunctions.__repr__(self).split('(')[1][:-1]
134 shape = self.gd.get_size_of_global_array()
135 s = 'UniformGridWaveFunctions({}, gpts={}x{}x{})'.format(s, *shape)
136 return s
138 def new(self, buf=None, dist='inherit', nbands=None):
139 if dist == 'inherit':
140 dist = self.matrix.dist
141 return UniformGridWaveFunctions(nbands or len(self),
142 self.gd, self.dtype,
143 buf,
144 self.kpt, dist,
145 self.spin)
147 def view(self, n1, n2):
148 return UniformGridWaveFunctions(n2 - n1, self.gd, self.dtype,
149 self.array[n1:n2],
150 self.kpt, None,
151 self.spin)
153 def plot(self):
154 import matplotlib.pyplot as plt
155 ax = plt.figure().add_subplot(111)
156 a, b, c = self.array.shape[1:]
157 ax.plot(self.array[0, a // 2, b // 2])
158 plt.show()
161class PlaneWaveExpansionWaveFunctions(ArrayWaveFunctions):
162 def __init__(self, nbands, pd, dtype=None, data=None, kpt=0, dist=None,
163 spin=0, collinear=True):
164 ng = ng0 = pd.myng_q[kpt]
165 if data is not None:
166 assert data.dtype == complex
167 if dtype == float:
168 ng *= 2
169 if isinstance(data, np.ndarray):
170 data = data.view(float)
172 ArrayWaveFunctions.__init__(self, nbands, ng, dtype, data, dist,
173 collinear)
174 self.pd = pd
175 self.gd = pd.gd
176 self.comm = pd.gd.comm
177 self.dv = pd.gd.dv / pd.gd.N_c.prod()
178 self.kpt = kpt
179 self.spin = spin
180 if collinear:
181 self.myshape = (self.matrix.dist.shape[0], ng0)
182 else:
183 self.myshape = (self.matrix.dist.shape[0], 2, ng0)
185 @property
186 def array(self):
187 if not self.in_memory:
188 return self.matrix.array
189 elif self.dtype == float:
190 return self.matrix.array.view(complex)
191 else:
192 return self.matrix.array.reshape(self.myshape)
194 def _distribute(self, big_psit_G, psit_G):
195 if self.collinear:
196 if self.dtype == float:
197 if big_psit_G is not None:
198 big_psit_G = big_psit_G.view(complex)
199 psit_G = psit_G.view(complex)
200 psit_G[:] = self.pd.scatter(big_psit_G, self.kpt)
201 else:
202 psit_sG = psit_G.reshape((2, -1))
203 psit_sG[0] = self.pd.scatter(big_psit_G[0], self.kpt)
204 psit_sG[1] = self.pd.scatter(big_psit_G[1], self.kpt)
206 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False,
207 operator=None, result=None, serial=False):
208 if other is None or isinstance(other, ArrayWaveFunctions):
209 if out is None:
210 out = Matrix(len(self), len(other or self), dtype=self.dtype,
211 dist=(self.matrix.dist.comm,
212 self.matrix.dist.rows,
213 self.matrix.dist.columns))
214 assert cc
215 if other is None:
216 assert symmetric
217 operate_and_multiply(self, self.dv, out, operator, result)
218 elif not serial:
219 assert not symmetric
220 operate_and_multiply_not_symmetric(self, self.dv, out,
221 other)
222 elif self.dtype == complex:
223 self.matrix.multiply(self.dv, 'N', other.matrix, 'C',
224 0.0, out, symmetric)
225 else:
226 self.matrix.multiply(2 * self.dv, 'N', other.matrix, 'T',
227 0.0, out, symmetric)
228 if self.gd.comm.rank == 0:
229 correction = np.outer(self.matrix.array[:, 0],
230 other.matrix.array[:, 0])
231 if symmetric:
232 out.array -= 0.5 * self.dv * (correction +
233 correction.T)
234 else:
235 out.array -= self.dv * correction
236 else:
237 assert not cc
238 P_ani = {a: P_ni for a, P_ni in out.items()}
239 other.integrate(self.array, P_ani, self.kpt)
240 return out
242 def new(self, buf=None, dist='inherit', nbands=None):
243 if buf is not None:
244 array = self.array
245 buf = buf.ravel()[:array.size]
246 buf.shape = array.shape
247 if dist == 'inherit':
248 dist = self.matrix.dist
249 return PlaneWaveExpansionWaveFunctions(nbands or len(self),
250 self.pd, self.dtype,
251 buf,
252 self.kpt, dist,
253 self.spin, self.collinear)
255 def view(self, n1, n2):
256 return PlaneWaveExpansionWaveFunctions(n2 - n1, self.pd, self.dtype,
257 self.array[n1:n2],
258 self.kpt, None,
259 self.spin, self.collinear)
262def operate_and_multiply(psit1, dv, out, operator, psit2):
263 if psit1.comm:
264 if psit2 is not None:
265 assert psit2.comm is psit1.comm
266 if psit1.comm.size > 1:
267 out.comm = psit1.comm
268 out.state = 'a sum is needed'
270 comm = psit1.matrix.dist.comm
271 N = len(psit1)
272 n = (N + comm.size - 1) // comm.size
273 mynbands = len(psit1.matrix.array)
275 buf1 = psit1.new(nbands=n, dist=None)
276 buf2 = psit1.new(nbands=n, dist=None)
277 half = comm.size // 2
278 psit = psit1.view(0, mynbands)
279 if psit2 is not None:
280 psit2 = psit2.view(0, mynbands)
282 for r in range(half + 1):
283 rrequest = None
284 srequest = None
286 if r < half:
287 srank = (comm.rank + r + 1) % comm.size
288 rrank = (comm.rank - r - 1) % comm.size
289 skip = (comm.size % 2 == 0 and r == half - 1)
290 n1 = min(rrank * n, N)
291 n2 = min(n1 + n, N)
292 if not (skip and comm.rank < half) and n2 > n1:
293 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False)
294 if not (skip and comm.rank >= half) and len(psit1.array) > 0:
295 srequest = comm.send(psit1.array, srank, 11, False)
297 if r == 0:
298 if operator:
299 operator(psit1.array, psit2.array)
300 else:
301 psit2 = psit
303 if not (comm.size % 2 == 0 and r == half and comm.rank < half):
304 m12 = psit2.matrix_elements(psit, symmetric=(r == 0), cc=True,
305 serial=True)
306 n1 = min(((comm.rank - r) % comm.size) * n, N)
307 n2 = min(n1 + n, N)
308 out.array[:, n1:n2] = m12.array[:, :n2 - n1]
310 if rrequest:
311 comm.wait(rrequest)
312 if srequest:
313 comm.wait(srequest)
315 psit = buf1
316 buf1, buf2 = buf2, buf1
318 requests = []
319 blocks = []
320 nrows = (comm.size - 1) // 2
321 for row in range(nrows):
322 for column in range(comm.size - nrows + row, comm.size):
323 if comm.rank == row:
324 n1 = min(column * n, N)
325 n2 = min(n1 + n, N)
326 if mynbands > 0 and n2 > n1:
327 requests.append(
328 comm.send(out.array[:, n1:n2].T.conj().copy(),
329 column, 12, False))
330 elif comm.rank == column:
331 n1 = min(row * n, N)
332 n2 = min(n1 + n, N)
333 if mynbands > 0 and n2 > n1:
334 block = np.empty((mynbands, n2 - n1), out.dtype)
335 blocks.append((n1, n2, block))
336 requests.append(comm.receive(block, row, 12, False))
338 comm.waitall(requests)
339 for n1, n2, block in blocks:
340 out.array[:, n1:n2] = block
343def operate_and_multiply_not_symmetric(psit1, dv, out, psit2):
344 if psit1.comm:
345 if psit2 is not None:
346 assert psit2.comm is psit1.comm
347 if psit1.comm.size > 1:
348 out.comm = psit1.comm
349 out.state = 'a sum is needed'
351 comm = psit1.matrix.dist.comm
352 N = len(psit1)
353 n = (N + comm.size - 1) // comm.size
354 mynbands = len(psit1.matrix.array)
356 buf1 = psit1.new(nbands=n, dist=None)
357 buf2 = psit1.new(nbands=n, dist=None)
359 psit1 = psit1.view(0, mynbands)
360 psit = psit2.view(0, mynbands)
361 for r in range(comm.size):
362 rrequest = None
363 srequest = None
365 if r < comm.size - 1:
366 srank = (comm.rank + r + 1) % comm.size
367 rrank = (comm.rank - r - 1) % comm.size
368 n1 = min(rrank * n, N)
369 n2 = min(n1 + n, N)
370 if n2 > n1:
371 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False)
372 if len(psit1.array) > 0:
373 srequest = comm.send(psit2.array, srank, 11, False)
375 m12 = psit1.matrix_elements(psit, cc=True, serial=True)
376 n1 = min(((comm.rank - r) % comm.size) * n, N)
377 n2 = min(n1 + n, N)
378 out.array[:, n1:n2] = m12.array[:, :n2 - n1]
380 if rrequest:
381 comm.wait(rrequest)
382 if srequest:
383 comm.wait(srequest)
385 psit = buf1
386 buf1, buf2 = buf2, buf1