Coverage for gpaw/lrtddft2/k_matrix.py: 95%

250 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import os 

2import glob 

3import datetime 

4 

5import numpy as np 

6 

7import gpaw.mpi 

8from gpaw.utilities import pack_density 

9from gpaw.lrtddft2.eta import QuadraticETA 

10 

11 

12class Kmatrix: 

13 def __init__(self, ks_singles, xc, deriv_scale=1e-5): 

14 self.basefilename = ks_singles.basefilename 

15 self.ks_singles = ks_singles 

16 self.lr_comms = self.ks_singles.lr_comms 

17 self.calc = self.ks_singles.calc 

18 self.xc = xc 

19 self.ready_indices = [] 

20 self.values = None 

21 self.K_matrix_ready = False 

22 self.deriv_scale = deriv_scale 

23 self.file_format = 0 # 0 = 'Casida', 1 = 'K-matrix' 

24 

25 self.dVxct_gip_2 = None # temporary for finite difference 

26 

27 # for pros!!! 

28 self.fH_pre = 1.0 

29 self.fxc_pre = 1.0 

30 

31 def initialize(self): 

32 self.K_matrix_ready = False 

33 

34 def read_indices(self): 

35 # Read ALL ready_rows files 

36 # self.timer.start('Init read ready rows') 

37 # root reads and then broadcasts 

38 data = None 

39 if self.lr_comms.parent_comm.rank == 0: 

40 data = '' 

41 ready_files = glob.glob(self.basefilename + '.ready_rows.*') 

42 for ready_file in ready_files: 

43 if os.path.isfile(ready_file): 

44 with open(ready_file, 'r', 1024 * 1024) as fd: 

45 data += fd.read() 

46 

47 data = gpaw.mpi.broadcast_string(data, 

48 root=0, 

49 comm=self.lr_comms.parent_comm) 

50 for line in data.splitlines(): 

51 line = line.split() 

52 self.ready_indices.append([int(line[0]), int(line[1])]) 

53 

54 def read_values(self): 

55 """ Read K-matrix (not the Casida form) ready for 2D ScaLapack array""" 

56 # nrow = len(self.ks_singles.kss_list) # total rows 

57 nlrow = 0 # local rows 

58 nlcol = 0 # local cols 

59 

60 # Create indexing 

61 self.lr_comms.index_map = {} # (i,p) to matrix index map 

62 for (ip, kss) in enumerate(self.ks_singles.kss_list): 

63 self.lr_comms.index_map[(kss.occ_ind, kss.unocc_ind)] = ip 

64 if self.lr_comms.get_local_eh_index(ip) is not None: 

65 nlrow += 1 

66 if self.lr_comms.get_local_dd_index(ip) is not None: 

67 nlcol += 1 

68 

69 # ready rows list for different procs (read by this proc) 

70 elem_lists = {} 

71 for proc in range(self.lr_comms.parent_comm.size): 

72 elem_lists[proc] = [] 

73 

74 self.lr_comms.parent_comm.barrier() 

75 

76 # self.timer.start('Read K-matrix') 

77 # Read ALL ready_rows files but on different processors 

78 for (k, 

79 K_fn) in enumerate(glob.glob(self.basefilename + '.K_matrix.*')): 

80 # read every "parent comm size"th file, starting from 

81 # parent comm rank 

82 if (k % self.lr_comms.parent_comm.size != 

83 self.lr_comms.parent_comm.rank): 

84 continue 

85 

86 # for each file 

87 with open(K_fn, 'r', 1024 * 1024) as fd: 

88 lines = fd.read().splitlines() 

89 for line in lines: 

90 # self.timer.start('Read K-matrix: elem') 

91 if line[0] == '#': 

92 if line.startswith('# K-matrix file'): 

93 self.file_format = 1 

94 continue 

95 

96 elems = line.split() 

97 i = int(elems[0]) 

98 p = int(elems[1]) 

99 j = int(elems[2]) 

100 q = int(elems[3]) 

101 # self.timer.stop('Read K-matrix: elem') 

102 

103 # self.timer.start('Read K-matrix: index') 

104 ip = self.lr_comms.index_map.get((i, p)) 

105 jq = self.lr_comms.index_map.get((j, q)) 

106 # self.timer.stop('Read K-matrix: index') 

107 if ip is None or jq is None: 

108 continue 

109 

110 # where to send 

111 # self.timer.start('Read K-matrix: line') 

112 (proc, ehproc, ddproc, lip, 

113 ljq) = self.lr_comms.get_matrix_elem_proc_and_index(ip, jq) 

