Coverage for gpaw/nlopt/basic.py: 76%
59 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from dataclasses import dataclass
3import numpy as np
4from gpaw.mpi import MPIComm, broadcast
7@dataclass
8class NLOData:
9 w_sk: np.ndarray
10 f_skn: np.ndarray
11 E_skn: np.ndarray
12 p_skvnn: np.ndarray
13 comm: MPIComm
15 def write(self, filename):
16 if self.comm.rank == 0:
17 np.savez(filename, w_sk=self.w_sk, f_skn=self.f_skn,
18 E_skn=self.E_skn, p_skvnn=self.p_skvnn)
19 self.comm.barrier()
21 @classmethod
22 def load(cls, filename, comm):
23 """
24 Load the data
26 Input:
27 filename NLO data filename
28 Output:
29 p_kvnn The mometum matrix elements, dimension (nk,3,nb,nb)
30 """
32 # Load the data to the master
33 if comm.rank == 0:
34 nlo = np.load(filename)
35 else:
36 nlo = dict.fromkeys(['w_sk', 'f_skn', 'E_skn', 'p_skvnn'])
37 comm.barrier()
39 return cls(nlo['w_sk'], nlo['f_skn'],
40 nlo['E_skn'], nlo['p_skvnn'], comm)
42 def distribute(self):
43 """
44 Distribute the data among cores
46 Input:
47 arr_list A list of numpy array (the first two should be s,k)
48 Output:
49 k_info A dictionary of data with key of s,k index
50 """
52 arr_list = [self.w_sk, self.f_skn, self.E_skn, self.p_skvnn]
54 # Check the array shape
55 comm = self.comm
56 size = comm.size
57 rank = comm.rank
58 if rank == 0:
59 nk = 0
60 arr_shape = []
61 for ii, arr in enumerate(arr_list):
62 ar_shape = arr.shape
63 arr_shape.append(ar_shape)
64 if nk == 0:
65 ns = ar_shape[0]
66 nk = ar_shape[1]
67 else:
68 assert ar_shape[1] == nk, 'Wrong shape for array.'
69 else:
70 arr_shape = None
71 nk = None
72 ns = None
73 arr_shape = broadcast(arr_shape, root=0, comm=comm)
74 nk = broadcast(nk, root=0, comm=comm)
75 ns = broadcast(ns, root=0, comm=comm)
77 # Distribute the data of k-points between cores
78 k_info = {}
80 # Loop over k points
81 for s1 in range(ns):
82 for kk in range(nk):
83 if rank == 0:
84 if kk % size == rank:
85 k_info[s1 * nk + kk] = [arr[s1, kk]
86 for arr in arr_list]
87 else:
88 for ii, arr in enumerate(arr_list):
89 data_k = np.array(arr[s1, kk], dtype=complex)
90 comm.send(
91 data_k, kk % size, tag=ii * nk + kk)
92 else:
93 if kk % size == rank:
94 dataset = []
95 for ii, cshape in enumerate(arr_shape):
96 data_k = np.empty(cshape[2:], dtype=complex)
97 comm.receive(data_k, 0, tag=ii * nk + kk)
98 dataset.append(data_k)
99 k_info[s1 * nk + kk] = dataset
101 return k_info