Coverage for gpaw/analyse/overlap.py: 100%

58 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import numpy as np 

2 

3from gpaw.utilities import packed_index 

4 

5 

6class Overlap: 

7 """Wave funcion overlap of two GPAW objects""" 

8 def __init__(self, calc): 

9 self.calc = calc 

10 self.nb, self.nk, self.ns = self.number_of_states(calc) 

11 self.gd = self.calc.wfs.gd 

12 self.kd = self.calc.wfs.kd 

13 

14 def number_of_states(self, calc): 

15 # we will need the wave functions 

16 calc.converge_wave_functions() 

17 return (calc.get_number_of_bands(), len(calc.get_ibz_k_points()), 

18 calc.get_number_of_spins()) 

19 

20 def pseudo(self, other, myspin=0, otherspin=0, normalize=True): 

21 r"""Overlap with pseudo wave functions only 

22 

23 Parameter 

24 --------- 

25 other: gpaw 

26 gpaw-object containing pseudo wave functions 

27 normalize: bool 

28 normalize pseudo wave functions in the overlap integral 

29 

30 Returns 

31 ------- 

32 out: array 

33 u_kij = \int dx mypsitilde_ki^*(x) otherpsitilde_kj(x) 

34 """ 

35 nbo, nko, _ = self.number_of_states(other) 

36 assert self.nk == nko # XXX allow for different number of k-points ? 

37 

38 overlap_knn = [] 

39 for k in range(self.nk): 

40 # XXX what if parallelized over spin or kpoints ? 

41 kpt_rank, q = self.kd.get_rank_and_index(k) 

42 assert self.kd.comm.rank == kpt_rank 

43 kpt_rank, qo = other.wfs.kd.get_rank_and_index(k) 

44 assert self.kd.comm.rank == kpt_rank 

45 

46 overlap_nn = np.zeros((self.nb, nbo), dtype=self.calc.wfs.dtype) 

47 mkpt = self.calc.wfs.kpt_qs[q][myspin] 

48 okpt = other.wfs.kpt_qs[qo][otherspin] 

49 psit_nG = mkpt.psit_nG 

50 norm_n = self.gd.integrate(psit_nG.conj() * psit_nG) 

51 psito_nG = okpt.psit_nG 

52 normo_n = other.wfs.gd.integrate(psito_nG.conj() * psito_nG) 

53 for i in range(self.nb): 

54 p_nG = np.repeat(psit_nG[i].conj()[np.newaxis], nbo, axis=0) 

55 overlap_nn[i] = self.gd.integrate(p_nG * psito_nG) 

56 if normalize: 

57 overlap_nn[i] /= np.sqrt(np.repeat(norm_n[i], nbo) * 

58 normo_n) 

59 overlap_knn.append(overlap_nn) 

60 return np.array(overlap_knn) 

61 

62 def full(self, other, myspin=0, otherspin=0): 

63 """Overlap of Kohn-Sham states including local terms. 

64 

65 Parameter 

66 --------- 

67 other: gpaw 

68 gpaw-object containing wave functions 

69 

70 Returns 

71 ------- 

72 out: array 

73 u_kij = int dx mypsi_ki^*(x) otherpsi_kj(x) 

74 """ 

75 ov_knn = self.pseudo(other, normalize=False) 

76 for k in range(self.nk): 

77 # XXX what if parallelized over spin or kpoints ? 

78 kpt_rank, q = self.kd.get_rank_and_index(k) 

79 assert self.kd.comm.rank == kpt_rank 

80 kpt_rank, qo = other.wfs.kd.get_rank_and_index(k) 

81 assert self.kd.comm.rank == kpt_rank 

82 

83 mkpt = self.calc.wfs.kpt_qs[q][myspin] 

84 okpt = other.wfs.kpt_qs[qo][otherspin] 

85 

86 aov_nn = np.zeros_like(ov_knn[k]) 

87 for a, mP_ni in mkpt.P_ani.items(): 

88 oP_ni = okpt.P_ani[a] 

89 Delta_p = (np.sqrt(4 * np.pi) * 

90 self.calc.wfs.setups[a].Delta_pL[:, 0]) 

91 for n0, mP_i in enumerate(mP_ni): 

92 for n1, oP_i in enumerate(oP_ni): 

93 ni = len(mP_i) 

94 assert len(oP_i) == ni 

95 for i, mP in enumerate(mP_i): 

96 for j, oP in enumerate(oP_i): 

97 ij = packed_index(i, j, ni) 

98 aov_nn[n0, n1] += Delta_p[ij] * mP.conj() * oP 

99 self.calc.wfs.gd.comm.sum(aov_nn) 

100 ov_knn[k] += aov_nn 

101 return ov_knn