Coverage for gpaw/point_groups/check.py: 96%

99 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1"""Symmetry checking code.""" 

2import sys 

3from typing import Any, Dict, List, Sequence, Union 

4 

5import numpy as np 

6from ase import Atoms 

7from numpy.linalg import det, inv, solve 

8from scipy.ndimage import map_coordinates 

9 

10from gpaw.typing import Array1D, Array2D, Array3D, ArrayLike 

11 

12from . import PointGroup 

13 

14Axis = Union[str, Sequence[float], Array1D, None] 

15 

16 

17class SymmetryChecker: 

18 def __init__(self, 

19 group: Union[str, PointGroup], 

20 center: ArrayLike, 

21 radius: float = 2.0, 

22 x: Axis = None, 

23 y: Axis = None, 

24 z: Axis = None, 

25 grid_spacing: float = 0.2): 

26 """Check point-group symmetries. 

27 

28 If a non-standard orientation is desired then two of 

29 *x*, *y*, *z* can be specified. 

30 """ 

31 if isinstance(group, str): 

32 group = PointGroup(group) 

33 self.group = group 

34 self.normalized_table = group.get_normalized_table() 

35 self.points = sphere(radius, grid_spacing) 

36 self.center = center 

37 self.grid_spacing = grid_spacing 

38 self.rotation = rotation_matrix([x, y, z]) 

39 

40 def check_atoms(self, atoms: Atoms, tol: float = 1e-5) -> bool: 

41 """Check if atoms have all the symmetries. 

42 

43 Unit of *tol* is Angstrom. 

44 """ 

45 numbers = atoms.numbers 

46 positions = (atoms.positions - self.center).dot(self.rotation.T) 

47 icell = np.linalg.inv(atoms.cell.dot(self.rotation.T)) 

48 for opname, op in self.group.operations.items(): 

49 P = positions.dot(op.T) 

50 for i, pos in enumerate(P): 

51 sdiff = (pos - positions).dot(icell) 

52 sdiff -= sdiff.round() * atoms.pbc 

53 dist2 = (sdiff.dot(atoms.cell)**2).sum(1) 

54 j = dist2.argmin() 

55 if dist2[j] > tol**2 or numbers[j] != numbers[i]: 

56 return False 

57 return True 

58 

59 def check_function(self, 

60 function: Array3D, 

61 grid_vectors: Array2D = None) -> Dict[str, Any]: 

62 """Check function on uniform grid.""" 

63 if grid_vectors is None: 

64 grid_vectors = np.eye(3) 

65 dv = abs(det(grid_vectors)) 

66 norm1 = (function**2).sum() * dv 

67 M = inv(grid_vectors).T 

68 overlaps: List[float] = [] 

69 for op in self.group.operations.values(): 

70 op = self.rotation.T @ op @ self.rotation 

71 pts = (self.points @ op.T + self.center) @ M.T 

72 pts %= function.shape 

73 values = map_coordinates(function, pts.T, mode='wrap') 

74 if not overlaps: 

75 values1 = values 

76 overlaps.append(values.dot(values1) * self.grid_spacing**3) 

77 

78 reduced_overlaps = [] 

79 i1 = 0 

80 for n in self.group.nops: 

81 i2 = i1 + n 

82 reduced_overlaps.append(sum(overlaps[i1:i2]) / n / overlaps[0]) 

83 i1 = i2 

84 

85 characters = solve(self.normalized_table.T, reduced_overlaps) 

86 best = self.group.symmetries[characters.argmax()] 

87 

88 return {'symmetry': best, 

89 'norm': norm1, 

90 'overlaps': overlaps, 

91 'characters': {symmetry: value 

92 for symmetry, value 

93 in zip(self.group.symmetries, characters)}} 

94 

95 def check_band(self, 

96 calc, 

97 band: int, 

98 spin: int = 0) -> Dict[str, Any]: 

99 """Check wave function from GPAW calculation.""" 

100 wfs = calc.get_pseudo_wave_function(band, spin=spin) 

101 grid_vectors = (calc.atoms.cell.T / wfs.shape).T 

102 return self.check_function(wfs, grid_vectors) 

103 

104 def check_calculation(self, 

105 calc, 

106 n1: int, 

107 n2: int, 

108 spin: int = 0, 

109 output: str = '-') -> None: 

110 """Check several wave functions from GPAW calculation.""" 

111 lines = ['band energy norm normcut best ' + 

112 ''.join(f'{sym:8}' for sym in self.group.symmetries)] 

113 n2 = n2 or calc.get_number_of_bands() 

114 for n in range(n1, n2): 

115 dct = self.check_band(calc, n, spin) 

116 best = dct['symmetry'] 

117 norm = dct['norm'] 

118 normcut = dct['overlaps'][0] 

119 eig = calc.get_eigenvalues(spin=spin)[n] 

120 lines.append( 

121 f'{n:4} {eig:9.3f} {norm:8.3f} {normcut:8.3f} {best:>8}' + 

122 ''.join(f'{x:8.3f}' 

123 for x in dct['characters'].values())) 

124 

125 fd = sys.stdout if output == '-' else open(output, 'w') 

126 fd.write('\n'.join(lines) + '\n') 

127 if output != '-': 

128 fd.close() 

129 

130 

131def sphere(radius: float, grid_spacing: float) -> Array2D: 

132 """Return sphere of grid-points. 

133 

134 >>> points = sphere(1.1, 1.0) 

135 >>> points.shape 

136 (7, 3) 

137 """ 

138 npts = int(radius / grid_spacing) + 1 

139 x = np.linspace(-npts, npts, 2 * npts + 1) * grid_spacing 

140 points = np.array(np.meshgrid(x, x, x, indexing='ij')).reshape((3, -1)).T 

141 points = points[(points**2).sum(1) <= radius**2] 

142 return points 

143 

144 

145def rotation_matrix(axes: Sequence[Axis]) -> Array3D: 

146 """Calculate rotation matrix. 

147 

148 >>> rotation_matrix(['-y', 'x', None]) 

149 array([[ 0, -1, 0], 

150 [ 1, 0, 0], 

151 [ 0, 0, 1]]) 

152 """ 

153 if all(axis is None for axis in axes): 

154 return np.eye(3) 

155 

156 j = -1 

157 for i, axis in enumerate(axes): 

158 if axis is None: 

159 assert j == -1 

160 j = i 

161 assert j != -1 

162 

163 axes = [normalize(axis) if axis is not None else None 

164 for axis in axes] 

165 axes[j] = np.cross(axes[j - 2], axes[j - 1]) # type: ignore 

166 

167 return np.array(axes) 

168 

169 

170def normalize(vector: Union[str, Sequence[float], Array1D]) -> Array1D: 

171 """Normalize a vector. 

172 

173 The *vector* must be a sequence of three numbers or one of the following 

174 strings: x, y, z, -x, -y, -z. 

175 """ 

176 if isinstance(vector, str): 

177 if vector[0] == '-': 

178 return -np.array(normalize(vector[1:])) 

179 return {'x': np.array([1, 0, 0]), 

180 'y': np.array([0, 1, 0]), 

181 'z': np.array([0, 0, 1])}[vector] 

182 return np.array(vector) / np.linalg.norm(vector)