114 elem_lists[proc].append(line + '\n') 

115 # self.timer.stop('Read K-matrix: line') 

116 

117 if ip == jq: 

118 continue 

119 

120 # where to send transposed 

121 # self.timer.start('Read K-matrix: line') 

122 (proc, ehproc, ddproc, lip, 

123 ljq) = self.lr_comms.get_matrix_elem_proc_and_index(jq, ip) 

124 elem_lists[proc].append(line + '\n') 

125 # self.timer.stop('Read K-matrix: line') 

126 

127 # self.timer.start('Read K-matrix: join') 

128 for proc in range(self.lr_comms.parent_comm.size): 

129 elem_lists[proc] = ''.join(elem_lists[proc]) 

130 # self.timer.stop('Read K-matrix: join') 

131 

132 # print self.parent_comm.rank, '- elem_lists -' 

133 # for (key,val) in elem_lists.items(): 

134 # print key, ':', val[0:120] 

135 # sys.stdout.flush() 

136 

137 self.lr_comms.parent_comm.barrier() 

138 # self.timer.stop('Read K-matrix') 

139 

140 self.file_format = self.lr_comms.parent_comm.max_scalar( 

141 self.file_format) 

142 

143 # send and receive elem_list 

144 # self.timer.start('Communicate K-matrix') 

145 alltoall_dict = gpaw.mpi.alltoallv_string(elem_lists, 

146 self.lr_comms.parent_comm) 

147 # ready for garbage collection 

148 del elem_lists 

149 local_elem_list = ''.join(alltoall_dict.values()) 

150 # ready for garbage collection 

151 del alltoall_dict 

152 

153 # local_elem_list = '' 

154 # for sending_proc in range(self.parent_comm.size): 

155 # for receiving_proc in range(self.parent_comm.size): 

156 # if ( sending_proc == receiving_proc and 

157 # sending_proc == self.parent_comm.rank ): 

158 # local_elem_list += elem_lists[sending_proc] 

159 # elif sending_proc == self.parent_comm.rank: 

160 # gpaw.mpi.send_string( elem_lists[receiving_proc], 

161 # receiving_proc, 

162 # comm=self.parent_comm) 

163 # elif receiving_proc == self.parent_comm.rank: 

164 # elist = gpaw.mpi.receive_string( sending_proc, 

165 # comm=self.parent_comm ) 

166 # local_elem_list += elist 

167 # 

168 

169 # Matrix build 

170 K_matrix = np.zeros((nlrow, nlcol)) 

171 K_matrix[:, :] = np.nan # fill with NaNs to detect problems 

172 # Read ALL K_matrix files 

173 for line in local_elem_list.splitlines(): 

174 line = line.split() 

175 ipkey = (int(line[0]), int(line[1])) 

176 jqkey = (int(line[2]), int(line[3])) 

177 Kvalue = float(line[4]) 

178 

179 ip = self.lr_comms.index_map.get(ipkey) 

180 jq = self.lr_comms.index_map.get(jqkey) 

181 

182 # if not in index map, ignore 

183 if (ip is None or jq is None): 

184 continue 

185 

186 kss_ip = self.ks_singles.kss_list[ip] 

187 kss_jq = self.ks_singles.kss_list[jq] 

188 

189 # if (ip,jq) on this this proc 

190 lip = self.lr_comms.get_local_eh_index(ip) 

191 ljq = self.lr_comms.get_local_dd_index(jq) 

192 if lip is not None and ljq is not None: 

193 # add value to matrix 

194 if self.file_format == 1: 

195 K_matrix[lip, ljq] = Kvalue 

196 elif self.file_format == 0: 

197 K_matrix[lip, ljq] = Kvalue / ( 

198 2. * np.sqrt(kss_ip.energy_diff * kss_jq.energy_diff * 

199 kss_ip.pop_diff * kss_jq.pop_diff)) 

200 else: 

201 raise RuntimeError('Invalid K-matrix file format') 

202 

203 # if (jq,ip) on this this proc 

204 ljq = self.lr_comms.get_local_eh_index(jq) 

205 lip = self.lr_comms.get_local_dd_index(ip) 

206 if lip is not None and ljq is not None: 

207 # add value to matrix 

208 if self.file_format == 1: 

209 K_matrix[ljq, lip] = Kvalue 

210 elif self.file_format == 0: 

211 K_matrix[ljq, lip] = Kvalue / ( 

212 2. * np.sqrt(kss_ip.energy_diff * kss_jq.energy_diff * 

213 kss_ip.pop_diff * kss_jq.pop_diff)) 

