Coverage for gpaw/new/symmetry.py: 87%
314 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1from __future__ import annotations
3from collections import defaultdict
4from functools import cached_property
5from typing import Any, Iterable, Sequence
7import numpy as np
8from ase import Atoms
9from ase.units import Bohr
10from gpaw import debug
11from gpaw.core.domain import normalize_cell
12from gpaw.new import zips
13from gpaw.rotation import rotation
14from gpaw.symmetry import Symmetry as OldSymmetry
15from gpaw.symmetry import frac
16from gpaw.typing import Array2D, Array3D, ArrayLike1D, ArrayLike2D, ArrayLike3D
19class SymmetryBrokenError(Exception):
20 """Broken-symmetry error."""
23def create_symmetries_object(atoms: Atoms,
24 *,
25 setup_ids: Sequence | None = None,
26 magmoms: ArrayLike2D | None = None,
27 rotations: ArrayLike3D | None = None,
28 translations: ArrayLike2D | None = None,
29 atommaps: ArrayLike2D | None = None,
30 extra_ids: Sequence[int] | None = None,
31 tolerance: float | None = None, # Å
32 point_group: bool = True,
33 symmorphic: bool = True,
34 _backwards_compatible=False) -> Symmetries:
35 """Find symmetries from atoms object.
37 >>> atoms = Atoms('H', cell=[1, 1, 1], pbc=True)
38 >>> sym = create_symmetries_object(atoms)
39 >>> len(sym)
40 48
41 >>> sym.rotation_scc.shape
42 (48, 3, 3)
43 """
44 cell_cv = atoms.cell.complete()
46 if tolerance is None:
47 tolerance = 1e-7 if _backwards_compatible else 1e-5
48 if _backwards_compatible:
49 cell_cv *= 1 / Bohr
51 # Create int atom-ids from setups, magmoms and user-supplied
52 # (extra_ids) ids:
53 if setup_ids is None:
54 ids = atoms.numbers
55 else:
56 ids = integer_ids(setup_ids)
57 if magmoms is not None:
58 ids = integer_ids((id, m) for id, m in zips(ids, safe_id(magmoms)))
59 if extra_ids is not None:
60 ids = integer_ids((id, x) for id, x in zips(ids, extra_ids))
62 if rotations is None:
63 # Find symmetries from cell, ids and positions:
64 if point_group:
65 sym = Symmetries.from_cell(
66 cell_cv,
67 pbc=atoms.pbc,
68 tolerance=tolerance,
69 _backwards_compatible=_backwards_compatible)
70 else:
71 # No symmetries (identity only):
72 sym = Symmetries(cell=cell_cv,
73 tolerance=tolerance,
74 _backwards_compatible=_backwards_compatible)
76 sym = sym.analyze_positions(
77 atoms.get_scaled_positions(),
78 ids=ids,
79 symmorphic=symmorphic)
80 else:
81 sym = Symmetries(cell=cell_cv,
82 rotations=rotations,
83 translations=translations,
84 atommaps=atommaps,
85 tolerance=tolerance,
86 _backwards_compatible=_backwards_compatible)
87 if atommaps is None:
88 sym = sym.with_atom_maps(atoms.get_scaled_positions(), ids=ids)
90 # Legacy:
91 sym._old_symmetry = OldSymmetry(
92 ids, cell_cv, atoms.pbc, tolerance,
93 point_group,
94 time_reversal='?',
95 symmorphic=symmorphic)
96 sym._old_symmetry.op_scc = sym.rotation_scc
97 sym._old_symmetry.ft_sc = sym.translation_sc
98 sym._old_symmetry.a_sa = sym.atommap_sa
99 sym._old_symmetry.has_inversion = sym.has_inversion
100 sym._old_symmetry.gcd_c = sym.gcd_c
102 return sym
105class Symmetries:
106 def __init__(self,
107 *,
108 cell: ArrayLike1D | ArrayLike2D,
109 rotations: ArrayLike3D | None = None,
110 translations: ArrayLike2D | None = None,
111 atommaps: ArrayLike2D | None = None,
112 tolerance: float | None = None,
113 _backwards_compatible=False):
114 """Symmetries object.
116 "Rotations" here means rotations, mirror and inversion operations.
118 Units of "cell" and "tolerance" should match.
120 >>> sym = Symmetries.from_cell([1, 2, 3])
121 >>> sym.has_inversion
122 True
123 >>> len(sym)
124 8
125 >>> sym2 = sym.analyze_positions([[0, 0, 0], [0, 0, 0.4]], ids=[1, 2])
126 >>> sym2.has_inversion
127 False
128 >>> len(sym2)
129 4
130 """
131 self.cell_cv = normalize_cell(cell)
132 if tolerance is None:
133 tolerance = 1e-7 if _backwards_compatible else 1e-5
134 self.tolerance = tolerance
135 self._backwards_compatible = _backwards_compatible
136 if rotations is None:
137 rotations = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]
138 self.rotation_scc = np.array(rotations, dtype=int)
139 assert (self.rotation_scc == rotations).all()
140 if translations is None:
141 self.translation_sc = np.zeros((len(self.rotation_scc), 3))
142 else:
143 self.translation_sc = np.array(translations)
144 if atommaps is None:
145 self.atommap_sa = np.empty((len(self.rotation_scc), 0), int)
146 else:
147 self.atommap_sa = np.array(atommaps)
148 assert self.atommap_sa.dtype == int
150 # Legacy stuff:
151 self.op_scc = self.rotation_scc # old name
152 self._old_symmetry: OldSymmetry
154 @cached_property
155 def symmorphic(self):
156 return not self.translation_sc.any()
158 @cached_property
159 def has_inversion(self):
160 inv_cc = -np.eye(3, dtype=int)
161 for r_cc, t_c in zip(self.rotation_scc, self.translation_sc):
162 if (r_cc == inv_cc).all() and not t_c.any():
163 return True
164 return False
166 @classmethod
167 def from_cell(cls,
168 cell: ArrayLike1D | ArrayLike2D,
169 *,
170 pbc: ArrayLike1D = (True, True, True),
171 tolerance: float | None = None,
172 _backwards_compatible=False) -> Symmetries:
173 if isinstance(pbc, int):
174 pbc = (pbc,) * 3
175 cell_cv = normalize_cell(cell)
176 if tolerance is None:
177 tolerance = 1e-7 if _backwards_compatible else 1e-5
178 rotation_scc = find_lattice_symmetry(cell_cv, pbc, tolerance,
179 _backwards_compatible)
180 return cls(cell=cell_cv,
181 rotations=rotation_scc,
182 tolerance=tolerance,
183 _backwards_compatible=_backwards_compatible)
185 def analyze_positions(self,
186 relative_positions: ArrayLike2D,
187 ids: Sequence[int],
188 *,
189 symmorphic: bool = True) -> Symmetries:
190 return prune_symmetries(
191 self, np.asarray(relative_positions), ids, symmorphic)
193 def with_atom_maps(self,
194 relative_positions: Array2D,
195 ids: Sequence[int]) -> Symmetries:
196 atommap_sa = np.empty((len(self), len(relative_positions)), int)
197 a_ij = defaultdict(list)
198 for a, id in enumerate(ids):
199 a_ij[id].append(a)
200 for U_cc, t_c, map_a in zip(self.rotation_scc,
201 self.translation_sc,
202 atommap_sa):
203 map_a[:] = self.check_one_symmetry(relative_positions,
204 U_cc, t_c, a_ij)
205 return Symmetries(cell=self.cell_cv,
206 rotations=self.rotation_scc,
207 translations=self.translation_sc,
208 atommaps=atommap_sa,
209 tolerance=self.tolerance,
210 _backwards_compatible=self._backwards_compatible)
212 @classmethod
213 def from_atoms(cls,
214 atoms,
215 *,
216 ids: Sequence[int] | None = None,
217 symmorphic: bool = True,
218 tolerance: float | None = None):
219 sym = cls.from_cell(atoms.cell,
220 pbc=atoms.pbc,
221 tolerance=tolerance)
222 if ids is None:
223 ids = atoms.numbers
224 return sym.analyze_positions(atoms.positions,
225 ids=ids,
226 symmorphic=symmorphic)
228 def __len__(self):
229 return len(self.rotation_scc)
231 def __str__(self):
232 lines = ['symmetry:',
233 f' number of symmetries: {len(self)}']
234 if self.symmorphic:
235 lines.append(' rotations: [')
236 for rot_cc in self.rotation_scc:
237 lines.append(f' {mat(rot_cc)},')
238 else:
239 nt = self.translation_sc.any(1).sum()
240 lines.append(f' number of symmetries with translation: {nt}')
241 lines.append(' rotations and translations: [')
242 for rot_cc, t_c in zips(self.rotation_scc, self.translation_sc):
243 a, b, c = t_c
244 lines.append(f' [{mat(rot_cc)}, '
245 f'[{a:6.3f}, {b:6.3f}, {c:6.3f}]],')
246 lines[-1] = lines[-1][:-1] + ']\n'
247 return '\n'.join(lines)
249 def check_positions(self, fracpos_ac):
250 for U_cc, t_c, b_a in zip(self.rotation_scc,
251 self.translation_sc,
252 self.atommap_sa):
253 error_ac = fracpos_ac @ U_cc - t_c - fracpos_ac[b_a]
254 error_ac -= error_ac.round()
255 if self._backwards_compatible:
256 if abs(error_ac).max() > self.tolerance:
257 raise SymmetryBrokenError
258 else:
259 error_av = error_ac @ self.cell_cv
260 if (error_av**2).sum(1).max() > self.tolerance**2:
261 raise SymmetryBrokenError
263 def symmetrize_forces(self, F0_av):
264 """Symmetrize forces."""
265 F_av = np.zeros_like(F0_av)
266 for map_a, op_cc in zip(self.atommap_sa, self.rotation_scc):
267 op_vv = np.linalg.inv(self.cell_cv) @ op_cc @ self.cell_cv
268 for a1, a2 in enumerate(map_a):
269 F_av[a2] += np.dot(F0_av[a1], op_vv)
270 return F_av / len(self)
272 def lcm(self) -> list[int]:
273 """Find least common multiple compatible with translations."""
274 return [np.lcm.reduce([frac(t, tol=1e-4)[1] for t in t_s])
275 for t_s in self.translation_sc.T]
277 @cached_property
278 def gcd_c(self):
279 # Needed for old gpaw.utilities.gpts.get_number_of_grid_points()
280 # function ...
281 return np.array(self.lcm())
283 def check_grid(self, N_c) -> bool:
284 """Check that symmetries are commensurate with grid."""
285 for U_cc, t_c in zip(self.rotation_scc, self.translation_sc):
286 t_c = t_c * N_c
287 # Make sure all grid-points map onto another grid-point:
288 if (((N_c * U_cc).T % N_c).any() or
289 not np.allclose(t_c, t_c.round())):
290 return False
291 return True
293 def check_one_symmetry(self,
294 spos_ac,
295 op_cc,
296 ft_c,
297 a_ia):
298 """Checks whether atoms satisfy one given symmetry operation."""
300 a_a = np.zeros(len(spos_ac), int)
301 for b_a in a_ia.values():
302 spos_jc = spos_ac[b_a]
303 for b in b_a:
304 spos_c = np.dot(spos_ac[b], op_cc)
305 sdiff_jc = spos_c - spos_jc - ft_c
306 sdiff_jc -= sdiff_jc.round()
307 if self._backwards_compatible:
308 indices = np.where(
309 abs(sdiff_jc).max(1) < self.tolerance)[0]
310 else:
311 sdiff_jv = sdiff_jc @ self.cell_cv
312 indices = np.where(
313 (sdiff_jv**2).sum(1) < self.tolerance**2)[0]
314 if len(indices) == 1:
315 a = indices[0]
316 a_a[b] = b_a[a]
317 else:
318 assert len(indices) == 0
319 return None
321 return a_a
324def find_lattice_symmetry(cell_cv, pbc_c, tol, _backwards_compatible=False):
325 """Determine list of symmetry operations."""
326 # Symmetry operations as matrices in 123 basis.
327 # Operation is a 3x3 matrix, with possible elements -1, 0, 1, thus
328 # there are 3**9 = 19683 possible matrices:
329 combinations = 1 - np.indices([3] * 9)
330 U_scc = combinations.reshape((3, 3, 3**9)).transpose((2, 0, 1))
332 # The metric of the cell should be conserved after applying
333 # the operation:
334 metric_cc = cell_cv.dot(cell_cv.T)
335 metric_scc = np.einsum('sij, jk, slk -> sil',
336 U_scc, metric_cc, U_scc,
337 optimize=True)
338 if _backwards_compatible:
339 mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= tol
340 else:
341 mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= tol**2
342 U_scc = U_scc[mask_s]
344 # Operation must not swap axes that don't have same PBC:
345 pbc_cc = np.logical_xor.outer(pbc_c, pbc_c)
346 mask_s = ~U_scc[:, pbc_cc].any(axis=1)
347 U_scc = U_scc[mask_s]
348 return U_scc
351def prune_symmetries(sym: Symmetries,
352 relpos_ac: Array2D,
353 id_a: Sequence[int],
354 symmorphic: bool = True) -> Symmetries:
355 """Remove symmetries that are not satisfied by the atoms."""
357 if len(relpos_ac) == 0:
358 return sym
360 # Build lists of atom numbers for each type of atom - one
361 # list for each combination of atomic number, setup type,
362 # magnetic moment and basis set:
363 a_ij = defaultdict(list)
364 for a, id in enumerate(id_a):
365 a_ij[id].append(a)
367 a_j = a_ij[id_a[0]] # just pick the first species
369 def check(op_cc, ft_c):
370 return sym.check_one_symmetry(relpos_ac, op_cc, ft_c, a_ij)
372 # if supercell disable fractional translations:
373 if not symmorphic:
374 op_cc = np.identity(3, int)
375 ftrans_sc = relpos_ac[a_j[1:]] - relpos_ac[a_j[0]]
376 ftrans_sc -= np.rint(ftrans_sc)
377 for ft_c in ftrans_sc:
378 a_a = check(op_cc, ft_c)
379 if a_a is not None:
380 symmorphic = True
381 break
383 symmetries = []
384 ftsymmetries = []
386 # go through all possible symmetry operations
387 for op_cc in sym.rotation_scc:
388 # first ignore fractional translations
389 a_a = check(op_cc, [0, 0, 0])
390 if a_a is not None:
391 symmetries.append((op_cc, [0, 0, 0], a_a))
392 elif not symmorphic:
393 # check fractional translations
394 sposrot_ac = np.dot(relpos_ac, op_cc)
395 ftrans_jc = sposrot_ac[a_j] - relpos_ac[a_j[0]]
396 ftrans_jc -= np.rint(ftrans_jc)
397 for ft_c in ftrans_jc:
398 a_a = check(op_cc, ft_c)
399 if a_a is not None:
400 ftsymmetries.append((op_cc, ft_c, a_a))
402 # Add symmetry operations with fractional translations at the end:
403 symmetries.extend(ftsymmetries)
405 sym = Symmetries(cell=sym.cell_cv,
406 rotations=[s[0] for s in symmetries],
407 translations=[s[1] for s in symmetries],
408 atommaps=[s[2] for s in symmetries],
409 tolerance=sym.tolerance,
410 _backwards_compatible=sym._backwards_compatible)
411 if debug:
412 sym.check_positions(relpos_ac)
413 return sym
416class SymmetrizationPlan:
417 def __init__(self,
418 symmetries: Symmetries,
419 l_aj):
420 self.symmetries = symmetries
421 self.l_aj = l_aj
422 self.rotation_svv = np.einsum('vc, scd, dw -> svw',
423 np.linalg.inv(symmetries.cell_cv),
424 symmetries.rotation_scc,
425 symmetries.cell_cv)
426 lmax = max((max(l_j) for l_j in l_aj), default=-1)
427 self.rotation_lsmm = [
428 np.array([rotation(l, r_vv) for r_vv in self.rotation_svv])
429 for l in range(lmax + 1)]
430 self._rotations: dict[tuple[int, ...], Array3D] = {}
432 def rotations(self, l_j, xp=np):
433 ells = tuple(l_j)
434 rotation_sii = self._rotations.get(ells)
435 if rotation_sii is None:
436 ni = sum(2 * l + 1 for l in l_j)
437 rotation_sii = np.zeros((len(self.symmetries), ni, ni))
438 i1 = 0
439 for l in l_j:
440 i2 = i1 + 2 * l + 1
441 rotation_sii[:, i1:i2, i1:i2] = self.rotation_lsmm[l]
442 i1 = i2
443 rotation_sii = xp.asarray(rotation_sii)
444 self._rotations[ells] = rotation_sii
445 return rotation_sii
447 def apply_distributed(self, D_asii, dist_D_asii):
448 for a1, D_sii in dist_D_asii.items():
449 D_sii[:] = 0.0
450 rotation_sii = self.rotations(self.l_aj[a1])
451 for a2, rotation_ii in zips(self.symmetries.atommap_sa[:, a1],
452 rotation_sii):
453 D_sii += np.einsum('ij, sjk, lk -> sil',
454 rotation_ii, D_asii[a2], rotation_ii)
455 dist_D_asii.data *= 1.0 / len(self.symmetries)
458class GPUSymmetrizationPlan(SymmetrizationPlan):
459 def __init__(self,
460 symmetries: Symmetries,
461 l_aj,
462 layout):
463 super().__init__(symmetries, l_aj)
465 xp = layout.xp
466 a_sa = symmetries.atommap_sa
468 ns = a_sa.shape[0] # Number of symmetries
469 na = a_sa.shape[1] # Number of atoms
471 if xp is np:
472 import scipy
473 sparse = scipy.sparse
474 else:
475 from gpaw.gpu import cupyx
476 sparse = cupyx.scipy.sparse
478 # Find orbits, i.e. point group action,
479 # which also equals to set of all cosets.
480 # In practical terms, these are just atoms which map
481 # to each other via symmetry operations.
482 # Mathematically {{as: s∈ S}: a∈ A}, where a is an atom.
483 cosets = {frozenset(a_sa[:, a]) for a in range(na)}
485 S_aZZ = {}
486 work = []
487 for coset in map(list, cosets):
488 nA = len(coset) # Number of atoms in this orbit
489 a = coset[0] # Representative atom for coset
491 # The atomic density matrices transform as
492 # ρ'_ii = R_sii ρ_ii R^T_sii
493 # Which equals to vec(ρ'_ii) = (R^s_ii ⊗ R^s_ii) vec(ρ_ii)
494 # Here we to the Kronecker product for each of the
495 # symmetry transformations.
496 R_sii = xp.asarray(self.rotations(l_aj[a], xp))
497 i2 = R_sii.shape[1]**2
498 R_sPP = xp.einsum('sab, scd -> sacbd', R_sii, R_sii)
499 R_sPP = R_sPP.reshape((ns, i2, i2)) / ns
501 S_ZZ = xp.zeros((nA * i2,) * 2)
503 # For each orbit, the symetrization operation is represented by
504 # a full matrix operating on a subset of indices to the full array.
505 for loca1, a1 in enumerate(coset):
506 Z1 = loca1 * i2
507 Z2 = Z1 + i2
508 for s, a2 in enumerate(a_sa[:, a1]):
509 loca2 = coset.index(a2)
510 Z3 = loca2 * i2
511 Z4 = Z3 + i2
512 S_ZZ[Z1:Z2, Z3:Z4] += R_sPP[s]
513 # Utilize sparse matrices if sizes get out of hand
514 # Limit is hard coded to 100MB per orbit
515 if S_ZZ.nbytes > 100 * 1024**2:
516 S_ZZ = sparse.csr_matrix(S_ZZ)
517 S_aZZ[a] = S_ZZ
518 indices = []
519 for loca1, a1 in enumerate(coset):
520 a1_, start, end = layout.myindices[a1]
521 # When parallelization is done, this needs to be rewritten
522 assert a1_ == a1
523 for X in range(i2):
524 indices.append(start + X)
525 work.append((a, xp.array(indices)))
527 self.work = work
528 self.S_aZZ = S_aZZ
529 self.xp = xp
531 def apply(self, source, target):
532 total = 0
533 for a, ind in self.work:
534 for spin in range(len(source)):
535 total += len(ind)
536 target[spin, ind] = self.S_aZZ[a] @ source[spin, ind]
537 assert total / len(source) == source.shape[1]
540def mat(rot_cc) -> str:
541 """Convert 3x3 matrix to str.
543 >>> mat([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
544 '[[-1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]]'
546 """
547 return '[[' + '], ['.join(', '.join(f'{r:2}'
548 for r in rot_c)
549 for rot_c in rot_cc) + ']]'
552def integer_ids(ids: Iterable) -> list[int]:
553 """Convert arbitrary ids to int ids.
555 >>> integer_ids([(1, 'a'), (12, 'b'), (1, 'a')])
556 [0, 1, 0]
557 """
558 dct: dict[Any, int] = {}
559 iids = []
560 for id in ids:
561 iid = dct.get(id)
562 if iid is None:
563 iid = len(dct)
564 dct[id] = iid
565 iids.append(iid)
566 return iids
569def safe_id(magmom_av, tolerance=1e-3):
570 """Convert magnetic moments to integer id's.
572 While calculating id's for atoms, there may be rounding errors
573 in magnetic moments supplied. This will create an unique integer
574 identifier for each magnetic moment double, based on the range
575 as set by the first occurence of each floating point number:
576 [magmom_a - tolerance, magmom_a + tolerance].
578 >>> safe_id([1.01, 0.99, 0.5], tolerance=0.025)
579 [0, 0, 2]
580 """
581 id_a = []
582 for a, magmom_v in enumerate(magmom_av):
583 quantized = None
584 for a2 in range(a):
585 if np.linalg.norm(magmom_av[a2] - magmom_v) < tolerance:
586 quantized = a2
587 break
588 if quantized is None:
589 quantized = a
590 id_a.append(quantized)
591 return id_a