Coverage for gpaw/lrtddft2/lr_communicators.py: 80%
61 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import gpaw.mpi
4class LrCommunicators:
5 def __init__(self, world, dd_size: int, eh_size: int = None):
6 """Create communicators for LrTDDFT calculation.
8 Parameters
9 ----------
10 world
11 MPI parent communicator (usually ``gpaw.mpi.world``)
12 dd_size
13 Number of domains for domain decomposition
14 eh_size
15 Number of groups for parallelization over electron-hole pairs
17 Note
18 ----
19 Sizes must match, i.e., world.size must be equal to
20 dd_size x eh_size, e.g., 1024 = 64*16
22 Tip
23 ---
24 Use enough processes for domain decomposition (dd_size) to fit
25 everything (easily) into memory, and use the remaining processes
26 for electron-hole pairs as K-matrix build is trivially parallel
27 over them.
30 Pass ``lr_comms.dd_comm`` to ground state calc when
31 reading for LrTDDFT.
34 Examples
35 --------
37 For 8 MPI processes::
39 lr_comms = LrCommunicators(gpaw.mpi.world, 4, 2)
40 txt = 'lr_%06d_%06d.txt' % (lr_comms.dd_comm.rank,
41 lr_comms.eh_comm.rank)
42 calc = GPAW('unocc.gpw', communicator=lr_comms.dd_comm)
43 lr = LrTDDFT2(calc, lr_communicators=lr_comms, txt=txt)
45 """
47 self.parent_comm = None
48 self.dd_comm = None
49 self.eh_comm = None
51 self.world = world
52 self.dd_size = dd_size
53 self.eh_size = eh_size
55 if self.world is None:
56 return
57 if self.dd_size is None:
58 return
60 if self.eh_size is None:
61 self.eh_size = self.world.size // self.dd_size
63 self.parent_comm = self.world
65 if self.world.size != self.dd_size * self.eh_size:
66 raise RuntimeError('Domain decomposition processes (dd_size) '
67 'times electron-hole (eh_size) processes '
68 'does not match with total processes '
69 '(world size != dd_size * eh_size)')
71 dd_ranks = []
72 eh_ranks = []
73 for k in range(self.world.size):
74 if k // self.dd_size == self.world.rank // self.dd_size:
75 dd_ranks.append(k)
76 if k % self.dd_size == self.world.rank % self.dd_size:
77 eh_ranks.append(k)
78 self.dd_comm = self.world.new_communicator(dd_ranks)
79 self.eh_comm = self.world.new_communicator(eh_ranks)
81 def initialize(self, calc):
82 if self.parent_comm is None:
83 if calc is not None:
84 self.dd_comm = calc.density.gd.comm
85 self.parent_comm = self.dd_comm.parent
86 if self.parent_comm.size != self.dd_comm.size:
87 raise RuntimeError(
88 'Invalid communicators in LrTDDFT2. Ground state '
89 'calculator domain decomposition communicator and '
90 'its parent (or actually its parent parent) has '
91 'different size. Please set up LrCommunicators '
92 'explicitly to avoid this. Or contact developers '
93 'if this is intentional.'
94 )
95 self.eh_comm = gpaw.mpi.serial_comm
96 else:
97 self.parent_comm = gpaw.mpi.serial_comm
98 self.dd_comm = gpaw.mpi.serial_comm
99 self.eh_comm = gpaw.mpi.serial_comm
100 else:
101 # Check that parent_comm is valid
102 if self.parent_comm != self.eh_comm.parent:
103 raise RuntimeError(
104 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent '
105 'communicator does is not parent of electron-hole '
106 'communicator. Please set up LrCommunicators explicitly '
107 'to avoid this.')
108 if self.parent_comm != self.dd_comm.parent:
109 raise RuntimeError(
110 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent '
111 'communicator does is not parent of domain decomposition '
112 'communicator. Please set up LrCommunicators explicitly '
113 'to avoid this.')
115 # Do not use so slow... unless absolutely necessary
116 # def index_of_kss(self,i,p):
117 # for (ind,kss) in enumerate(self.kss_list):
118 # if kss.occ_ind == i and kss.unocc_ind == p:
119 # return ind
120 # return None
122 def get_local_eh_index(self, ip):
123 if ip % self.eh_comm.size != self.eh_comm.rank:
124 return None
125 return ip // self.eh_comm.size
127 def get_local_dd_index(self, jq):
128 if jq % self.dd_comm.size != self.dd_comm.rank:
129 return None
130 return jq // self.dd_comm.size
132 def get_global_eh_index(self, lip):
133 return lip * self.eh_comm.size + self.eh_comm.rank
135 def get_global_dd_index(self, ljq):
136 return ljq * self.dd_comm.size + self.dd_comm.rank
138 def get_matrix_elem_proc_and_index(self, ip, jq):
139 ehproc = ip % self.eh_comm.size
140 ddproc = jq % self.dd_comm.size
141 proc = ehproc * self.dd_comm.size + ddproc
142 lip = ip // self.eh_comm.size
143 ljq = jq // self.eh_comm.size
144 return (proc, ehproc, ddproc, lip, ljq)