Coverage for gpaw/new/brillouin.py: 90%

92 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1"""Brillouin-zone sampling.""" 

2from __future__ import annotations 

3from typing import TYPE_CHECKING 

4import numpy as np 

5from ase.dft.kpoints import monkhorst_pack 

6from gpaw.mpi import MPIComm 

7from gpaw.typing import Array1D, ArrayLike2D 

8from gpaw.symmetry import reduce_kpts 

9if TYPE_CHECKING: 

10 from gpaw.new.symmetry import Symmetries 

11 

12 

13class BZPoints: 

14 def __init__(self, points: ArrayLike2D): 

15 self.kpt_Kc = np.array(points) 

16 assert self.kpt_Kc.ndim == 2 

17 assert self.kpt_Kc.shape[1] == 3 

18 self.gamma_only = len(self.kpt_Kc) == 1 and not self.kpt_Kc.any() 

19 

20 def __len__(self): 

21 """Number of k-points in the BZ.""" 

22 return len(self.kpt_Kc) 

23 

24 def __repr__(self): 

25 if self.gamma_only: 

26 return 'BZPoints([<gamma only>])' 

27 return f'BZPoints([<{len(self)} points>])' 

28 

29 def reduce(self, 

30 symmetries: Symmetries, 

31 *, 

32 comm: MPIComm = None, 

33 strict: bool = True, 

34 use_time_reversal=True, 

35 tolerance=1e-7) -> IBZ: 

36 """Find irreducible set of k-points.""" 

37 if not use_time_reversal and len(symmetries) == 1: 

38 N = len(self) 

39 return IBZ(symmetries, 

40 self, 

41 ibz2bz=np.arange(N), 

42 bz2ibz=np.arange(N), 

43 weights=np.ones(N) / N, 

44 bz2bz_Ks=np.arange(N).reshape((N, 1)), 

45 s_K=np.zeros(N, int), 

46 time_reversal_K=np.zeros(N, bool)) 

47 

48 if symmetries.has_inversion: 

49 use_time_reversal = False 

50 (_, weight_k, sym_K, time_reversal_K, bz2ibz_K, ibz2bz_k, 

51 bz2bz_Ks) = reduce_kpts(self.kpt_Kc, 

52 symmetries.rotation_scc, 

53 use_time_reversal, 

54 comm, 

55 tolerance) 

56 

57 if strict and -1 in bz2bz_Ks: 

58 raise ValueError( 

59 'Your k-points are not as symmetric as your crystal!') 

60 

61 return IBZ(symmetries, self, ibz2bz_k, bz2ibz_K, weight_k, bz2bz_Ks, 

62 sym_K, time_reversal_K) 

63 

64 

65class BZBandPath(BZPoints): 

66 def __init__(self, band_path): 

67 self.band_path = band_path 

68 super().__init__(band_path.kpts) 

69 

70 

71class MonkhorstPackKPoints(BZPoints): 

72 def __init__(self, size, shift=(0, 0, 0)): 

73 self.size_c = size 

74 self.shift_c = np.array(shift) 

75 super().__init__(monkhorst_pack(size) + shift) 

76 

77 def __repr__(self): 

78 return f'MonkhorstPackKPoints({self.size_c}, shift={self.shift_c})' 

79 

80 def __str__(self): 

81 a, b, c = self.size_c 

82 l, m, n = self.shift_c 

83 return (f'monkhorst-pack size: [{a}, {b}, {c}]\n' 

84 f'monkhorst-pack shift: [{l}, {m}, {n}]\n') 

85 

86 

87class IBZ: 

88 def __init__(self, 

89 symmetries: Symmetries, 

90 bz: BZPoints, 

91 ibz2bz, bz2ibz, weights, 

92 bz2bz_Ks=None, s_K=None, time_reversal_K=None): 

93 self.symmetries = symmetries 

94 self.bz = bz 

95 self.weight_k = weights 

96 self.kpt_kc = bz.kpt_Kc[ibz2bz] 

97 self.ibz2bz_k = ibz2bz 

98 self.bz2ibz_K = bz2ibz 

99 self.bz2bz_Ks = bz2bz_Ks 

100 self.s_K = s_K 

101 self.time_reversal_K = time_reversal_K 

102 

103 def __len__(self): 

104 """Number of k-points in the IBZ.""" 

105 return len(self.kpt_kc) 

106 

107 def __repr__(self): 

108 return (f'IBZ(<points: {len(self)}, ' 

109 f'symmetries: {len(self.symmetries)}>)') 

110 

111 def __str__(self): 

112 N = len(self) 

113 txt = ('bz sampling:\n' 

114 f' number of bz points: {len(self.bz)}\n' 

115 f' number of ibz points: {N}\n') 

116 

117 if self.bz2bz_Ks is not None and -1 in self.bz2bz_Ks: 

118 txt += ' your k-points are not as symmetric as your crystal!\n' 

119 

120 if isinstance(self.bz, MonkhorstPackKPoints): 

121 txt += ' ' + str(self.bz).replace('\n', '\n ', 1) 

122 

123 txt += ' points and weights: [\n' 

124 k = 0 

125 while k < N: 

126 if k == 10: 

127 if N > 10: 

128 txt += ' # ...\n' 

129 k = N - 1 

130 a, b, c = self.kpt_kc[k] 

131 w = self.weight_k[k] 

132 t = ',' if k < N - 1 else ']' 

133 txt += (f' [[{a:12.8f}, {b:12.8f}, {c:12.8f}], ' 

134 f'{w:.8f}]{t} # {k}\n') 

135 k += 1 

136 return txt 

137 

138 def ranks(self, comm: MPIComm) -> Array1D: 

139 """Distribute k-points over MPI-communicator.""" 

140 return ranks(comm.size, len(self)) 

141 

142 

143def ranks(N, K) -> Array1D: 

144 """Distribute k-points over MPI-communicator. 

145 

146 >>> ranks(4, 6) 

147 array([0, 1, 2, 2, 3, 3]) 

148 """ 

149 n, x = divmod(K, N) 

150 rnks = np.empty(K, int) 

151 r = N - x 

152 for k in range(r * n): 

153 rnks[k] = k // n 

154 for k in range(r * n, K): 

155 rnks[k] = (k - r * n) // (n + 1) + r 

156 return rnks