Coverage for gpaw/lcaotddft/utilities.py: 87%

114 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import numpy as np 

2 

3from gpaw.blacs import BlacsGrid 

4from gpaw.blacs import Redistributor 

5 

6 

7def collect_uX(kd, comm, a_uX, s, k): 

8 # Comm is a communicator orthogonal to kd.comm (ie, domainband_comm) 

9 Xshape = a_uX[0].shape 

10 dtype = a_uX[0].dtype 

11 kpt_rank, q = kd.get_rank_and_index(k) 

12 u = q * kd.nspins + s 

13 if kd.comm.rank == kpt_rank: 

14 a_X = a_uX[u] 

15 # Comm master send to the global master 

16 if comm.rank == 0: 

17 if kpt_rank == 0: 

18 # assert world.rank == 0 

19 return a_X 

20 else: 

21 kd.comm.ssend(a_X, 0, 2018) 

22 elif comm.rank == 0 and kpt_rank != 0: 

23 # assert world.rank == 0 

24 a_X = np.empty(Xshape, dtype=dtype) 

25 kd.comm.receive(a_X, kpt_rank, 2018) 

26 return a_X 

27 

28 

29def write_uX(kd, comm, writer, name, a_uX): 

30 ushape = (kd.nspins, kd.nibzkpts) 

31 Xshape = a_uX[0].shape 

32 dtype = a_uX[0].dtype 

33 writer.add_array(name, ushape + Xshape, dtype=dtype) 

34 for s in range(kd.nspins): 

35 for k in range(kd.nibzkpts): 

36 a_X = collect_uX(kd, comm, a_uX, s, k) 

37 writer.fill(a_X) 

38 

39 

40def read_uX(kpt_u, reader, name): 

41 a_uX = [] 

42 for kpt in kpt_u: 

43 indices = (kpt.s, kpt.k) 

44 # TODO: does this read on all the comm ranks in vain? 

45 a_X = reader.proxy(name, *indices)[:] 

46 a_uX.append(a_X) 

47 return a_uX 

48 

49 

50def distribute_nM(ksl, a_nM): 

51 if not ksl.using_blacs: 

52 return a_nM 

53 

54 dtype = a_nM.dtype 

55 ksl.nMdescriptor.checkassert(a_nM) 

56 if ksl.gd.rank != 0: 

57 a_nM = ksl.nM_unique_descriptor.zeros(dtype=dtype) 

58 

59 nM2mm = Redistributor(ksl.block_comm, ksl.nM_unique_descriptor, 

60 ksl.mmdescriptor) 

61 

62 a_mm = ksl.mmdescriptor.empty(dtype=dtype) 

63 nM2mm.redistribute(a_nM, a_mm, ksl.bd.nbands, ksl.nao) 

64 return a_mm 

65 

66 

67def collect_MM(ksl, a_mm): 

68 if not ksl.using_blacs: 

69 return a_mm 

70 

71 dtype = a_mm.dtype 

72 NM = ksl.nao 

73 grid = BlacsGrid(ksl.block_comm, 1, 1) 

74 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM) 

75 mm2MM = Redistributor(ksl.block_comm, 

76 ksl.mmdescriptor, 

77 MM_descriptor) 

78 

79 a_MM = MM_descriptor.empty(dtype=dtype) 

80 mm2MM.redistribute(a_mm, a_MM) 

81 return a_MM 

82 

83 

84def collect_uMM(kd, ksl, a_uMM, s, k): 

85 return collect_wuMM(kd, ksl, [a_uMM], 0, s, k) 

86 

87 

88def collect_wuMM(kd, ksl, a_wuMM, w, s, k): 

89 # This function is based on 

90 # gpaw/wavefunctions/base.py: WaveFunctions.collect_auxiliary() 

91 

92 dtype = a_wuMM[0][0].dtype 

93 NM = ksl.nao 

94 kpt_rank, q = kd.get_rank_and_index(k) 

95 u = q * kd.nspins + s 

96 if kd.comm.rank == kpt_rank: 

97 a_MM = a_wuMM[w][u] 

98 

99 # Collect within blacs grid 

100 a_MM = collect_MM(ksl, a_MM) 

101 

102 # KSL master send a_MM to the global master 

103 if ksl.block_comm.rank == 0: 

104 if kpt_rank == 0: 

105 assert ksl.world.rank == 0 

106 # I have it already 

107 return a_MM 

108 else: 

109 kd.comm.send(a_MM, 0, 2017) 

110 return None 

111 elif ksl.block_comm.rank == 0 and kpt_rank != 0: 

112 assert ksl.world.rank == 0 

113 a_MM = np.empty((NM, NM), dtype=dtype) 

114 kd.comm.receive(a_MM, kpt_rank, 2017) 

115 return a_MM 

116 

117 

118def distribute_MM(ksl, a_MM): 

119 if not ksl.using_blacs: 

120 return a_MM 

121 

122 dtype = a_MM.dtype 

123 NM = ksl.nao 

124 grid = BlacsGrid(ksl.block_comm, 1, 1) 

125 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM) 

126 MM2mm = Redistributor(ksl.block_comm, 

127 MM_descriptor, 

128 ksl.mmdescriptor) 

129 if ksl.block_comm.rank != 0: 

130 a_MM = MM_descriptor.empty(dtype=dtype) 

131 

132 a_mm = ksl.mmdescriptor.empty(dtype=dtype) 

133 MM2mm.redistribute(a_MM, a_mm) 

134 return a_mm 

135 

136 

137def write_uMM(kd, ksl, writer, name, a_uMM): 

138 return write_wuMM(kd, ksl, writer, name, [a_uMM], wlist=[0]) 

139 

140 

141def write_wuMM(kd, ksl, writer, name, a_wuMM, wlist): 

142 NM = ksl.nao 

143 dtype = a_wuMM[0][0].dtype 

144 writer.add_array(name, 

145 (len(wlist), kd.nspins, kd.nibzkpts, NM, NM), 

146 dtype=dtype) 

147 for w in wlist: 

148 for s in range(kd.nspins): 

149 for k in range(kd.nibzkpts): 

150 a_MM = collect_wuMM(kd, ksl, a_wuMM, w, s, k) 

151 writer.fill(a_MM) 

152 

153 

154def read_uMM(kpt_u, ksl, reader, name): 

155 return read_wuMM(kpt_u, ksl, reader, name, wlist=[0])[0] 

156 

157 

158def read_wuMM(kpt_u, ksl, reader, name, wlist): 

159 a_wuMM = [] 

160 for w in wlist: 

161 a_uMM = [] 

162 for kpt in kpt_u: 

163 indices = (w, kpt.s, kpt.k) 

164 # TODO: does this read on all the ksl ranks in vain? 

165 a_MM = reader.proxy(name, *indices)[:] 

166 a_MM = distribute_MM(ksl, a_MM) 

167 a_uMM.append(a_MM) 

168 a_wuMM.append(a_uMM) 

169 return a_wuMM