Coverage for gpaw/hybrids/kpts.py: 65%

74 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2 

3from gpaw.kpt_descriptor import KPointDescriptor 

4from gpaw.projections import Projections 

5from gpaw.utilities.partition import AtomPartition 

6from gpaw.wavefunctions.arrays import PlaneWaveExpansionWaveFunctions 

7from gpaw.pw.descriptor import PWDescriptor 

8 

9 

10class KPoint: 

11 def __init__(self, 

12 proj, # projections 

13 f_n, # occupations numbers between 0 and 1 

14 k_c, # k-vector in units of reciprocal cell 

15 weight, # weight of k-point 

16 dPdR_aniv=[]): 

17 self.proj = proj 

18 self.f_n = f_n 

19 self.k_c = k_c 

20 self.weight = weight 

21 self.dPdR_aniv = dPdR_aniv 

22 

23 

24class PWKPoint(KPoint): 

25 def __init__(self, psit, *args): # plane-wave expansion of wfs 

26 self.psit = psit 

27 KPoint.__init__(self, *args) 

28 

29 

30class RSKPoint(KPoint): 

31 def __init__(self, u_nR, *args): 

32 self.u_nR = u_nR 

33 KPoint.__init__(self, *args) 

34 

35 

36def to_real_space(psit, na=0, nb=None): 

37 pd = psit.pd 

38 comm = pd.comm 

39 S = comm.size 

40 q = psit.kpt 

41 nbands = len(psit.array) 

42 nb = nb or nbands 

43 u_nR = pd.gd.empty(nbands, pd.dtype, global_array=True) 

44 for n1 in range(0, nbands, S): 

45 n2 = min(n1 + S, nbands) 

46 u_G = pd.alltoall1(psit.array[n1:n2], q) 

47 if u_G is not None: 

48 n = n1 + comm.rank 

49 u_nR[n] = pd.ifft(u_G, local=True, safe=False, q=q) 

50 for n in range(n1, n2): 

51 comm.broadcast(u_nR[n], n - n1) 

52 

53 return u_nR[na:nb] 

54 

55 

56def get_kpt(wfs, k, spin, n1, n2): 

57 k_c = wfs.kd.ibzk_kc[k] 

58 weight = wfs.kd.weight_k[k] 

59 

60 if wfs.world.size == wfs.gd.comm.size: 

61 # Easy: 

62 kpt = wfs.kpt_qs[k][spin] 

63 psit = kpt.psit.view(n1, n2) 

64 proj = kpt.projections.view(n1, n2) 

65 f_n = kpt.f_n[n1:n2] 

66 else: 

67 # Need to redistribute things: 

68 gd = wfs.gd.new_descriptor(comm=wfs.world) 

69 kd = KPointDescriptor([k_c]) 

70 pd = PWDescriptor(wfs.ecut, gd, wfs.dtype, kd, wfs.fftwflags) 

71 psit = PlaneWaveExpansionWaveFunctions(n2 - n1, 

72 pd, 

73 dtype=wfs.dtype, 

74 spin=spin) 

75 for n in range(n1, n2): 

76 psit_G = wfs.get_wave_function_array(n, k, spin, 

77 realspace=False, cut=False) 

78 if isinstance(psit_G, float): 

79 psit_G = None 

80 else: 

81 psit_G = psit_G[:pd.ngmax] 

82 psit._distribute(psit_G, psit.array[n - n1]) 

83 

84 P_nI = wfs.collect_projections(k, spin) 

85 if wfs.world.rank == 0: 

86 P_nI = P_nI[n1:n2] 

87 natoms = len(wfs.setups) 

88 rank_a = np.zeros(natoms, int) 

89 atom_partition = AtomPartition(wfs.world, rank_a) 

90 nproj_a = [setup.ni for setup in wfs.setups] 

91 proj = Projections(n2 - n1, 

92 nproj_a, 

93 atom_partition, 

94 spin=spin, 

95 dtype=wfs.dtype, 

96 data=P_nI) 

97 

98 rank_a = np.linspace(0, wfs.world.size, len(wfs.setups), 

99 endpoint=False).astype(int) 

100 atom_partition = AtomPartition(wfs.world, rank_a) 

101 proj = proj.redist(atom_partition) 

102 

103 f_n = wfs.collect_occupations(k, spin) 

104 if wfs.world.rank != 0: 

105 f_n = np.empty(n2 - n1) 

106 else: 

107 f_n = f_n[n1:n2] 

108 wfs.world.broadcast(f_n, 0) 

109 

110 f_n = f_n / (weight * (2 // wfs.nspins)) # scale to [0, 1] 

111 

112 return PWKPoint(psit, proj, f_n, k_c, weight)