Coverage for gpaw/wavefunctions/arrays.py: 80%

282 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import numpy as np 

2 

3from gpaw.matrix import Matrix, create_distribution 

4 

5 

6class MatrixInFile: 

7 def __init__(self, M, N, dtype, data, dist): 

8 self.shape = (M, N) 

9 self.dtype = dtype 

10 self.array = data # pointer to data in a file 

11 self.dist = create_distribution(M, N, *dist) 

12 

13 

14class ArrayWaveFunctions: 

15 def __init__(self, M, N, dtype, data, dist, collinear): 

16 self.collinear = collinear 

17 if not collinear: 

18 N *= 2 

19 if data is None or isinstance(data, np.ndarray): 

20 self.matrix = Matrix(M, N, dtype, data, dist) 

21 self.in_memory = True 

22 else: 

23 self.matrix = MatrixInFile(M, N, dtype, data, dist) 

24 self.in_memory = False 

25 self.comm = None 

26 self.dtype = self.matrix.dtype 

27 

28 def __len__(self): 

29 return len(self.matrix) 

30 

31 def multiply(self, alpha, opa, b, opb, beta, c, symmetric): 

32 self.matrix.multiply(alpha, opa, b.matrix, opb, beta, c, symmetric) 

33 if opa == 'N' and self.comm: 

34 if self.comm.size > 1: 

35 c.comm = self.comm 

36 c.state = 'a sum is needed' 

37 assert opb in 'TC' and b.comm is self.comm 

38 

39 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False, 

40 operator=None, result=None, serial=False): 

41 if out is None: 

42 out = Matrix(len(self), len(other or self), dtype=self.dtype, 

43 dist=(self.matrix.dist.comm, 

44 self.matrix.dist.rows, 

45 self.matrix.dist.columns)) 

46 if other is None or isinstance(other, ArrayWaveFunctions): 

47 assert cc 

48 if other is None: 

49 assert symmetric 

50 operate_and_multiply(self, self.dv, out, operator, result) 

51 elif not serial: 

52 assert not symmetric 

53 operate_and_multiply_not_symmetric(self, self.dv, out, 

54 other) 

55 else: 

56 self.multiply(self.dv, 'N', other, 'C', 0.0, out, symmetric) 

57 else: 

58 assert not cc 

59 P_ani = {a: P_ni for a, P_ni in out.items()} 

60 other.integrate(self.array, P_ani, self.kpt) 

61 return out 

62 

63 def add(self, lfc, coefs): 

64 lfc.add(self.array, dict(coefs.items()), self.kpt) 

65 

66 def apply(self, func, out=None): 

67 out = out or self.new() 

68 func(self.array, out.array) 

69 return out 

70 

71 def __setitem__(self, i, x): 

72 x.eval(self.matrix) 

73 

74 def __iadd__(self, other): 

75 other.eval(self.matrix, 1.0) 

76 return self 

77 

78 def eval(self, matrix): 

79 matrix.array[:] = self.matrix.array 

80 

81 def read_from_file(self): 

82 """Read wave functions from file into memory.""" 

83 matrix = Matrix(*self.matrix.shape, 

84 dtype=self.dtype, dist=self.matrix.dist) 

85 # Read band by band to save memory 

86 rows = matrix.dist.rows 

87 blocksize = (matrix.shape[0] + rows - 1) // rows 

88 for myn, psit_G in enumerate(matrix.array): 

89 n = matrix.dist.comm.rank * blocksize + myn 

90 if self.comm.rank == 0: 

91 big_psit_G = self.array[n] 

92 if big_psit_G.dtype == complex and self.dtype == float: 

93 big_psit_G = big_psit_G.view(float) 

94 elif big_psit_G.dtype == float and self.dtype == complex: 

95 big_psit_G = np.asarray(big_psit_G, complex) 

96 else: 

97 big_psit_G = None 

98 self._distribute(big_psit_G, psit_G) 

99 self.matrix = matrix 

100 self.in_memory = True 

101 

102 

103class UniformGridWaveFunctions(ArrayWaveFunctions): 

104 def __init__(self, nbands, gd, dtype=None, data=None, kpt=None, dist=None, 

105 spin=0, collinear=True): 

106 ngpts = gd.n_c.prod() 

107 ArrayWaveFunctions.__init__(self, nbands, ngpts, dtype, data, dist, 

108 collinear) 

109 

110 M = self.matrix 

111 

112 if data is None: 

113 M.array = M.array.reshape(-1).reshape(M.dist.shape) 

114 

115 self.myshape = (M.dist.shape[0],) + tuple(gd.n_c) 

116 self.gd = gd 

117 self.dv = gd.dv 

118 self.kpt = kpt 

119 self.spin = spin 

120 self.comm = gd.comm 

121 

122 @property 

123 def array(self): 

124 if self.in_memory: 

125 return self.matrix.array.reshape(self.myshape) 

126 else: 

127 return self.matrix.array 

128 

129 def _distribute(self, big_psit_R, psit_R): 

130 self.gd.distribute(big_psit_R, psit_R.reshape(self.gd.n_c)) 

131 

