Coverage for gpaw/lcaotddft/utilities.py: 87%
114 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import numpy as np
3from gpaw.blacs import BlacsGrid
4from gpaw.blacs import Redistributor
7def collect_uX(kd, comm, a_uX, s, k):
8 # Comm is a communicator orthogonal to kd.comm (ie, domainband_comm)
9 Xshape = a_uX[0].shape
10 dtype = a_uX[0].dtype
11 kpt_rank, q = kd.get_rank_and_index(k)
12 u = q * kd.nspins + s
13 if kd.comm.rank == kpt_rank:
14 a_X = a_uX[u]
15 # Comm master send to the global master
16 if comm.rank == 0:
17 if kpt_rank == 0:
18 # assert world.rank == 0
19 return a_X
20 else:
21 kd.comm.ssend(a_X, 0, 2018)
22 elif comm.rank == 0 and kpt_rank != 0:
23 # assert world.rank == 0
24 a_X = np.empty(Xshape, dtype=dtype)
25 kd.comm.receive(a_X, kpt_rank, 2018)
26 return a_X
29def write_uX(kd, comm, writer, name, a_uX):
30 ushape = (kd.nspins, kd.nibzkpts)
31 Xshape = a_uX[0].shape
32 dtype = a_uX[0].dtype
33 writer.add_array(name, ushape + Xshape, dtype=dtype)
34 for s in range(kd.nspins):
35 for k in range(kd.nibzkpts):
36 a_X = collect_uX(kd, comm, a_uX, s, k)
37 writer.fill(a_X)
40def read_uX(kpt_u, reader, name):
41 a_uX = []
42 for kpt in kpt_u:
43 indices = (kpt.s, kpt.k)
44 # TODO: does this read on all the comm ranks in vain?
45 a_X = reader.proxy(name, *indices)[:]
46 a_uX.append(a_X)
47 return a_uX
50def distribute_nM(ksl, a_nM):
51 if not ksl.using_blacs:
52 return a_nM
54 dtype = a_nM.dtype
55 ksl.nMdescriptor.checkassert(a_nM)
56 if ksl.gd.rank != 0:
57 a_nM = ksl.nM_unique_descriptor.zeros(dtype=dtype)
59 nM2mm = Redistributor(ksl.block_comm, ksl.nM_unique_descriptor,
60 ksl.mmdescriptor)
62 a_mm = ksl.mmdescriptor.empty(dtype=dtype)
63 nM2mm.redistribute(a_nM, a_mm, ksl.bd.nbands, ksl.nao)
64 return a_mm
67def collect_MM(ksl, a_mm):
68 if not ksl.using_blacs:
69 return a_mm
71 dtype = a_mm.dtype
72 NM = ksl.nao
73 grid = BlacsGrid(ksl.block_comm, 1, 1)
74 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM)
75 mm2MM = Redistributor(ksl.block_comm,
76 ksl.mmdescriptor,
77 MM_descriptor)
79 a_MM = MM_descriptor.empty(dtype=dtype)
80 mm2MM.redistribute(a_mm, a_MM)
81 return a_MM
84def collect_uMM(kd, ksl, a_uMM, s, k):
85 return collect_wuMM(kd, ksl, [a_uMM], 0, s, k)
88def collect_wuMM(kd, ksl, a_wuMM, w, s, k):
89 # This function is based on
90 # gpaw/wavefunctions/base.py: WaveFunctions.collect_auxiliary()
92 dtype = a_wuMM[0][0].dtype
93 NM = ksl.nao
94 kpt_rank, q = kd.get_rank_and_index(k)
95 u = q * kd.nspins + s
96 if kd.comm.rank == kpt_rank:
97 a_MM = a_wuMM[w][u]
99 # Collect within blacs grid
100 a_MM = collect_MM(ksl, a_MM)
102 # KSL master send a_MM to the global master
103 if ksl.block_comm.rank == 0:
104 if kpt_rank == 0:
105 assert ksl.world.rank == 0
106 # I have it already
107 return a_MM
108 else:
109 kd.comm.send(a_MM, 0, 2017)
110 return None
111 elif ksl.block_comm.rank == 0 and kpt_rank != 0:
112 assert ksl.world.rank == 0
113 a_MM = np.empty((NM, NM), dtype=dtype)
114 kd.comm.receive(a_MM, kpt_rank, 2017)
115 return a_MM
118def distribute_MM(ksl, a_MM):
119 if not ksl.using_blacs:
120 return a_MM
122 dtype = a_MM.dtype
123 NM = ksl.nao
124 grid = BlacsGrid(ksl.block_comm, 1, 1)
125 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM)
126 MM2mm = Redistributor(ksl.block_comm,
127 MM_descriptor,
128 ksl.mmdescriptor)
129 if ksl.block_comm.rank != 0:
130 a_MM = MM_descriptor.empty(dtype=dtype)
132 a_mm = ksl.mmdescriptor.empty(dtype=dtype)
133 MM2mm.redistribute(a_MM, a_mm)
134 return a_mm
137def write_uMM(kd, ksl, writer, name, a_uMM):
138 return write_wuMM(kd, ksl, writer, name, [a_uMM], wlist=[0])
141def write_wuMM(kd, ksl, writer, name, a_wuMM, wlist):
142 NM = ksl.nao
143 dtype = a_wuMM[0][0].dtype
144 writer.add_array(name,
145 (len(wlist), kd.nspins, kd.nibzkpts, NM, NM),
146 dtype=dtype)
147 for w in wlist:
148 for s in range(kd.nspins):
149 for k in range(kd.nibzkpts):
150 a_MM = collect_wuMM(kd, ksl, a_wuMM, w, s, k)
151 writer.fill(a_MM)
154def read_uMM(kpt_u, ksl, reader, name):
155 return read_wuMM(kpt_u, ksl, reader, name, wlist=[0])[0]
158def read_wuMM(kpt_u, ksl, reader, name, wlist):
159 a_wuMM = []
160 for w in wlist:
161 a_uMM = []
162 for kpt in kpt_u:
163 indices = (w, kpt.s, kpt.k)
164 # TODO: does this read on all the ksl ranks in vain?
165 a_MM = reader.proxy(name, *indices)[:]
166 a_MM = distribute_MM(ksl, a_MM)
167 a_uMM.append(a_MM)
168 a_wuMM.append(a_uMM)
169 return a_wuMM