Coverage for gpaw/point_groups/check.py: 96%
99 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1"""Symmetry checking code."""
2import sys
3from typing import Any, Dict, List, Sequence, Union
5import numpy as np
6from ase import Atoms
7from numpy.linalg import det, inv, solve
8from scipy.ndimage import map_coordinates
10from gpaw.typing import Array1D, Array2D, Array3D, ArrayLike
12from . import PointGroup
14Axis = Union[str, Sequence[float], Array1D, None]
17class SymmetryChecker:
18 def __init__(self,
19 group: Union[str, PointGroup],
20 center: ArrayLike,
21 radius: float = 2.0,
22 x: Axis = None,
23 y: Axis = None,
24 z: Axis = None,
25 grid_spacing: float = 0.2):
26 """Check point-group symmetries.
28 If a non-standard orientation is desired then two of
29 *x*, *y*, *z* can be specified.
30 """
31 if isinstance(group, str):
32 group = PointGroup(group)
33 self.group = group
34 self.normalized_table = group.get_normalized_table()
35 self.points = sphere(radius, grid_spacing)
36 self.center = center
37 self.grid_spacing = grid_spacing
38 self.rotation = rotation_matrix([x, y, z])
40 def check_atoms(self, atoms: Atoms, tol: float = 1e-5) -> bool:
41 """Check if atoms have all the symmetries.
43 Unit of *tol* is Angstrom.
44 """
45 numbers = atoms.numbers
46 positions = (atoms.positions - self.center).dot(self.rotation.T)
47 icell = np.linalg.inv(atoms.cell.dot(self.rotation.T))
48 for opname, op in self.group.operations.items():
49 P = positions.dot(op.T)
50 for i, pos in enumerate(P):
51 sdiff = (pos - positions).dot(icell)
52 sdiff -= sdiff.round() * atoms.pbc
53 dist2 = (sdiff.dot(atoms.cell)**2).sum(1)
54 j = dist2.argmin()
55 if dist2[j] > tol**2 or numbers[j] != numbers[i]:
56 return False
57 return True
59 def check_function(self,
60 function: Array3D,
61 grid_vectors: Array2D = None) -> Dict[str, Any]:
62 """Check function on uniform grid."""
63 if grid_vectors is None:
64 grid_vectors = np.eye(3)
65 dv = abs(det(grid_vectors))
66 norm1 = (function**2).sum() * dv
67 M = inv(grid_vectors).T
68 overlaps: List[float] = []
69 for op in self.group.operations.values():
70 op = self.rotation.T @ op @ self.rotation
71 pts = (self.points @ op.T + self.center) @ M.T
72 pts %= function.shape
73 values = map_coordinates(function, pts.T, mode='wrap')
74 if not overlaps:
75 values1 = values
76 overlaps.append(values.dot(values1) * self.grid_spacing**3)
78 reduced_overlaps = []
79 i1 = 0
80 for n in self.group.nops:
81 i2 = i1 + n
82 reduced_overlaps.append(sum(overlaps[i1:i2]) / n / overlaps[0])
83 i1 = i2
85 characters = solve(self.normalized_table.T, reduced_overlaps)
86 best = self.group.symmetries[characters.argmax()]
88 return {'symmetry': best,
89 'norm': norm1,
90 'overlaps': overlaps,
91 'characters': {symmetry: value
92 for symmetry, value
93 in zip(self.group.symmetries, characters)}}
95 def check_band(self,
96 calc,
97 band: int,
98 spin: int = 0) -> Dict[str, Any]:
99 """Check wave function from GPAW calculation."""
100 wfs = calc.get_pseudo_wave_function(band, spin=spin)
101 grid_vectors = (calc.atoms.cell.T / wfs.shape).T
102 return self.check_function(wfs, grid_vectors)
104 def check_calculation(self,
105 calc,
106 n1: int,
107 n2: int,
108 spin: int = 0,
109 output: str = '-') -> None:
110 """Check several wave functions from GPAW calculation."""
111 lines = ['band energy norm normcut best ' +
112 ''.join(f'{sym:8}' for sym in self.group.symmetries)]
113 n2 = n2 or calc.get_number_of_bands()
114 for n in range(n1, n2):
115 dct = self.check_band(calc, n, spin)
116 best = dct['symmetry']
117 norm = dct['norm']
118 normcut = dct['overlaps'][0]
119 eig = calc.get_eigenvalues(spin=spin)[n]
120 lines.append(
121 f'{n:4} {eig:9.3f} {norm:8.3f} {normcut:8.3f} {best:>8}' +
122 ''.join(f'{x:8.3f}'
123 for x in dct['characters'].values()))
125 fd = sys.stdout if output == '-' else open(output, 'w')
126 fd.write('\n'.join(lines) + '\n')
127 if output != '-':
128 fd.close()
131def sphere(radius: float, grid_spacing: float) -> Array2D:
132 """Return sphere of grid-points.
134 >>> points = sphere(1.1, 1.0)
135 >>> points.shape
136 (7, 3)
137 """
138 npts = int(radius / grid_spacing) + 1
139 x = np.linspace(-npts, npts, 2 * npts + 1) * grid_spacing
140 points = np.array(np.meshgrid(x, x, x, indexing='ij')).reshape((3, -1)).T
141 points = points[(points**2).sum(1) <= radius**2]
142 return points
145def rotation_matrix(axes: Sequence[Axis]) -> Array3D:
146 """Calculate rotation matrix.
148 >>> rotation_matrix(['-y', 'x', None])
149 array([[ 0, -1, 0],
150 [ 1, 0, 0],
151 [ 0, 0, 1]])
152 """
153 if all(axis is None for axis in axes):
154 return np.eye(3)
156 j = -1
157 for i, axis in enumerate(axes):
158 if axis is None:
159 assert j == -1
160 j = i
161 assert j != -1
163 axes = [normalize(axis) if axis is not None else None
164 for axis in axes]
165 axes[j] = np.cross(axes[j - 2], axes[j - 1]) # type: ignore
167 return np.array(axes)
170def normalize(vector: Union[str, Sequence[float], Array1D]) -> Array1D:
171 """Normalize a vector.
173 The *vector* must be a sequence of three numbers or one of the following
174 strings: x, y, z, -x, -y, -z.
175 """
176 if isinstance(vector, str):
177 if vector[0] == '-':
178 return -np.array(normalize(vector[1:]))
179 return {'x': np.array([1, 0, 0]),
180 'y': np.array([0, 1, 0]),
181 'z': np.array([0, 0, 1])}[vector]
182 return np.array(vector) / np.linalg.norm(vector)