132 def __repr__(self): 

133 s = ArrayWaveFunctions.__repr__(self).split('(')[1][:-1] 

134 shape = self.gd.get_size_of_global_array() 

135 s = 'UniformGridWaveFunctions({}, gpts={}x{}x{})'.format(s, *shape) 

136 return s 

137 

138 def new(self, buf=None, dist='inherit', nbands=None): 

139 if dist == 'inherit': 

140 dist = self.matrix.dist 

141 return UniformGridWaveFunctions(nbands or len(self), 

142 self.gd, self.dtype, 

143 buf, 

144 self.kpt, dist, 

145 self.spin) 

146 

147 def view(self, n1, n2): 

148 return UniformGridWaveFunctions(n2 - n1, self.gd, self.dtype, 

149 self.array[n1:n2], 

150 self.kpt, None, 

151 self.spin) 

152 

153 def plot(self): 

154 import matplotlib.pyplot as plt 

155 ax = plt.figure().add_subplot(111) 

156 a, b, c = self.array.shape[1:] 

157 ax.plot(self.array[0, a // 2, b // 2]) 

158 plt.show() 

159 

160 

161class PlaneWaveExpansionWaveFunctions(ArrayWaveFunctions): 

162 def __init__(self, nbands, pd, dtype=None, data=None, kpt=0, dist=None, 

163 spin=0, collinear=True): 

164 ng = ng0 = pd.myng_q[kpt] 

165 if data is not None: 

166 assert data.dtype == complex 

167 if dtype == float: 

168 ng *= 2 

169 if isinstance(data, np.ndarray): 

170 data = data.view(float) 

171 

172 ArrayWaveFunctions.__init__(self, nbands, ng, dtype, data, dist, 

173 collinear) 

174 self.pd = pd 

175 self.gd = pd.gd 

176 self.comm = pd.gd.comm 

177 self.dv = pd.gd.dv / pd.gd.N_c.prod() 

178 self.kpt = kpt 

179 self.spin = spin 

180 if collinear: 

181 self.myshape = (self.matrix.dist.shape[0], ng0) 

182 else: 

183 self.myshape = (self.matrix.dist.shape[0], 2, ng0) 

184 

185 @property 

186 def array(self): 

187 if not self.in_memory: 

188 return self.matrix.array 

189 elif self.dtype == float: 

190 return self.matrix.array.view(complex) 

191 else: 

192 return self.matrix.array.reshape(self.myshape) 

193 

194 def _distribute(self, big_psit_G, psit_G): 

195 if self.collinear: 

196 if self.dtype == float: 

197 if big_psit_G is not None: 

198 big_psit_G = big_psit_G.view(complex) 

199 psit_G = psit_G.view(complex) 

200 psit_G[:] = self.pd.scatter(big_psit_G, self.kpt) 

201 else: 

202 psit_sG = psit_G.reshape((2, -1)) 

203 psit_sG[0] = self.pd.scatter(big_psit_G[0], self.kpt) 

204 psit_sG[1] = self.pd.scatter(big_psit_G[1], self.kpt) 

205 

206 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False, 

207 operator=None, result=None, serial=False): 

208 if other is None or isinstance(other, ArrayWaveFunctions): 

209 if out is None: 

210 out = Matrix(len(self), len(other or self), dtype=self.dtype, 

211 dist=(self.matrix.dist.comm, 

212 self.matrix.dist.rows, 

213 self.matrix.dist.columns)) 

214 assert cc 

215 if other is None: 

216 assert symmetric 

217 operate_and_multiply(self, self.dv, out, operator, result) 

218 elif not serial: 

219 assert not symmetric 

220 operate_and_multiply_not_symmetric(self, self.dv, out, 

221 other) 

222 elif self.dtype == complex: 

223 self.matrix.multiply(self.dv, 'N', other.matrix, 'C', 

224 0.0, out, symmetric) 

225 else: 

226 self.matrix.multiply(2 * self.dv, 'N', other.matrix, 'T', 

227 0.0, out, symmetric) 

228 if self.gd.comm.rank == 0: 

229 correction = np.outer(self.matrix.array[:, 0], 

230 other.matrix.array[:, 0]) 

231 if symmetric: 

232 out.array -= 0.5 * self.dv * (correction + 

233 correction.T) 

234 else: 

235 out.array -= self.dv * correction 

236 else: 

237 assert not cc 

238 P_ani = {a: P_ni for a, P_ni in out.items()} 

239 other.integrate(self.array, P_ani, self.kpt) 

240 return out 

241 

242 def new(self, buf=None, dist='inherit', nbands=None): 

243 if buf is not None: 

244 array = self.array 

245 buf = buf.ravel()[:array.size] 

246 buf.shape = array.shape 

247 if dist == 'inherit': 

248 dist = self.matrix.dist 

249 return PlaneWaveExpansionWaveFunctions(nbands or len(self), 

250 self.pd, self.dtype, 

251 buf, 

252 self.kpt, dist, 

253 self.spin, self.collinear) 

254 

255 def view(self, n1, n2): 

256 return PlaneWaveExpansionWaveFunctions(n2 - n1, self.pd, self.dtype, 

257 self.array[n1:n2], 

258 self.kpt, None, 

259 self.spin, self.collinear) 

260 

261 

262def operate_and_multiply(psit1, dv, out, operator, psit2): 

263 if psit1.comm: 

264 if psit2 is not None: 

265 assert psit2.comm is psit1.comm 

266 if psit1.comm.size > 1: 

267 out.comm = psit1.comm 

268 out.state = 'a sum is needed' 

269 

270 comm = psit1.matrix.dist.comm 

271 N = len(psit1) 

272 n = (N + comm.size - 1) // comm.size 

273 mynbands = len(psit1.matrix.array) 

274 

275 buf1 = psit1.new(nbands=n, dist=None) 

276 buf2 = psit1.new(nbands=n, dist=None) 

277 half = comm.size // 2 

278 psit = psit1.view(0, mynbands) 

279 if psit2 is not None: 

280 psit2 = psit2.view(0, mynbands) 

281 

282 for r in range(half + 1): 

283 rrequest = None 

284 srequest = None 

285 

286 if r < half: 

287 srank = (comm.rank + r + 1) % comm.size 

288 rrank = (comm.rank - r - 1) % comm.size 

289 skip = (comm.size % 2 == 0 and r == half - 1) 

290 n1 = min(rrank * n, N) 

291 n2 = min(n1 + n, N) 

292 if not (skip and comm.rank < half) and n2 > n1: 

293 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False) 