214 else: 

215 raise RuntimeError('Invalid K-matrix file format') 

216 

217 # ready for garbage collection 

218 del local_elem_list 

219 

220 # If any NaNs found, we did not read all matrix elements... BAD 

221 if np.isnan(np.sum(np.sum(K_matrix))): 

222 raise RuntimeError( 

223 'Not all required K-matrix elements could be found.') 

224 

225 self.values = K_matrix 

226 

227 # FIXME: implement spin polarized 

228 def calculate(self): 

229 # Check if already done before allocating 

230 if self.K_matrix_ready: 

231 return 

232 

233 # Loop over all transitions 

234 self.K_matrix_ready = True # mark done... if not, it's changed 

235 nrows = 0 # number of rows for timings 

236 for (ip, kss_ip) in enumerate(self.ks_singles.kss_list): 

237 i = kss_ip.occ_ind 

238 p = kss_ip.unocc_ind 

239 

240 # if not mine, skip it 

241 if self.lr_comms.get_local_eh_index(ip) is None: 

242 continue 

243 # if already calculated, skip it 

244 if [i, p] in self.ready_indices: 

245 continue 

246 

247 self.K_matrix_ready = False # something not calculated, must do it 

248 nrows += 1 

249 

250 # If matrix was ready, done 

251 if self.K_matrix_ready: 

252 return 

253 

254 # self.timer.start('Calculate K matrix') 

255 # self.timer.start('Initialize') 

256 

257 # Filenames 

258 ####################################################################### 

259 # saving CORRECT K-matrix, not its "Casida form" like in alpha version 

260 ####################################################################### 

261 Kfn = self.basefilename + '.K_matrix.' + '%06dof%06d' % ( 

262 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size) 

263 rrfn = self.basefilename + '.ready_rows.' + '%06dof%06d' % ( 

264 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size) 

265 logfn = self.basefilename + '.log.' + '%06dof%06d' % ( 

266 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size) 

267 

268 # Open only on dd_comm root 

269 if self.lr_comms.dd_comm.rank == 0: 

270 self.Kfile = open(Kfn, 'a+') 

271 self.ready_file = open(rrfn, 'a+') 

272 self.log_file = open(logfn, 'a+') 

273 self.Kfile.write('# K-matrix file\n') 

274 

275 self.poisson = self.calc.hamiltonian.poisson 

276 

277 # Allocate grids for densities and potentials 

278 dnt_Gip = self.calc.wfs.gd.empty() 

279 dnt_gip = self.calc.density.finegd.empty() 

280 drhot_gip = self.calc.density.finegd.empty() 

281 dnt_Gjq = self.calc.wfs.gd.empty() 

282 dnt_gjq = self.calc.density.finegd.empty() 

283 drhot_gjq = self.calc.density.finegd.empty() 

284 

285 self.nt_g = self.calc.density.finegd.zeros( 

286 self.calc.density.nt_sg.shape[0]) 

287 

288 dVht_gip = self.calc.density.finegd.empty() 

289 dVxct_gip = self.calc.density.finegd.zeros( 

290 self.calc.density.nt_sg.shape[0]) 

291 

292 # self.timer.stop('Initialize') 

293 # Init ETA 

294 self.matrix_eta = QuadraticETA() 

295 

296 ################################################################# 

297 # Outer loop over KS singles 

298 for (ip, kss_ip) in enumerate(self.ks_singles.kss_list): 

299 

300 # if not mine, skip it 

301 if self.lr_comms.get_local_eh_index(ip) is None: 

302 continue 

303 # if already calculated, skip it 

304 if [kss_ip.occ_ind, kss_ip.unocc_ind] in self.ready_indices: 

305 continue 

306 

307 # ETA 

308 if self.lr_comms.dd_comm.rank == 0: 

309 self.matrix_eta.update() 

310 # add 21% extra time (1.1**2 = 1.21) 

311 eta = self.matrix_eta.eta((nrows + 1) * 1.1) 

312 self.log_file.write( 

313 'Calculating pair %5d => %5d ( %s, ETA %9.1lfs )\n' % 

314 (kss_ip.occ_ind, kss_ip.unocc_ind, 

315 str(datetime.datetime.now()), eta)) 

316 self.log_file.flush() 

317 

318 # Pair density 

319 # self.timer.start('Pair density') 

320 dnt_Gip[:] = 0.0 

321 dnt_gip[:] = 0.0 

