Coverage for gpaw/response/symmetry.py: 99%

122 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1from __future__ import annotations 

2 

3from typing import Union 

4from dataclasses import dataclass 

5from collections.abc import Sequence 

6from functools import cached_property 

7 

8import numpy as np 

9 

10from gpaw.response.kpoints import KPointDomainGenerator 

11 

12 

13@dataclass 

14class QSymmetries(Sequence): 

15 """Symmetry operations for a given q-point. 

16 

17 We operate with several different symmetry indices: 

18 * u: indices the unitary symmetries of the system. Length is nU. 

19 * S: extended symmetry index. In addition to the unitary symmetries 

20 (first nU indices) it includes also symmetries generated by a 

21 unitary symmetry transformation *followed* by a time-reversal. 

22 Length is 2 * nU. 

23 * s: reduced symmetry index. Includes all the "S-symmetries" which map 

24 the q-point in question onto itself (up to a reciprocal lattice 

25 vector). May be reduced further, if some of the symmetries have been 

26 disabled. Length is q-dependent and depends on user input. 

27 """ 

28 q_c: np.ndarray 

29 U_ucc: np.ndarray # unitary symmetry transformations 

30 S_s: np.ndarray # extended symmetry index for each q-symmetry 

31 shift_sc: np.ndarray # reciprocal lattice shifts, G = (T)Uq - q 

32 

33 def __post_init__(self): 

34 self.nU = len(self.U_ucc) 

35 

36 def __len__(self): 

37 return len(self.S_s) 

38 

39 def __getitem__(self, s): 

40 return self.U_scc[s], self.sign_s[s], self.shift_sc[s] 

41 

42 def unioperator(self, S): 

43 return self.U_ucc[S % self.nU] 

44 

45 def timereversal(self, S): 

46 """Does the extended index S involve a time-reversal symmetry?""" 

