Coverage for gpaw/lrtddft2/lr_communicators.py: 80%

61 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import gpaw.mpi 

2 

3 

4class LrCommunicators: 

5 def __init__(self, world, dd_size: int, eh_size: int = None): 

6 """Create communicators for LrTDDFT calculation. 

7 

8 Parameters 

9 ---------- 

10 world 

11 MPI parent communicator (usually ``gpaw.mpi.world``) 

12 dd_size 

13 Number of domains for domain decomposition 

14 eh_size 

15 Number of groups for parallelization over electron-hole pairs 

16 

17 Note 

18 ---- 

19 Sizes must match, i.e., world.size must be equal to 

20 dd_size x eh_size, e.g., 1024 = 64*16 

21 

22 Tip 

23 --- 

24 Use enough processes for domain decomposition (dd_size) to fit 

25 everything (easily) into memory, and use the remaining processes 

26 for electron-hole pairs as K-matrix build is trivially parallel 

27 over them. 

28 

29 

30 Pass ``lr_comms.dd_comm`` to ground state calc when 

31 reading for LrTDDFT. 

32 

33 

34 Examples 

35 -------- 

36 

37 For 8 MPI processes:: 

38 

39 lr_comms = LrCommunicators(gpaw.mpi.world, 4, 2) 

40 txt = 'lr_%06d_%06d.txt' % (lr_comms.dd_comm.rank, 

41 lr_comms.eh_comm.rank) 

42 calc = GPAW('unocc.gpw', communicator=lr_comms.dd_comm) 

43 lr = LrTDDFT2(calc, lr_communicators=lr_comms, txt=txt) 

44 

45 """ 

46 

47 self.parent_comm = None 

48 self.dd_comm = None 

49 self.eh_comm = None 

50 

51 self.world = world 

52 self.dd_size = dd_size 

53 self.eh_size = eh_size 

54 

55 if self.world is None: 

56 return 

57 if self.dd_size is None: 

58 return 

59 

60 if self.eh_size is None: 

61 self.eh_size = self.world.size // self.dd_size 

62 

63 self.parent_comm = self.world 

64 

65 if self.world.size != self.dd_size * self.eh_size: 

66 raise RuntimeError('Domain decomposition processes (dd_size) ' 

67 'times electron-hole (eh_size) processes ' 

68 'does not match with total processes ' 

69 '(world size != dd_size * eh_size)') 

70 

71 dd_ranks = [] 

72 eh_ranks = [] 

73 for k in range(self.world.size): 

74 if k // self.dd_size == self.world.rank // self.dd_size: 

75 dd_ranks.append(k) 

76 if k % self.dd_size == self.world.rank % self.dd_size: 

77 eh_ranks.append(k) 

78 self.dd_comm = self.world.new_communicator(dd_ranks) 

79 self.eh_comm = self.world.new_communicator(eh_ranks) 

80 

81 def initialize(self, calc): 

82 if self.parent_comm is None: 

83 if calc is not None: 

84 self.dd_comm = calc.density.gd.comm 

85 self.parent_comm = self.dd_comm.parent 

86 if self.parent_comm.size != self.dd_comm.size: 

87 raise RuntimeError( 

88 'Invalid communicators in LrTDDFT2. Ground state ' 

89 'calculator domain decomposition communicator and ' 

90 'its parent (or actually its parent parent) has ' 

91 'different size. Please set up LrCommunicators ' 

92 'explicitly to avoid this. Or contact developers ' 

93 'if this is intentional.' 

94 ) 

95 self.eh_comm = gpaw.mpi.serial_comm 

96 else: 

97 self.parent_comm = gpaw.mpi.serial_comm 

98 self.dd_comm = gpaw.mpi.serial_comm 

99 self.eh_comm = gpaw.mpi.serial_comm 

100 else: 

101 # Check that parent_comm is valid 

102 if self.parent_comm != self.eh_comm.parent: 

103 raise RuntimeError( 

104 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent ' 

105 'communicator does is not parent of electron-hole ' 

106 'communicator. Please set up LrCommunicators explicitly ' 

107 'to avoid this.') 

108 if self.parent_comm != self.dd_comm.parent: 

109 raise RuntimeError( 

110 'Invalid communicators in LrTDDFT2. LrTDDFT2 parent ' 

111 'communicator does is not parent of domain decomposition ' 

112 'communicator. Please set up LrCommunicators explicitly ' 

113 'to avoid this.') 

114 

115 # Do not use so slow... unless absolutely necessary 

116 # def index_of_kss(self,i,p): 

117 # for (ind,kss) in enumerate(self.kss_list): 

118 # if kss.occ_ind == i and kss.unocc_ind == p: 

119 # return ind 

120 # return None 

121 

122 def get_local_eh_index(self, ip): 

123 if ip % self.eh_comm.size != self.eh_comm.rank: 

124 return None 

125 return ip // self.eh_comm.size 

126 

127 def get_local_dd_index(self, jq): 

128 if jq % self.dd_comm.size != self.dd_comm.rank: 

129 return None 

130 return jq // self.dd_comm.size 

131 

132 def get_global_eh_index(self, lip): 

133 return lip * self.eh_comm.size + self.eh_comm.rank 

134 

135 def get_global_dd_index(self, ljq): 

136 return ljq * self.dd_comm.size + self.dd_comm.rank 

137 

138 def get_matrix_elem_proc_and_index(self, ip, jq): 

139 ehproc = ip % self.eh_comm.size 

140 ddproc = jq % self.dd_comm.size 

141 proc = ehproc * self.dd_comm.size + ddproc 

142 lip = ip // self.eh_comm.size 

143 ljq = jq // self.eh_comm.size 

144 return (proc, ehproc, ddproc, lip, ljq)