Coverage for gpaw/response/symmetry.py: 99%
122 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1from __future__ import annotations
3from typing import Union
4from dataclasses import dataclass
5from collections.abc import Sequence
6from functools import cached_property
8import numpy as np
10from gpaw.response.kpoints import KPointDomainGenerator
13@dataclass
14class QSymmetries(Sequence):
15 """Symmetry operations for a given q-point.
17 We operate with several different symmetry indices:
18 * u: indices the unitary symmetries of the system. Length is nU.
19 * S: extended symmetry index. In addition to the unitary symmetries
20 (first nU indices) it includes also symmetries generated by a
21 unitary symmetry transformation *followed* by a time-reversal.
22 Length is 2 * nU.
23 * s: reduced symmetry index. Includes all the "S-symmetries" which map
24 the q-point in question onto itself (up to a reciprocal lattice
25 vector). May be reduced further, if some of the symmetries have been
26 disabled. Length is q-dependent and depends on user input.
27 """
28 q_c: np.ndarray
29 U_ucc: np.ndarray # unitary symmetry transformations
30 S_s: np.ndarray # extended symmetry index for each q-symmetry
31 shift_sc: np.ndarray # reciprocal lattice shifts, G = (T)Uq - q
33 def __post_init__(self):
34 self.nU = len(self.U_ucc)
36 def __len__(self):
37 return len(self.S_s)
39 def __getitem__(self, s):
40 return self.U_scc[s], self.sign_s[s], self.shift_sc[s]
42 def unioperator(self, S):
43 return self.U_ucc[S % self.nU]
45 def timereversal(self, S):
46 """Does the extended index S involve a time-reversal symmetry?"""
47 return bool(S // self.nU)
49 def sign(self, S):
50 """Flip the sign under time-reversal."""
51 if self.timereversal(S):
52 return -1
53 return 1
55 @cached_property
56 def U_scc(self):
57 return np.array([self.unioperator(S) for S in self.S_s])
59 @cached_property
60 def sign_s(self):
61 return np.array([self.sign(S) for S in self.S_s])
63 @cached_property
64 def ndirect(self):
65 """Number of direct symmetries."""
66 return sum(np.array(self.S_s) < self.nU)
68 @property
69 def nindirect(self):
70 """Number of indirect symmetries."""
71 return len(self) - self.ndirect
73 def description(self) -> str:
74 """Return string description of symmetry operations."""
75 isl = ['\n']
76 nx = 6 # You are not allowed to use non-symmorphic syms (value 3)
77 y = 0
78 for y in range((len(self) + nx - 1) // nx):
79 for c in range(3):
80 tisl = []
81 for x in range(nx):
82 s = x + y * nx
83 if s == len(self):
84 break
85 U_cc, sign, _ = self[s]
86 op_c = sign * U_cc[c]
87 tisl.append(f' ({op_c[0]:2d} {op_c[1]:2d} {op_c[2]:2d})')
88 tisl.append('\n')
89 isl.append(''.join(tisl))
90 isl.append('\n')
91 return ''.join(isl[:-1])
94QSymmetryInput = Union['QSymmetryAnalyzer', dict, bool]
97@dataclass
98class QSymmetryAnalyzer:
99 """Identifies symmetries of the k-grid, under which q is invariant.
101 Parameters
102 ----------
103 point_group : bool
104 Use point group symmetry.
105 time_reversal : bool
106 Use time-reversal symmetry (if applicable).
107 """
108 point_group: bool = True
109 time_reversal: bool = True
111 @staticmethod
112 def from_input(qsymmetry: QSymmetryInput) -> QSymmetryAnalyzer:
113 if not isinstance(qsymmetry, QSymmetryAnalyzer):
114 if isinstance(qsymmetry, dict):
115 qsymmetry = QSymmetryAnalyzer(**qsymmetry)
116 else:
117 qsymmetry = QSymmetryAnalyzer(
118 point_group=qsymmetry, time_reversal=qsymmetry)
119 return qsymmetry
121 @property
122 def disabled(self):
123 return not (self.point_group or self.time_reversal)
125 @property
126 def disabled_symmetry_info(self):
127 if self.disabled:
128 txt = ''
129 elif not self.point_group:
130 txt = 'point-group '
131 elif not self.time_reversal:
132 txt = 'time-reversal '
133 else:
134 return ''
135 txt += 'symmetry has been manually disabled'
136 return txt
138 def analysis_info(self, symmetries):
139 dsinfo = self.disabled_symmetry_info
140 return '\n'.join([
141 '',
142 f'Symmetries of q_c{f" ({dsinfo})" if len(dsinfo) else ""}:',
143 f' Direct symmetries (Uq -> q): {symmetries.ndirect}',
144 f' Indirect symmetries (TUq -> q): {symmetries.nindirect}',
145 f'In total {len(symmetries)} allowed symmetries.',
146 symmetries.description()])
148 def analyze(self, q_c, kpoints, context):
149 """Analyze symmetries and set up KPointDomainGenerator."""
150 symmetries = self.analyze_symmetries(q_c, kpoints.kd)
151 generator = KPointDomainGenerator(symmetries, kpoints)
152 context.print(self.analysis_info(symmetries))
153 context.print(generator.get_infostring())
154 return symmetries, generator
156 def analyze_symmetries(self, q_c, kd):
157 r"""Determine allowed symmetries.
159 An direct symmetry U must fulfill::
161 U \mathbf{q} = q + \Delta
163 Under time-reversal (indirect) it must fulfill::
165 -U \mathbf{q} = q + \Delta
167 where :math:`\Delta` is a reciprocal lattice vector.
168 """
169 # Map q-point for each unitary symmetry
170 U_ucc = kd.symmetry.op_scc # here s is the unitary symmetry index
171 Uq_uc = np.dot(U_ucc, q_c)
173 # Direct and indirect -> global symmetries
174 nU = len(U_ucc)
175 nS = 2 * nU
176 shift_Sc = np.zeros((nS, 3), float)
177 is_qsymmetry_S = np.zeros(nS, bool)
179 # Identify direct symmetries
180 # Check whether U q - q is integer (reciprocal lattice vector)
181 dshift_uc = Uq_uc - q_c[np.newaxis]
182 is_direct_symm_u = (
183 abs(dshift_uc - dshift_uc.round()) < kd.symmetry.tol
184 ).all(axis=1)
185 is_qsymmetry_S[:nU][is_direct_symm_u] = True
186 shift_Sc[:nU] = dshift_uc
188 # Identify indirect symmetries
189 # Check whether -U q - q is integer (reciprocal lattice vector)
190 idshift_uc = -Uq_uc - q_c
191 is_indirect_symm_u = (
192 abs(idshift_uc - idshift_uc.round()) < kd.symmetry.tol
193 ).all(axis=1)
194 is_qsymmetry_S[nU:][is_indirect_symm_u] = True
195 shift_Sc[nU:] = idshift_uc
197 # The indices of the allowed symmetries
198 S_s = is_qsymmetry_S.nonzero()[0]
200 # Set up symmetry filters
201 def is_not_point_group(S):
202 return (U_ucc[S % nU] == np.eye(3)).all()
204 def is_not_time_reversal(S):
205 return not bool(S // nU)
207 def is_not_non_symmorphic(S):
208 return not bool(kd.symmetry.ft_sc[S % nU].any())
210 # Filter out point-group symmetry, if disabled
211 if not self.point_group:
212 S_s = list(filter(is_not_point_group, S_s))
214 # Filter out time-reversal, if inapplicable or disabled
215 if not kd.symmetry.time_reversal or \
216 kd.symmetry.has_inversion or \
217 not self.time_reversal:
218 S_s = list(filter(is_not_time_reversal, S_s))
220 # We always filter out non-symmorphic symmetries
221 S_s = list(filter(is_not_non_symmorphic, S_s))
223 # All shifts should now be integer
224 shift_sc = shift_Sc[S_s]
225 assert (
226 abs(shift_sc - shift_sc.round()) < kd.symmetry.tol
227 ).all()
228 return QSymmetries(q_c, U_ucc, S_s, shift_sc.round().astype(int))