47 return bool(S // self.nU) 

48 

49 def sign(self, S): 

50 """Flip the sign under time-reversal.""" 

51 if self.timereversal(S): 

52 return -1 

53 return 1 

54 

55 @cached_property 

56 def U_scc(self): 

57 return np.array([self.unioperator(S) for S in self.S_s]) 

58 

59 @cached_property 

60 def sign_s(self): 

61 return np.array([self.sign(S) for S in self.S_s]) 

62 

63 @cached_property 

64 def ndirect(self): 

65 """Number of direct symmetries.""" 

66 return sum(np.array(self.S_s) < self.nU) 

67 

68 @property 

69 def nindirect(self): 

70 """Number of indirect symmetries.""" 

71 return len(self) - self.ndirect 

72 

73 def description(self) -> str: 

74 """Return string description of symmetry operations.""" 

75 isl = ['\n'] 

76 nx = 6 # You are not allowed to use non-symmorphic syms (value 3) 

77 y = 0 

78 for y in range((len(self) + nx - 1) // nx): 

79 for c in range(3): 

80 tisl = [] 

81 for x in range(nx): 

82 s = x + y * nx 

83 if s == len(self): 

84 break 

85 U_cc, sign, _ = self[s] 

86 op_c = sign * U_cc[c] 

87 tisl.append(f' ({op_c[0]:2d} {op_c[1]:2d} {op_c[2]:2d})') 

88 tisl.append('\n') 

89 isl.append(''.join(tisl)) 

90 isl.append('\n') 

91 return ''.join(isl[:-1]) 

92 

93 

94QSymmetryInput = Union['QSymmetryAnalyzer', dict, bool] 

95 

96 

97@dataclass 

98class QSymmetryAnalyzer: 

99 """Identifies symmetries of the k-grid, under which q is invariant. 

100 

101 Parameters 

102 ---------- 

103 point_group : bool 

104 Use point group symmetry. 

105 time_reversal : bool 

106 Use time-reversal symmetry (if applicable). 

107 """ 

108 point_group: bool = True 

109 time_reversal: bool = True 

110 

111 @staticmethod 

112 def from_input(qsymmetry: QSymmetryInput) -> QSymmetryAnalyzer: 

113 if not isinstance(qsymmetry, QSymmetryAnalyzer): 

114 if isinstance(qsymmetry, dict): 

115 qsymmetry = QSymmetryAnalyzer(**qsymmetry) 

116 else: 

117 qsymmetry = QSymmetryAnalyzer( 

118 point_group=qsymmetry, time_reversal=qsymmetry) 

119 return qsymmetry 

120 

121 @property 

122 def disabled(self): 

123 return not (self.point_group or self.time_reversal) 

124 

125 @property 

126 def disabled_symmetry_info(self): 

127 if self.disabled: 

128 txt = '' 

129 elif not self.point_group: 

130 txt = 'point-group ' 

131 elif not self.time_reversal: 

132 txt = 'time-reversal ' 

133 else: 

134 return '' 

135 txt += 'symmetry has been manually disabled' 

136 return txt 

137 

138 def analysis_info(self, symmetries): 

139 dsinfo = self.disabled_symmetry_info 

140 return '\n'.join([ 

141 '', 

142 f'Symmetries of q_c{f" ({dsinfo})" if len(dsinfo) else ""}:', 

143 f' Direct symmetries (Uq -> q): {symmetries.ndirect}', 

144 f' Indirect symmetries (TUq -> q): {symmetries.nindirect}', 

145 f'In total {len(symmetries)} allowed symmetries.', 

146 symmetries.description()]) 

147 

148 def analyze(self, q_c, kpoints, context): 

149 """Analyze symmetries and set up KPointDomainGenerator.""" 

150 symmetries = self.analyze_symmetries(q_c, kpoints.kd) 

151 generator = KPointDomainGenerator(symmetries, kpoints) 

152 context.print(self.analysis_info(symmetries)) 

153 context.print(generator.get_infostring()) 

154 return symmetries, generator 

155 

156 def analyze_symmetries(self, q_c, kd): 

157 r"""Determine allowed symmetries. 

158 

159 An direct symmetry U must fulfill:: 

160 

161 U \mathbf{q} = q + \Delta 

162 

163 Under time-reversal (indirect) it must fulfill:: 

164 

165 -U \mathbf{q} = q + \Delta 

166 

167 where :math:`\Delta` is a reciprocal lattice vector. 

168 """ 

169 # Map q-point for each unitary symmetry 

170 U_ucc = kd.symmetry.op_scc # here s is the unitary symmetry index 

171 Uq_uc = np.dot(U_ucc, q_c) 

172 

173 # Direct and indirect -> global symmetries 

174 nU = len(U_ucc) 

175 nS = 2 * nU 

176 shift_Sc = np.zeros((nS, 3), float) 

177 is_qsymmetry_S = np.zeros(nS, bool) 

178 

179 # Identify direct symmetries 

180 # Check whether U q - q is integer (reciprocal lattice vector) 

181 dshift_uc = Uq_uc - q_c[np.newaxis] 

182 is_direct_symm_u = ( 

183 abs(dshift_uc - dshift_uc.round()) < kd.symmetry.tol 

184 ).all(axis=1) 

185 is_qsymmetry_S[:nU][is_direct_symm_u] = True 

186 shift_Sc[:nU] = dshift_uc 

187 

188 # Identify indirect symmetries 

189 # Check whether -U q - q is integer (reciprocal lattice vector) 

190 idshift_uc = -Uq_uc - q_c 

191 is_indirect_symm_u = ( 

192 abs(idshift_uc - idshift_uc.round()) < kd.symmetry.tol 

193 ).all(axis=1) 

194 is_qsymmetry_S[nU:][is_indirect_symm_u] = True 

195 shift_Sc[nU:] = idshift_uc 

196 

197 # The indices of the allowed symmetries 

198 S_s = is_qsymmetry_S.nonzero()[0] 

199 

200 # Set up symmetry filters 

201 def is_not_point_group(S): 

202 return (U_ucc[S % nU] == np.eye(3)).all() 

203 

204 def is_not_time_reversal(S): 

205 return not bool(S // nU) 

206 

207 def is_not_non_symmorphic(S): 

208 return not bool(kd.symmetry.ft_sc[S % nU].any()) 

209 

210 # Filter out point-group symmetry, if disabled 

211 if not self.point_group: 

212 S_s = list(filter(is_not_point_group, S_s)) 

213 

214 # Filter out time-reversal, if inapplicable or disabled 

215 if not kd.symmetry.time_reversal or \ 

216 kd.symmetry.has_inversion or \ 

217 not self.time_reversal: 

218 S_s = list(filter(is_not_time_reversal, S_s)) 

219 

220 # We always filter out non-symmorphic symmetries 

221 S_s = list(filter(is_not_non_symmorphic, S_s)) 

222 

223 # All shifts should now be integer 

224 shift_sc = shift_Sc[S_s] 

225 assert ( 

226 abs(shift_sc - shift_sc.round()) < kd.symmetry.tol 

227 ).all() 

228 return QSymmetries(q_c, U_ucc, S_s, shift_sc.round().astype(int))