294 if not (skip and comm.rank >= half) and len(psit1.array) > 0: 

295 srequest = comm.send(psit1.array, srank, 11, False) 

296 

297 if r == 0: 

298 if operator: 

299 operator(psit1.array, psit2.array) 

300 else: 

301 psit2 = psit 

302 

303 if not (comm.size % 2 == 0 and r == half and comm.rank < half): 

304 m12 = psit2.matrix_elements(psit, symmetric=(r == 0), cc=True, 

305 serial=True) 

306 n1 = min(((comm.rank - r) % comm.size) * n, N) 

307 n2 = min(n1 + n, N) 

308 out.array[:, n1:n2] = m12.array[:, :n2 - n1] 

309 

310 if rrequest: 

311 comm.wait(rrequest) 

312 if srequest: 

313 comm.wait(srequest) 

314 

315 psit = buf1 

316 buf1, buf2 = buf2, buf1 

317 

318 requests = [] 

319 blocks = [] 

320 nrows = (comm.size - 1) // 2 

321 for row in range(nrows): 

322 for column in range(comm.size - nrows + row, comm.size): 

323 if comm.rank == row: 

324 n1 = min(column * n, N) 

325 n2 = min(n1 + n, N) 

326 if mynbands > 0 and n2 > n1: 

327 requests.append( 

328 comm.send(out.array[:, n1:n2].T.conj().copy(), 

329 column, 12, False)) 

330 elif comm.rank == column: 

331 n1 = min(row * n, N) 

332 n2 = min(n1 + n, N) 

333 if mynbands > 0 and n2 > n1: 

334 block = np.empty((mynbands, n2 - n1), out.dtype) 

335 blocks.append((n1, n2, block)) 

336 requests.append(comm.receive(block, row, 12, False)) 

337 

338 comm.waitall(requests) 

339 for n1, n2, block in blocks: 

340 out.array[:, n1:n2] = block 

341 

342 

343def operate_and_multiply_not_symmetric(psit1, dv, out, psit2): 

344 if psit1.comm: 

345 if psit2 is not None: 

346 assert psit2.comm is psit1.comm 

347 if psit1.comm.size > 1: 

348 out.comm = psit1.comm 

349 out.state = 'a sum is needed' 

350 

351 comm = psit1.matrix.dist.comm 

352 N = len(psit1) 

353 n = (N + comm.size - 1) // comm.size 

354 mynbands = len(psit1.matrix.array) 

355 

356 buf1 = psit1.new(nbands=n, dist=None) 

357 buf2 = psit1.new(nbands=n, dist=None) 

358 

359 psit1 = psit1.view(0, mynbands) 

360 psit = psit2.view(0, mynbands) 

361 for r in range(comm.size): 

362 rrequest = None 

363 srequest = None 

364 

365 if r < comm.size - 1: 

366 srank = (comm.rank + r + 1) % comm.size 

367 rrank = (comm.rank - r - 1) % comm.size 

368 n1 = min(rrank * n, N) 

369 n2 = min(n1 + n, N) 

370 if n2 > n1: 

371 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False) 

372 if len(psit1.array) > 0: 

373 srequest = comm.send(psit2.array, srank, 11, False) 

374 

375 m12 = psit1.matrix_elements(psit, cc=True, serial=True) 

376 n1 = min(((comm.rank - r) % comm.size) * n, N) 

377 n2 = min(n1 + n, N) 

378 out.array[:, n1:n2] = m12.array[:, :n2 - n1] 

379 

380 if rrequest: 

381 comm.wait(rrequest) 

382 if srequest: 

383 comm.wait(srequest) 

384 

385 psit = buf1 

386 buf1, buf2 = buf2, buf1