322 drhot_gip[:] = 0.0 

323 kss_ip.calculate_pair_density(dnt_Gip, dnt_gip, drhot_gip) 

324 # self.timer.stop('Pair density') 

325 

326 # Smooth Hartree "pair" potential 

327 # for compensated pair density drhot_gip 

328 # self.timer.start('Poisson') 

329 dVht_gip[:] = 0.0 

330 self.poisson.solve(dVht_gip, drhot_gip, charge=None) 

331 # self.timer.stop('Poisson') 

332 

333 # smooth XC "pair" potential 

334 self.calculate_smooth_xc_pair_potential(kss_ip, dnt_gip, dVxct_gip) 

335 # paw term of XC "pair" potential 

336 I_asp = self.calculate_paw_xc_pair_potentials(kss_ip) 

337 

338 ################################################################# 

339 # Inner loop over KS singles 

340 K = [] # storage for row before writing to file 

341 for (jq, kss_jq) in enumerate(self.ks_singles.kss_list): 

342 i = kss_ip.occ_ind 

343 p = kss_ip.unocc_ind 

344 j = kss_jq.occ_ind 

345 q = kss_jq.unocc_ind 

346 

347 # Only lower triangle 

348 if ip < jq: 

349 continue 

350 

351 # Pair density dn_jq 

352 # self.timer.start('Pair density') 

353 dnt_Gjq[:] = 0.0 

354 dnt_gjq[:] = 0.0 

355 drhot_gjq[:] = 0.0 

356 kss_jq.calculate_pair_density(dnt_Gjq, dnt_gjq, drhot_gjq) 

357 # self.timer.stop('Pair density') 

358 

359 # integrate to get the final matrix element value 

360 # self.timer.start('Integrate') 

361 

362 # init grid part 

363 Ig = 0.0 

364 

365 # Hartree smooth part, RHOT_JQ HERE??? 

366 Ig += self.fH_pre * self.calc.density.finegd.integrate( 

367 dVht_gip, drhot_gjq) 

368 # XC smooth part 

369 Ig += self.fxc_pre * self.calc.density.finegd.integrate( 

370 dVxct_gip, dnt_gjq) 

371 # self.timer.stop('Integrate') 

372 

373 # Atomic corrections 

374 # self.timer.start('Atomic corrections') 

375 Ia = self.calculate_paw_fHXC_corrections(kss_ip, kss_jq, I_asp) 

376 

377 # self.timer.stop('Atomic corrections') 

378 Ia = self.lr_comms.dd_comm.sum_scalar(Ia) 

379 

380 # Total integral 

381 Itot = Ig + Ia 

382 

383 # K_ip,jq += <ip|fHxc|jq> 

384 K.append([i, p, j, q, Itot]) 

385 

386 # Write i p j q Kipjq 

387 # (format: -2345.789012345678) 

388 

389 # self.timer.start('Write K') 

390 # Write only on dd_comm root 

391 if self.lr_comms.dd_comm.rank == 0: 

392 

393 # Write only lower triangle of K-matrix 

394 for [i, p, j, q, Kipjq] in K: 

395 self.Kfile.write("%5d %5d %5d %5d %22.16lf\n" % 

396 (i, p, j, q, Kipjq[0])) 

397 self.Kfile.flush() # flush K-matrix before ready_rows 

398 

399 # Write and flush ready rows 

400 self.ready_file.write("%d %d\n" % 

401 (kss_ip.occ_ind, kss_ip.unocc_ind)) 

402 self.ready_file.flush() 

403 

404 # self.timer.stop('Write K') 

405 

406 # Update ready rows before continuing 

407 self.ready_indices.append([kss_ip.occ_ind, kss_ip.unocc_ind]) 

408 

409 # Close files on dd_comm root 

410 if self.lr_comms.dd_comm.rank == 0: 

411 self.Kfile.close() 

412 self.ready_file.close() 

413 self.log_file.close() 

414 

415 # self.timer.stop('Calculate K matrix') 

416 

417 def calculate_smooth_xc_pair_potential(self, kss_ip, dnt_gip, dVxct_gip): 

418 # Smooth xc potential 

419 # (finite difference approximation from xc-potential) 

420 if (self.dVxct_gip_2 is None): 

421 self.dVxct_gip_2 = self.calc.density.finegd.zeros( 

422 self.calc.density.nt_sg.shape[0]) 

423 dVxct_gip_2 = self.dVxct_gip_2 

424 

425 s = kss_ip.spin_ind 

426 

427 # self.timer.start('Smooth XC') 

