Coverage for gpaw/hybrids/kpts.py: 65%
74 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import numpy as np
3from gpaw.kpt_descriptor import KPointDescriptor
4from gpaw.projections import Projections
5from gpaw.utilities.partition import AtomPartition
6from gpaw.wavefunctions.arrays import PlaneWaveExpansionWaveFunctions
7from gpaw.pw.descriptor import PWDescriptor
10class KPoint:
11 def __init__(self,
12 proj, # projections
13 f_n, # occupations numbers between 0 and 1
14 k_c, # k-vector in units of reciprocal cell
15 weight, # weight of k-point
16 dPdR_aniv=[]):
17 self.proj = proj
18 self.f_n = f_n
19 self.k_c = k_c
20 self.weight = weight
21 self.dPdR_aniv = dPdR_aniv
24class PWKPoint(KPoint):
25 def __init__(self, psit, *args): # plane-wave expansion of wfs
26 self.psit = psit
27 KPoint.__init__(self, *args)
30class RSKPoint(KPoint):
31 def __init__(self, u_nR, *args):
32 self.u_nR = u_nR
33 KPoint.__init__(self, *args)
36def to_real_space(psit, na=0, nb=None):
37 pd = psit.pd
38 comm = pd.comm
39 S = comm.size
40 q = psit.kpt
41 nbands = len(psit.array)
42 nb = nb or nbands
43 u_nR = pd.gd.empty(nbands, pd.dtype, global_array=True)
44 for n1 in range(0, nbands, S):
45 n2 = min(n1 + S, nbands)
46 u_G = pd.alltoall1(psit.array[n1:n2], q)
47 if u_G is not None:
48 n = n1 + comm.rank
49 u_nR[n] = pd.ifft(u_G, local=True, safe=False, q=q)
50 for n in range(n1, n2):
51 comm.broadcast(u_nR[n], n - n1)
53 return u_nR[na:nb]
56def get_kpt(wfs, k, spin, n1, n2):
57 k_c = wfs.kd.ibzk_kc[k]
58 weight = wfs.kd.weight_k[k]
60 if wfs.world.size == wfs.gd.comm.size:
61 # Easy:
62 kpt = wfs.kpt_qs[k][spin]
63 psit = kpt.psit.view(n1, n2)
64 proj = kpt.projections.view(n1, n2)
65 f_n = kpt.f_n[n1:n2]
66 else:
67 # Need to redistribute things:
68 gd = wfs.gd.new_descriptor(comm=wfs.world)
69 kd = KPointDescriptor([k_c])
70 pd = PWDescriptor(wfs.ecut, gd, wfs.dtype, kd, wfs.fftwflags)
71 psit = PlaneWaveExpansionWaveFunctions(n2 - n1,
72 pd,
73 dtype=wfs.dtype,
74 spin=spin)
75 for n in range(n1, n2):
76 psit_G = wfs.get_wave_function_array(n, k, spin,
77 realspace=False, cut=False)
78 if isinstance(psit_G, float):
79 psit_G = None
80 else:
81 psit_G = psit_G[:pd.ngmax]
82 psit._distribute(psit_G, psit.array[n - n1])
84 P_nI = wfs.collect_projections(k, spin)
85 if wfs.world.rank == 0:
86 P_nI = P_nI[n1:n2]
87 natoms = len(wfs.setups)
88 rank_a = np.zeros(natoms, int)
89 atom_partition = AtomPartition(wfs.world, rank_a)
90 nproj_a = [setup.ni for setup in wfs.setups]
91 proj = Projections(n2 - n1,
92 nproj_a,
93 atom_partition,
94 spin=spin,
95 dtype=wfs.dtype,
96 data=P_nI)
98 rank_a = np.linspace(0, wfs.world.size, len(wfs.setups),
99 endpoint=False).astype(int)
100 atom_partition = AtomPartition(wfs.world, rank_a)
101 proj = proj.redist(atom_partition)
103 f_n = wfs.collect_occupations(k, spin)
104 if wfs.world.rank != 0:
105 f_n = np.empty(n2 - n1)
106 else:
107 f_n = f_n[n1:n2]
108 wfs.world.broadcast(f_n, 0)
110 f_n = f_n / (weight * (2 // wfs.nspins)) # scale to [0, 1]
112 return PWKPoint(psit, proj, f_n, k_c, weight)