Coverage for gpaw/lrtddft2/k_matrix.py: 95%
250 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import os
2import glob
3import datetime
5import numpy as np
7import gpaw.mpi
8from gpaw.utilities import pack_density
9from gpaw.lrtddft2.eta import QuadraticETA
12class Kmatrix:
13 def __init__(self, ks_singles, xc, deriv_scale=1e-5):
14 self.basefilename = ks_singles.basefilename
15 self.ks_singles = ks_singles
16 self.lr_comms = self.ks_singles.lr_comms
17 self.calc = self.ks_singles.calc
18 self.xc = xc
19 self.ready_indices = []
20 self.values = None
21 self.K_matrix_ready = False
22 self.deriv_scale = deriv_scale
23 self.file_format = 0 # 0 = 'Casida', 1 = 'K-matrix'
25 self.dVxct_gip_2 = None # temporary for finite difference
27 # for pros!!!
28 self.fH_pre = 1.0
29 self.fxc_pre = 1.0
31 def initialize(self):
32 self.K_matrix_ready = False
34 def read_indices(self):
35 # Read ALL ready_rows files
36 # self.timer.start('Init read ready rows')
37 # root reads and then broadcasts
38 data = None
39 if self.lr_comms.parent_comm.rank == 0:
40 data = ''
41 ready_files = glob.glob(self.basefilename + '.ready_rows.*')
42 for ready_file in ready_files:
43 if os.path.isfile(ready_file):
44 with open(ready_file, 'r', 1024 * 1024) as fd:
45 data += fd.read()
47 data = gpaw.mpi.broadcast_string(data,
48 root=0,
49 comm=self.lr_comms.parent_comm)
50 for line in data.splitlines():
51 line = line.split()
52 self.ready_indices.append([int(line[0]), int(line[1])])
54 def read_values(self):
55 """ Read K-matrix (not the Casida form) ready for 2D ScaLapack array"""
56 # nrow = len(self.ks_singles.kss_list) # total rows
57 nlrow = 0 # local rows
58 nlcol = 0 # local cols
60 # Create indexing
61 self.lr_comms.index_map = {} # (i,p) to matrix index map
62 for (ip, kss) in enumerate(self.ks_singles.kss_list):
63 self.lr_comms.index_map[(kss.occ_ind, kss.unocc_ind)] = ip
64 if self.lr_comms.get_local_eh_index(ip) is not None:
65 nlrow += 1
66 if self.lr_comms.get_local_dd_index(ip) is not None:
67 nlcol += 1
69 # ready rows list for different procs (read by this proc)
70 elem_lists = {}
71 for proc in range(self.lr_comms.parent_comm.size):
72 elem_lists[proc] = []
74 self.lr_comms.parent_comm.barrier()
76 # self.timer.start('Read K-matrix')
77 # Read ALL ready_rows files but on different processors
78 for (k,
79 K_fn) in enumerate(glob.glob(self.basefilename + '.K_matrix.*')):
80 # read every "parent comm size"th file, starting from
81 # parent comm rank
82 if (k % self.lr_comms.parent_comm.size !=
83 self.lr_comms.parent_comm.rank):
84 continue
86 # for each file
87 with open(K_fn, 'r', 1024 * 1024) as fd:
88 lines = fd.read().splitlines()
89 for line in lines:
90 # self.timer.start('Read K-matrix: elem')
91 if line[0] == '#':
92 if line.startswith('# K-matrix file'):
93 self.file_format = 1
94 continue
96 elems = line.split()
97 i = int(elems[0])
98 p = int(elems[1])
99 j = int(elems[2])
100 q = int(elems[3])
101 # self.timer.stop('Read K-matrix: elem')
103 # self.timer.start('Read K-matrix: index')
104 ip = self.lr_comms.index_map.get((i, p))
105 jq = self.lr_comms.index_map.get((j, q))
106 # self.timer.stop('Read K-matrix: index')
107 if ip is None or jq is None:
108 continue
110 # where to send
111 # self.timer.start('Read K-matrix: line')
112 (proc, ehproc, ddproc, lip,
113 ljq) = self.lr_comms.get_matrix_elem_proc_and_index(ip, jq)
114 elem_lists[proc].append(line + '\n')
115 # self.timer.stop('Read K-matrix: line')
117 if ip == jq:
118 continue
120 # where to send transposed
121 # self.timer.start('Read K-matrix: line')
122 (proc, ehproc, ddproc, lip,
123 ljq) = self.lr_comms.get_matrix_elem_proc_and_index(jq, ip)
124 elem_lists[proc].append(line + '\n')
125 # self.timer.stop('Read K-matrix: line')
127 # self.timer.start('Read K-matrix: join')
128 for proc in range(self.lr_comms.parent_comm.size):
129 elem_lists[proc] = ''.join(elem_lists[proc])
130 # self.timer.stop('Read K-matrix: join')
132 # print self.parent_comm.rank, '- elem_lists -'
133 # for (key,val) in elem_lists.items():
134 # print key, ':', val[0:120]
135 # sys.stdout.flush()
137 self.lr_comms.parent_comm.barrier()
138 # self.timer.stop('Read K-matrix')
140 self.file_format = self.lr_comms.parent_comm.max_scalar(
141 self.file_format)
143 # send and receive elem_list
144 # self.timer.start('Communicate K-matrix')
145 alltoall_dict = gpaw.mpi.alltoallv_string(elem_lists,
146 self.lr_comms.parent_comm)
147 # ready for garbage collection
148 del elem_lists
149 local_elem_list = ''.join(alltoall_dict.values())
150 # ready for garbage collection
151 del alltoall_dict
153 # local_elem_list = ''
154 # for sending_proc in range(self.parent_comm.size):
155 # for receiving_proc in range(self.parent_comm.size):
156 # if ( sending_proc == receiving_proc and
157 # sending_proc == self.parent_comm.rank ):
158 # local_elem_list += elem_lists[sending_proc]
159 # elif sending_proc == self.parent_comm.rank:
160 # gpaw.mpi.send_string( elem_lists[receiving_proc],
161 # receiving_proc,
162 # comm=self.parent_comm)
163 # elif receiving_proc == self.parent_comm.rank:
164 # elist = gpaw.mpi.receive_string( sending_proc,
165 # comm=self.parent_comm )
166 # local_elem_list += elist
167 #
169 # Matrix build
170 K_matrix = np.zeros((nlrow, nlcol))
171 K_matrix[:, :] = np.nan # fill with NaNs to detect problems
172 # Read ALL K_matrix files
173 for line in local_elem_list.splitlines():
174 line = line.split()
175 ipkey = (int(line[0]), int(line[1]))
176 jqkey = (int(line[2]), int(line[3]))
177 Kvalue = float(line[4])
179 ip = self.lr_comms.index_map.get(ipkey)
180 jq = self.lr_comms.index_map.get(jqkey)
182 # if not in index map, ignore
183 if (ip is None or jq is None):
184 continue
186 kss_ip = self.ks_singles.kss_list[ip]
187 kss_jq = self.ks_singles.kss_list[jq]
189 # if (ip,jq) on this this proc
190 lip = self.lr_comms.get_local_eh_index(ip)
191 ljq = self.lr_comms.get_local_dd_index(jq)
192 if lip is not None and ljq is not None:
193 # add value to matrix
194 if self.file_format == 1:
195 K_matrix[lip, ljq] = Kvalue
196 elif self.file_format == 0:
197 K_matrix[lip, ljq] = Kvalue / (
198 2. * np.sqrt(kss_ip.energy_diff * kss_jq.energy_diff *
199 kss_ip.pop_diff * kss_jq.pop_diff))
200 else:
201 raise RuntimeError('Invalid K-matrix file format')
203 # if (jq,ip) on this this proc
204 ljq = self.lr_comms.get_local_eh_index(jq)
205 lip = self.lr_comms.get_local_dd_index(ip)
206 if lip is not None and ljq is not None:
207 # add value to matrix
208 if self.file_format == 1:
209 K_matrix[ljq, lip] = Kvalue
210 elif self.file_format == 0:
211 K_matrix[ljq, lip] = Kvalue / (
212 2. * np.sqrt(kss_ip.energy_diff * kss_jq.energy_diff *
213 kss_ip.pop_diff * kss_jq.pop_diff))
214 else:
215 raise RuntimeError('Invalid K-matrix file format')
217 # ready for garbage collection
218 del local_elem_list
220 # If any NaNs found, we did not read all matrix elements... BAD
221 if np.isnan(np.sum(np.sum(K_matrix))):
222 raise RuntimeError(
223 'Not all required K-matrix elements could be found.')
225 self.values = K_matrix
227 # FIXME: implement spin polarized
228 def calculate(self):
229 # Check if already done before allocating
230 if self.K_matrix_ready:
231 return
233 # Loop over all transitions
234 self.K_matrix_ready = True # mark done... if not, it's changed
235 nrows = 0 # number of rows for timings
236 for (ip, kss_ip) in enumerate(self.ks_singles.kss_list):
237 i = kss_ip.occ_ind
238 p = kss_ip.unocc_ind
240 # if not mine, skip it
241 if self.lr_comms.get_local_eh_index(ip) is None:
242 continue
243 # if already calculated, skip it
244 if [i, p] in self.ready_indices:
245 continue
247 self.K_matrix_ready = False # something not calculated, must do it
248 nrows += 1
250 # If matrix was ready, done
251 if self.K_matrix_ready:
252 return
254 # self.timer.start('Calculate K matrix')
255 # self.timer.start('Initialize')
257 # Filenames
258 #######################################################################
259 # saving CORRECT K-matrix, not its "Casida form" like in alpha version
260 #######################################################################
261 Kfn = self.basefilename + '.K_matrix.' + '%06dof%06d' % (
262 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size)
263 rrfn = self.basefilename + '.ready_rows.' + '%06dof%06d' % (
264 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size)
265 logfn = self.basefilename + '.log.' + '%06dof%06d' % (
266 self.lr_comms.eh_comm.rank, self.lr_comms.eh_comm.size)
268 # Open only on dd_comm root
269 if self.lr_comms.dd_comm.rank == 0:
270 self.Kfile = open(Kfn, 'a+')
271 self.ready_file = open(rrfn, 'a+')
272 self.log_file = open(logfn, 'a+')
273 self.Kfile.write('# K-matrix file\n')
275 self.poisson = self.calc.hamiltonian.poisson
277 # Allocate grids for densities and potentials
278 dnt_Gip = self.calc.wfs.gd.empty()
279 dnt_gip = self.calc.density.finegd.empty()
280 drhot_gip = self.calc.density.finegd.empty()
281 dnt_Gjq = self.calc.wfs.gd.empty()
282 dnt_gjq = self.calc.density.finegd.empty()
283 drhot_gjq = self.calc.density.finegd.empty()
285 self.nt_g = self.calc.density.finegd.zeros(
286 self.calc.density.nt_sg.shape[0])
288 dVht_gip = self.calc.density.finegd.empty()
289 dVxct_gip = self.calc.density.finegd.zeros(
290 self.calc.density.nt_sg.shape[0])
292 # self.timer.stop('Initialize')
293 # Init ETA
294 self.matrix_eta = QuadraticETA()
296 #################################################################
297 # Outer loop over KS singles
298 for (ip, kss_ip) in enumerate(self.ks_singles.kss_list):
300 # if not mine, skip it
301 if self.lr_comms.get_local_eh_index(ip) is None:
302 continue
303 # if already calculated, skip it
304 if [kss_ip.occ_ind, kss_ip.unocc_ind] in self.ready_indices:
305 continue
307 # ETA
308 if self.lr_comms.dd_comm.rank == 0:
309 self.matrix_eta.update()
310 # add 21% extra time (1.1**2 = 1.21)
311 eta = self.matrix_eta.eta((nrows + 1) * 1.1)
312 self.log_file.write(
313 'Calculating pair %5d => %5d ( %s, ETA %9.1lfs )\n' %
314 (kss_ip.occ_ind, kss_ip.unocc_ind,
315 str(datetime.datetime.now()), eta))
316 self.log_file.flush()
318 # Pair density
319 # self.timer.start('Pair density')
320 dnt_Gip[:] = 0.0
321 dnt_gip[:] = 0.0
322 drhot_gip[:] = 0.0
323 kss_ip.calculate_pair_density(dnt_Gip, dnt_gip, drhot_gip)
324 # self.timer.stop('Pair density')
326 # Smooth Hartree "pair" potential
327 # for compensated pair density drhot_gip
328 # self.timer.start('Poisson')
329 dVht_gip[:] = 0.0
330 self.poisson.solve(dVht_gip, drhot_gip, charge=None)
331 # self.timer.stop('Poisson')
333 # smooth XC "pair" potential
334 self.calculate_smooth_xc_pair_potential(kss_ip, dnt_gip, dVxct_gip)
335 # paw term of XC "pair" potential
336 I_asp = self.calculate_paw_xc_pair_potentials(kss_ip)
338 #################################################################
339 # Inner loop over KS singles
340 K = [] # storage for row before writing to file
341 for (jq, kss_jq) in enumerate(self.ks_singles.kss_list):
342 i = kss_ip.occ_ind
343 p = kss_ip.unocc_ind
344 j = kss_jq.occ_ind
345 q = kss_jq.unocc_ind
347 # Only lower triangle
348 if ip < jq:
349 continue
351 # Pair density dn_jq
352 # self.timer.start('Pair density')
353 dnt_Gjq[:] = 0.0
354 dnt_gjq[:] = 0.0
355 drhot_gjq[:] = 0.0
356 kss_jq.calculate_pair_density(dnt_Gjq, dnt_gjq, drhot_gjq)
357 # self.timer.stop('Pair density')
359 # integrate to get the final matrix element value
360 # self.timer.start('Integrate')
362 # init grid part
363 Ig = 0.0
365 # Hartree smooth part, RHOT_JQ HERE???
366 Ig += self.fH_pre * self.calc.density.finegd.integrate(
367 dVht_gip, drhot_gjq)
368 # XC smooth part
369 Ig += self.fxc_pre * self.calc.density.finegd.integrate(
370 dVxct_gip, dnt_gjq)
371 # self.timer.stop('Integrate')
373 # Atomic corrections
374 # self.timer.start('Atomic corrections')
375 Ia = self.calculate_paw_fHXC_corrections(kss_ip, kss_jq, I_asp)
377 # self.timer.stop('Atomic corrections')
378 Ia = self.lr_comms.dd_comm.sum_scalar(Ia)
380 # Total integral
381 Itot = Ig + Ia
383 # K_ip,jq += <ip|fHxc|jq>
384 K.append([i, p, j, q, Itot])
386 # Write i p j q Kipjq
387 # (format: -2345.789012345678)
389 # self.timer.start('Write K')
390 # Write only on dd_comm root
391 if self.lr_comms.dd_comm.rank == 0:
393 # Write only lower triangle of K-matrix
394 for [i, p, j, q, Kipjq] in K:
395 self.Kfile.write("%5d %5d %5d %5d %22.16lf\n" %
396 (i, p, j, q, Kipjq[0]))
397 self.Kfile.flush() # flush K-matrix before ready_rows
399 # Write and flush ready rows
400 self.ready_file.write("%d %d\n" %
401 (kss_ip.occ_ind, kss_ip.unocc_ind))
402 self.ready_file.flush()
404 # self.timer.stop('Write K')
406 # Update ready rows before continuing
407 self.ready_indices.append([kss_ip.occ_ind, kss_ip.unocc_ind])
409 # Close files on dd_comm root
410 if self.lr_comms.dd_comm.rank == 0:
411 self.Kfile.close()
412 self.ready_file.close()
413 self.log_file.close()
415 # self.timer.stop('Calculate K matrix')
417 def calculate_smooth_xc_pair_potential(self, kss_ip, dnt_gip, dVxct_gip):
418 # Smooth xc potential
419 # (finite difference approximation from xc-potential)
420 if (self.dVxct_gip_2 is None):
421 self.dVxct_gip_2 = self.calc.density.finegd.zeros(
422 self.calc.density.nt_sg.shape[0])
423 dVxct_gip_2 = self.dVxct_gip_2
425 s = kss_ip.spin_ind
427 # self.timer.start('Smooth XC')
428 # finite difference plus, vxc+ = vxc(n + deriv_scale * dn)
429 self.nt_g[s][:] = self.deriv_scale * dnt_gip
430 self.nt_g[s][:] += self.calc.density.nt_sg[s]
431 dVxct_gip[:] = 0.0
432 self.xc.calculate(self.calc.density.finegd, self.nt_g, dVxct_gip)
434 # finite difference minus, vxc+ = vxc(n - deriv_scale * dn)
435 self.nt_g[s][:] = -self.deriv_scale * dnt_gip
436 self.nt_g[s][:] += self.calc.density.nt_sg[s]
437 dVxct_gip_2[:] = 0.0
438 self.xc.calculate(self.calc.density.finegd, self.nt_g, dVxct_gip_2)
439 dVxct_gip -= dVxct_gip_2
440 # finite difference approx for fxc
441 # vxc = (vxc+ - vxc-) / 2h
442 dVxct_gip *= 1. / (2. * self.deriv_scale)
443 # self.timer.stop('Smooth XC')
445 def calculate_paw_xc_pair_potentials(self, kss_ip):
446 # XC corrections
447 I_asp = {}
448 i = kss_ip.occ_ind
449 p = kss_ip.unocc_ind
451 # s = kss_ip.spin_ind
453 # FIXME, only spin unpolarized works
454 # self.timer.start('Atomic XC')
455 for a, P_ni in self.calc.wfs.kpt_u[kss_ip.kpt_ind].P_ani.items():
456 I_sp = np.zeros_like(self.calc.density.D_asp[a])
457 I_sp_2 = np.zeros_like(self.calc.density.D_asp[a])
459 Pip_ni = self.calc.wfs.kpt_u[kss_ip.spin_ind].P_ani[a]
460 Dip_ii = np.outer(Pip_ni[i], Pip_ni[p])
461 Dip_p = pack_density(Dip_ii)
463 # finite difference plus
464 D_sp = self.calc.density.D_asp[a].copy()
465 D_sp[kss_ip.spin_ind] += self.deriv_scale * Dip_p
466 self.xc.calculate_paw_correction(self.calc.wfs.setups[a], D_sp,
467 I_sp)
469 # finite difference minus
470 D_sp_2 = self.calc.density.D_asp[a].copy()
471 D_sp_2[kss_ip.spin_ind] -= self.deriv_scale * Dip_p
472 self.xc.calculate_paw_correction(self.calc.wfs.setups[a], D_sp_2,
473 I_sp_2)
475 # finite difference
476 I_asp[a] = (I_sp - I_sp_2) / (2. * self.deriv_scale)
478 # self.timer.stop('Atomic XC')
480 return I_asp
482 def calculate_paw_fHXC_corrections(self, kss_ip, kss_jq, I_asp):
483 i = kss_ip.occ_ind
484 p = kss_ip.unocc_ind
485 j = kss_jq.occ_ind
486 q = kss_jq.unocc_ind
488 Ia = 0.0
489 for a, P_ni in self.calc.wfs.kpt_u[kss_jq.spin_ind].P_ani.items():
490 Pip_ni = self.calc.wfs.kpt_u[kss_ip.spin_ind].P_ani[a]
491 Dip_ii = np.outer(Pip_ni[i], Pip_ni[p])
492 Dip_p = pack_density(Dip_ii)
494 Pjq_ni = self.calc.wfs.kpt_u[kss_jq.spin_ind].P_ani[a]
495 Djq_ii = np.outer(Pjq_ni[j], Pjq_ni[q])
496 Djq_p = pack_density(Djq_ii)
498 # Hartree part
499 C_pp = self.calc.wfs.setups[a].M_pp
500 # why factor of two here?
501 # see appendix A of J. Chem. Phys. 128, 244101 (2008)
502 #
503 # 2 sum_prst P P C P P
504 # ip jr prst ks qt
505 Ia += self.fH_pre * 2.0 * np.dot(Djq_p, np.dot(C_pp, Dip_p))
507 # XC part, CHECK THIS JQ EVERWHERE!!!
508 Ia += self.fxc_pre * np.dot(I_asp[a][kss_jq.spin_ind], Djq_p)
510 return Ia