428 # finite difference plus, vxc+ = vxc(n + deriv_scale * dn) 

429 self.nt_g[s][:] = self.deriv_scale * dnt_gip 

430 self.nt_g[s][:] += self.calc.density.nt_sg[s] 

431 dVxct_gip[:] = 0.0 

432 self.xc.calculate(self.calc.density.finegd, self.nt_g, dVxct_gip) 

433 

434 # finite difference minus, vxc+ = vxc(n - deriv_scale * dn) 

435 self.nt_g[s][:] = -self.deriv_scale * dnt_gip 

436 self.nt_g[s][:] += self.calc.density.nt_sg[s] 

437 dVxct_gip_2[:] = 0.0 

438 self.xc.calculate(self.calc.density.finegd, self.nt_g, dVxct_gip_2) 

439 dVxct_gip -= dVxct_gip_2 

440 # finite difference approx for fxc 

441 # vxc = (vxc+ - vxc-) / 2h 

442 dVxct_gip *= 1. / (2. * self.deriv_scale) 

443 # self.timer.stop('Smooth XC') 

444 

445 def calculate_paw_xc_pair_potentials(self, kss_ip): 

446 # XC corrections 

447 I_asp = {} 

448 i = kss_ip.occ_ind 

449 p = kss_ip.unocc_ind 

450 

451 # s = kss_ip.spin_ind 

452 

453 # FIXME, only spin unpolarized works 

454 # self.timer.start('Atomic XC') 

455 for a, P_ni in self.calc.wfs.kpt_u[kss_ip.kpt_ind].P_ani.items(): 

456 I_sp = np.zeros_like(self.calc.density.D_asp[a]) 

457 I_sp_2 = np.zeros_like(self.calc.density.D_asp[a]) 

458 

459 Pip_ni = self.calc.wfs.kpt_u[kss_ip.spin_ind].P_ani[a] 

460 Dip_ii = np.outer(Pip_ni[i], Pip_ni[p]) 

461 Dip_p = pack_density(Dip_ii) 

462 

463 # finite difference plus 

464 D_sp = self.calc.density.D_asp[a].copy() 

465 D_sp[kss_ip.spin_ind] += self.deriv_scale * Dip_p 

466 self.xc.calculate_paw_correction(self.calc.wfs.setups[a], D_sp, 

467 I_sp) 

468 

469 # finite difference minus 

470 D_sp_2 = self.calc.density.D_asp[a].copy() 

471 D_sp_2[kss_ip.spin_ind] -= self.deriv_scale * Dip_p 

472 self.xc.calculate_paw_correction(self.calc.wfs.setups[a], D_sp_2, 

473 I_sp_2) 

474 

475 # finite difference 

476 I_asp[a] = (I_sp - I_sp_2) / (2. * self.deriv_scale) 

477 

478 # self.timer.stop('Atomic XC') 

479 

480 return I_asp 

481 

482 def calculate_paw_fHXC_corrections(self, kss_ip, kss_jq, I_asp): 

483 i = kss_ip.occ_ind 

484 p = kss_ip.unocc_ind 

485 j = kss_jq.occ_ind 

486 q = kss_jq.unocc_ind 

487 

488 Ia = 0.0 

489 for a, P_ni in self.calc.wfs.kpt_u[kss_jq.spin_ind].P_ani.items(): 

490 Pip_ni = self.calc.wfs.kpt_u[kss_ip.spin_ind].P_ani[a] 

491 Dip_ii = np.outer(Pip_ni[i], Pip_ni[p]) 

492 Dip_p = pack_density(Dip_ii) 

493 

494 Pjq_ni = self.calc.wfs.kpt_u[kss_jq.spin_ind].P_ani[a] 

495 Djq_ii = np.outer(Pjq_ni[j], Pjq_ni[q]) 

496 Djq_p = pack_density(Djq_ii) 

497 

498 # Hartree part 

499 C_pp = self.calc.wfs.setups[a].M_pp 

500 # why factor of two here? 

501 # see appendix A of J. Chem. Phys. 128, 244101 (2008) 

502 # 

503 # 2 sum_prst P P C P P 

504 # ip jr prst ks qt 

505 Ia += self.fH_pre * 2.0 * np.dot(Djq_p, np.dot(C_pp, Dip_p)) 

506 

507 # XC part, CHECK THIS JQ EVERWHERE!!! 

508 Ia += self.fxc_pre * np.dot(I_asp[a][kss_jq.spin_ind], Djq_p) 

509 

510 return Ia