Coverage for gpaw/nlopt/basic.py: 76%

59 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from dataclasses import dataclass 

2 

3import numpy as np 

4from gpaw.mpi import MPIComm, broadcast 

5 

6 

7@dataclass 

8class NLOData: 

9 w_sk: np.ndarray 

10 f_skn: np.ndarray 

11 E_skn: np.ndarray 

12 p_skvnn: np.ndarray 

13 comm: MPIComm 

14 

15 def write(self, filename): 

16 if self.comm.rank == 0: 

17 np.savez(filename, w_sk=self.w_sk, f_skn=self.f_skn, 

18 E_skn=self.E_skn, p_skvnn=self.p_skvnn) 

19 self.comm.barrier() 

20 

21 @classmethod 

22 def load(cls, filename, comm): 

23 """ 

24 Load the data 

25 

26 Input: 

27 filename NLO data filename 

28 Output: 

29 p_kvnn The mometum matrix elements, dimension (nk,3,nb,nb) 

30 """ 

31 

32 # Load the data to the master 

33 if comm.rank == 0: 

34 nlo = np.load(filename) 

35 else: 

36 nlo = dict.fromkeys(['w_sk', 'f_skn', 'E_skn', 'p_skvnn']) 

37 comm.barrier() 

38 

39 return cls(nlo['w_sk'], nlo['f_skn'], 

40 nlo['E_skn'], nlo['p_skvnn'], comm) 

41 

42 def distribute(self): 

43 """ 

44 Distribute the data among cores 

45 

46 Input: 

47 arr_list A list of numpy array (the first two should be s,k) 

48 Output: 

49 k_info A dictionary of data with key of s,k index 

50 """ 

51 

52 arr_list = [self.w_sk, self.f_skn, self.E_skn, self.p_skvnn] 

53 

54 # Check the array shape 

55 comm = self.comm 

56 size = comm.size 

57 rank = comm.rank 

58 if rank == 0: 

59 nk = 0 

60 arr_shape = [] 

61 for ii, arr in enumerate(arr_list): 

62 ar_shape = arr.shape 

63 arr_shape.append(ar_shape) 

64 if nk == 0: 

65 ns = ar_shape[0] 

66 nk = ar_shape[1] 

67 else: 

68 assert ar_shape[1] == nk, 'Wrong shape for array.' 

69 else: 

70 arr_shape = None 

71 nk = None 

72 ns = None 

73 arr_shape = broadcast(arr_shape, root=0, comm=comm) 

74 nk = broadcast(nk, root=0, comm=comm) 

75 ns = broadcast(ns, root=0, comm=comm) 

76 

77 # Distribute the data of k-points between cores 

78 k_info = {} 

79 

80 # Loop over k points 

81 for s1 in range(ns): 

82 for kk in range(nk): 

83 if rank == 0: 

84 if kk % size == rank: 

85 k_info[s1 * nk + kk] = [arr[s1, kk] 

86 for arr in arr_list] 

87 else: 

88 for ii, arr in enumerate(arr_list): 

89 data_k = np.array(arr[s1, kk], dtype=complex) 

90 comm.send( 

91 data_k, kk % size, tag=ii * nk + kk) 

92 else: 

93 if kk % size == rank: 

94 dataset = [] 

95 for ii, cshape in enumerate(arr_shape): 

96 data_k = np.empty(cshape[2:], dtype=complex) 

97 comm.receive(data_k, 0, tag=ii * nk + kk) 

98 dataset.append(data_k) 

99 k_info[s1 * nk + kk] = dataset 

100 

101 return k_info