Coverage for gpaw/new/symmetry.py: 87%

314 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1from __future__ import annotations 

2 

3from collections import defaultdict 

4from functools import cached_property 

5from typing import Any, Iterable, Sequence 

6 

7import numpy as np 

8from ase import Atoms 

9from ase.units import Bohr 

10from gpaw import debug 

11from gpaw.core.domain import normalize_cell 

12from gpaw.new import zips 

13from gpaw.rotation import rotation 

14from gpaw.symmetry import Symmetry as OldSymmetry 

15from gpaw.symmetry import frac 

16from gpaw.typing import Array2D, Array3D, ArrayLike1D, ArrayLike2D, ArrayLike3D 

17 

18 

19class SymmetryBrokenError(Exception): 

20 """Broken-symmetry error.""" 

21 

22 

23def create_symmetries_object(atoms: Atoms, 

24 *, 

25 setup_ids: Sequence | None = None, 

26 magmoms: ArrayLike2D | None = None, 

27 rotations: ArrayLike3D | None = None, 

28 translations: ArrayLike2D | None = None, 

29 atommaps: ArrayLike2D | None = None, 

30 extra_ids: Sequence[int] | None = None, 

31 tolerance: float | None = None, # Å 

32 point_group: bool = True, 

33 symmorphic: bool = True, 

34 _backwards_compatible=False) -> Symmetries: 

35 """Find symmetries from atoms object. 

36 

37 >>> atoms = Atoms('H', cell=[1, 1, 1], pbc=True) 

38 >>> sym = create_symmetries_object(atoms) 

39 >>> len(sym) 

40 48 

41 >>> sym.rotation_scc.shape 

42 (48, 3, 3) 

43 """ 

44 cell_cv = atoms.cell.complete() 

45 

46 if tolerance is None: 

47 tolerance = 1e-7 if _backwards_compatible else 1e-5 

48 if _backwards_compatible: 

49 cell_cv *= 1 / Bohr 

50 

51 # Create int atom-ids from setups, magmoms and user-supplied 

52 # (extra_ids) ids: 

53 if setup_ids is None: 

54 ids = atoms.numbers 

55 else: 

56 ids = integer_ids(setup_ids) 

57 if magmoms is not None: 

58 ids = integer_ids((id, m) for id, m in zips(ids, safe_id(magmoms))) 

59 if extra_ids is not None: 

60 ids = integer_ids((id, x) for id, x in zips(ids, extra_ids)) 

61 

62 if rotations is None: 

63 # Find symmetries from cell, ids and positions: 

64 if point_group: 

65 sym = Symmetries.from_cell( 

66 cell_cv, 

67 pbc=atoms.pbc, 

68 tolerance=tolerance, 

69 _backwards_compatible=_backwards_compatible) 

70 else: 

71 # No symmetries (identity only): 

72 sym = Symmetries(cell=cell_cv, 

73 tolerance=tolerance, 

74 _backwards_compatible=_backwards_compatible) 

75 

76 sym = sym.analyze_positions( 

77 atoms.get_scaled_positions(), 

78 ids=ids, 

79 symmorphic=symmorphic) 

80 else: 

81 sym = Symmetries(cell=cell_cv, 

82 rotations=rotations, 

83 translations=translations, 

84 atommaps=atommaps, 

85 tolerance=tolerance, 

86 _backwards_compatible=_backwards_compatible) 

87 if atommaps is None: 

88 sym = sym.with_atom_maps(atoms.get_scaled_positions(), ids=ids) 

89 

90 # Legacy: 

91 sym._old_symmetry = OldSymmetry( 

92 ids, cell_cv, atoms.pbc, tolerance, 

93 point_group, 

94 time_reversal='?', 

95 symmorphic=symmorphic) 

96 sym._old_symmetry.op_scc = sym.rotation_scc 

97 sym._old_symmetry.ft_sc = sym.translation_sc 

98 sym._old_symmetry.a_sa = sym.atommap_sa 

99 sym._old_symmetry.has_inversion = sym.has_inversion 

100 sym._old_symmetry.gcd_c = sym.gcd_c 

101 

102 return sym 

103 

104 

105class Symmetries: 

106 def __init__(self, 

107 *, 

108 cell: ArrayLike1D | ArrayLike2D, 

109 rotations: ArrayLike3D | None = None, 

110 translations: ArrayLike2D | None = None, 

111 atommaps: ArrayLike2D | None = None, 

112 tolerance: float | None = None, 

113 _backwards_compatible=False): 

114 """Symmetries object. 

115 

116 "Rotations" here means rotations, mirror and inversion operations. 

117 

118 Units of "cell" and "tolerance" should match. 

119 

120 >>> sym = Symmetries.from_cell([1, 2, 3]) 

121 >>> sym.has_inversion 

122 True 

123 >>> len(sym) 

124 8 

125 >>> sym2 = sym.analyze_positions([[0, 0, 0], [0, 0, 0.4]], ids=[1, 2]) 

126 >>> sym2.has_inversion 

127 False 

128 >>> len(sym2) 

129 4 

130 """ 

131 self.cell_cv = normalize_cell(cell) 

132 if tolerance is None: 

133 tolerance = 1e-7 if _backwards_compatible else 1e-5 

134 self.tolerance = tolerance 

135 self._backwards_compatible = _backwards_compatible 

136 if rotations is None: 

137 rotations = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]] 

138 self.rotation_scc = np.array(rotations, dtype=int) 

139 assert (self.rotation_scc == rotations).all() 

140 if translations is None: 

141 self.translation_sc = np.zeros((len(self.rotation_scc), 3)) 

142 else: 

143 self.translation_sc = np.array(translations) 

144 if atommaps is None: 

145 self.atommap_sa = np.empty((len(self.rotation_scc), 0), int) 

146 else: 

147 self.atommap_sa = np.array(atommaps) 

148 assert self.atommap_sa.dtype == int 

149 

150 # Legacy stuff: 

151 self.op_scc = self.rotation_scc # old name 

152 self._old_symmetry: OldSymmetry 

153 

154 @cached_property 

155 def symmorphic(self): 

156 return not self.translation_sc.any() 

157 

158 @cached_property 

159 def has_inversion(self): 

160 inv_cc = -np.eye(3, dtype=int) 

161 for r_cc, t_c in zip(self.rotation_scc, self.translation_sc): 

162 if (r_cc == inv_cc).all() and not t_c.any(): 

163 return True 

164 return False 

165 

166 @classmethod 

167 def from_cell(cls, 

168 cell: ArrayLike1D | ArrayLike2D, 

169 *, 

170 pbc: ArrayLike1D = (True, True, True), 

171 tolerance: float | None = None, 

172 _backwards_compatible=False) -> Symmetries: 

173 if isinstance(pbc, int): 

174 pbc = (pbc,) * 3 

175 cell_cv = normalize_cell(cell) 

176 if tolerance is None: 

177 tolerance = 1e-7 if _backwards_compatible else 1e-5 

178 rotation_scc = find_lattice_symmetry(cell_cv, pbc, tolerance, 

179 _backwards_compatible) 

180 return cls(cell=cell_cv, 

181 rotations=rotation_scc, 

182 tolerance=tolerance, 

183 _backwards_compatible=_backwards_compatible) 

184 

185 def analyze_positions(self, 

186 relative_positions: ArrayLike2D, 

187 ids: Sequence[int], 

188 *, 

189 symmorphic: bool = True) -> Symmetries: 

190 return prune_symmetries( 

191 self, np.asarray(relative_positions), ids, symmorphic) 

192 

193 def with_atom_maps(self, 

194 relative_positions: Array2D, 

195 ids: Sequence[int]) -> Symmetries: 

196 atommap_sa = np.empty((len(self), len(relative_positions)), int) 

197 a_ij = defaultdict(list) 

198 for a, id in enumerate(ids): 

199 a_ij[id].append(a) 

200 for U_cc, t_c, map_a in zip(self.rotation_scc, 

201 self.translation_sc, 

202 atommap_sa): 

203 map_a[:] = self.check_one_symmetry(relative_positions, 

204 U_cc, t_c, a_ij) 

205 return Symmetries(cell=self.cell_cv, 

206 rotations=self.rotation_scc, 

207 translations=self.translation_sc, 

208 atommaps=atommap_sa, 

209 tolerance=self.tolerance, 

210 _backwards_compatible=self._backwards_compatible) 

211 

212 @classmethod 

213 def from_atoms(cls, 

214 atoms, 

215 *, 

216 ids: Sequence[int] | None = None, 

217 symmorphic: bool = True, 

218 tolerance: float | None = None): 

219 sym = cls.from_cell(atoms.cell, 

220 pbc=atoms.pbc, 

221 tolerance=tolerance) 

222 if ids is None: 

223 ids = atoms.numbers 

224 return sym.analyze_positions(atoms.positions, 

225 ids=ids, 

226 symmorphic=symmorphic) 

227 

228 def __len__(self): 

229 return len(self.rotation_scc) 

230 

231 def __str__(self): 

232 lines = ['symmetry:', 

233 f' number of symmetries: {len(self)}'] 

234 if self.symmorphic: 

235 lines.append(' rotations: [') 

236 for rot_cc in self.rotation_scc: 

237 lines.append(f' {mat(rot_cc)},') 

238 else: 

239 nt = self.translation_sc.any(1).sum() 

240 lines.append(f' number of symmetries with translation: {nt}') 

241 lines.append(' rotations and translations: [') 

242 for rot_cc, t_c in zips(self.rotation_scc, self.translation_sc): 

243 a, b, c = t_c 

244 lines.append(f' [{mat(rot_cc)}, ' 

245 f'[{a:6.3f}, {b:6.3f}, {c:6.3f}]],') 

246 lines[-1] = lines[-1][:-1] + ']\n' 

247 return '\n'.join(lines) 

248 

249 def check_positions(self, fracpos_ac): 

250 for U_cc, t_c, b_a in zip(self.rotation_scc, 

251 self.translation_sc, 

252 self.atommap_sa): 

253 error_ac = fracpos_ac @ U_cc - t_c - fracpos_ac[b_a] 

254 error_ac -= error_ac.round() 

255 if self._backwards_compatible: 

256 if abs(error_ac).max() > self.tolerance: 

257 raise SymmetryBrokenError 

258 else: 

259 error_av = error_ac @ self.cell_cv 

260 if (error_av**2).sum(1).max() > self.tolerance**2: 

261 raise SymmetryBrokenError 

262 

263 def symmetrize_forces(self, F0_av): 

264 """Symmetrize forces.""" 

265 F_av = np.zeros_like(F0_av) 

266 for map_a, op_cc in zip(self.atommap_sa, self.rotation_scc): 

267 op_vv = np.linalg.inv(self.cell_cv) @ op_cc @ self.cell_cv 

268 for a1, a2 in enumerate(map_a): 

269 F_av[a2] += np.dot(F0_av[a1], op_vv) 

270 return F_av / len(self) 

271 

272 def lcm(self) -> list[int]: 

273 """Find least common multiple compatible with translations.""" 

274 return [np.lcm.reduce([frac(t, tol=1e-4)[1] for t in t_s]) 

275 for t_s in self.translation_sc.T] 

276 

277 @cached_property 

278 def gcd_c(self): 

279 # Needed for old gpaw.utilities.gpts.get_number_of_grid_points() 

280 # function ... 

281 return np.array(self.lcm()) 

282 

283 def check_grid(self, N_c) -> bool: 

284 """Check that symmetries are commensurate with grid.""" 

285 for U_cc, t_c in zip(self.rotation_scc, self.translation_sc): 

286 t_c = t_c * N_c 

287 # Make sure all grid-points map onto another grid-point: 

288 if (((N_c * U_cc).T % N_c).any() or 

289 not np.allclose(t_c, t_c.round())): 

290 return False 

291 return True 

292 

293 def check_one_symmetry(self, 

294 spos_ac, 

295 op_cc, 

296 ft_c, 

297 a_ia): 

298 """Checks whether atoms satisfy one given symmetry operation.""" 

299 

300 a_a = np.zeros(len(spos_ac), int) 

301 for b_a in a_ia.values(): 

302 spos_jc = spos_ac[b_a] 

303 for b in b_a: 

304 spos_c = np.dot(spos_ac[b], op_cc) 

305 sdiff_jc = spos_c - spos_jc - ft_c 

306 sdiff_jc -= sdiff_jc.round() 

307 if self._backwards_compatible: 

308 indices = np.where( 

309 abs(sdiff_jc).max(1) < self.tolerance)[0] 

310 else: 

311 sdiff_jv = sdiff_jc @ self.cell_cv 

312 indices = np.where( 

313 (sdiff_jv**2).sum(1) < self.tolerance**2)[0] 

314 if len(indices) == 1: 

315 a = indices[0] 

316 a_a[b] = b_a[a] 

317 else: 

318 assert len(indices) == 0 

319 return None 

320 

321 return a_a 

322 

323 

324def find_lattice_symmetry(cell_cv, pbc_c, tol, _backwards_compatible=False): 

325 """Determine list of symmetry operations.""" 

326 # Symmetry operations as matrices in 123 basis. 

327 # Operation is a 3x3 matrix, with possible elements -1, 0, 1, thus 

328 # there are 3**9 = 19683 possible matrices: 

329 combinations = 1 - np.indices([3] * 9) 

330 U_scc = combinations.reshape((3, 3, 3**9)).transpose((2, 0, 1)) 

331 

332 # The metric of the cell should be conserved after applying 

333 # the operation: 

334 metric_cc = cell_cv.dot(cell_cv.T) 

335 metric_scc = np.einsum('sij, jk, slk -> sil', 

336 U_scc, metric_cc, U_scc, 

337 optimize=True) 

338 if _backwards_compatible: 

339 mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= tol 

340 else: 

341 mask_s = abs(metric_scc - metric_cc).sum(2).sum(1) <= tol**2 

342 U_scc = U_scc[mask_s] 

343 

344 # Operation must not swap axes that don't have same PBC: 

345 pbc_cc = np.logical_xor.outer(pbc_c, pbc_c) 

346 mask_s = ~U_scc[:, pbc_cc].any(axis=1) 

347 U_scc = U_scc[mask_s] 

348 return U_scc 

349 

350 

351def prune_symmetries(sym: Symmetries, 

352 relpos_ac: Array2D, 

353 id_a: Sequence[int], 

354 symmorphic: bool = True) -> Symmetries: 

355 """Remove symmetries that are not satisfied by the atoms.""" 

356 

357 if len(relpos_ac) == 0: 

358 return sym 

359 

360 # Build lists of atom numbers for each type of atom - one 

361 # list for each combination of atomic number, setup type, 

362 # magnetic moment and basis set: 

363 a_ij = defaultdict(list) 

364 for a, id in enumerate(id_a): 

365 a_ij[id].append(a) 

366 

367 a_j = a_ij[id_a[0]] # just pick the first species 

368 

369 def check(op_cc, ft_c): 

370 return sym.check_one_symmetry(relpos_ac, op_cc, ft_c, a_ij) 

371 

372 # if supercell disable fractional translations: 

373 if not symmorphic: 

374 op_cc = np.identity(3, int) 

375 ftrans_sc = relpos_ac[a_j[1:]] - relpos_ac[a_j[0]] 

376 ftrans_sc -= np.rint(ftrans_sc) 

377 for ft_c in ftrans_sc: 

378 a_a = check(op_cc, ft_c) 

379 if a_a is not None: 

380 symmorphic = True 

381 break 

382 

383 symmetries = [] 

384 ftsymmetries = [] 

385 

386 # go through all possible symmetry operations 

387 for op_cc in sym.rotation_scc: 

388 # first ignore fractional translations 

389 a_a = check(op_cc, [0, 0, 0]) 

390 if a_a is not None: 

391 symmetries.append((op_cc, [0, 0, 0], a_a)) 

392 elif not symmorphic: 

393 # check fractional translations 

394 sposrot_ac = np.dot(relpos_ac, op_cc) 

395 ftrans_jc = sposrot_ac[a_j] - relpos_ac[a_j[0]] 

396 ftrans_jc -= np.rint(ftrans_jc) 

397 for ft_c in ftrans_jc: 

398 a_a = check(op_cc, ft_c) 

399 if a_a is not None: 

400 ftsymmetries.append((op_cc, ft_c, a_a)) 

401 

402 # Add symmetry operations with fractional translations at the end: 

403 symmetries.extend(ftsymmetries) 

404 

405 sym = Symmetries(cell=sym.cell_cv, 

406 rotations=[s[0] for s in symmetries], 

407 translations=[s[1] for s in symmetries], 

408 atommaps=[s[2] for s in symmetries], 

409 tolerance=sym.tolerance, 

410 _backwards_compatible=sym._backwards_compatible) 

411 if debug: 

412 sym.check_positions(relpos_ac) 

413 return sym 

414 

415 

416class SymmetrizationPlan: 

417 def __init__(self, 

418 symmetries: Symmetries, 

419 l_aj): 

420 self.symmetries = symmetries 

421 self.l_aj = l_aj 

422 self.rotation_svv = np.einsum('vc, scd, dw -> svw', 

423 np.linalg.inv(symmetries.cell_cv), 

424 symmetries.rotation_scc, 

425 symmetries.cell_cv) 

426 lmax = max((max(l_j) for l_j in l_aj), default=-1) 

427 self.rotation_lsmm = [ 

428 np.array([rotation(l, r_vv) for r_vv in self.rotation_svv]) 

429 for l in range(lmax + 1)] 

430 self._rotations: dict[tuple[int, ...], Array3D] = {} 

431 

432 def rotations(self, l_j, xp=np): 

433 ells = tuple(l_j) 

434 rotation_sii = self._rotations.get(ells) 

435 if rotation_sii is None: 

436 ni = sum(2 * l + 1 for l in l_j) 

437 rotation_sii = np.zeros((len(self.symmetries), ni, ni)) 

438 i1 = 0 

439 for l in l_j: 

440 i2 = i1 + 2 * l + 1 

441 rotation_sii[:, i1:i2, i1:i2] = self.rotation_lsmm[l] 

442 i1 = i2 

443 rotation_sii = xp.asarray(rotation_sii) 

444 self._rotations[ells] = rotation_sii 

445 return rotation_sii 

446 

447 def apply_distributed(self, D_asii, dist_D_asii): 

448 for a1, D_sii in dist_D_asii.items(): 

449 D_sii[:] = 0.0 

450 rotation_sii = self.rotations(self.l_aj[a1]) 

451 for a2, rotation_ii in zips(self.symmetries.atommap_sa[:, a1], 

452 rotation_sii): 

453 D_sii += np.einsum('ij, sjk, lk -> sil', 

454 rotation_ii, D_asii[a2], rotation_ii) 

455 dist_D_asii.data *= 1.0 / len(self.symmetries) 

456 

457 

458class GPUSymmetrizationPlan(SymmetrizationPlan): 

459 def __init__(self, 

460 symmetries: Symmetries, 

461 l_aj, 

462 layout): 

463 super().__init__(symmetries, l_aj) 

464 

465 xp = layout.xp 

466 a_sa = symmetries.atommap_sa 

467 

468 ns = a_sa.shape[0] # Number of symmetries 

469 na = a_sa.shape[1] # Number of atoms 

470 

471 if xp is np: 

472 import scipy 

473 sparse = scipy.sparse 

474 else: 

475 from gpaw.gpu import cupyx 

476 sparse = cupyx.scipy.sparse 

477 

478 # Find orbits, i.e. point group action, 

479 # which also equals to set of all cosets. 

480 # In practical terms, these are just atoms which map 

481 # to each other via symmetry operations. 

482 # Mathematically {{as: s∈ S}: a∈ A}, where a is an atom. 

483 cosets = {frozenset(a_sa[:, a]) for a in range(na)} 

484 

485 S_aZZ = {} 

486 work = [] 

487 for coset in map(list, cosets): 

488 nA = len(coset) # Number of atoms in this orbit 

489 a = coset[0] # Representative atom for coset 

490 

491 # The atomic density matrices transform as 

492 # ρ'_ii = R_sii ρ_ii R^T_sii 

493 # Which equals to vec(ρ'_ii) = (R^s_ii ⊗ R^s_ii) vec(ρ_ii) 

494 # Here we to the Kronecker product for each of the 

495 # symmetry transformations. 

496 R_sii = xp.asarray(self.rotations(l_aj[a], xp)) 

497 i2 = R_sii.shape[1]**2 

498 R_sPP = xp.einsum('sab, scd -> sacbd', R_sii, R_sii) 

499 R_sPP = R_sPP.reshape((ns, i2, i2)) / ns 

500 

501 S_ZZ = xp.zeros((nA * i2,) * 2) 

502 

503 # For each orbit, the symetrization operation is represented by 

504 # a full matrix operating on a subset of indices to the full array. 

505 for loca1, a1 in enumerate(coset): 

506 Z1 = loca1 * i2 

507 Z2 = Z1 + i2 

508 for s, a2 in enumerate(a_sa[:, a1]): 

509 loca2 = coset.index(a2) 

510 Z3 = loca2 * i2 

511 Z4 = Z3 + i2 

512 S_ZZ[Z1:Z2, Z3:Z4] += R_sPP[s] 

513 # Utilize sparse matrices if sizes get out of hand 

514 # Limit is hard coded to 100MB per orbit 

515 if S_ZZ.nbytes > 100 * 1024**2: 

516 S_ZZ = sparse.csr_matrix(S_ZZ) 

517 S_aZZ[a] = S_ZZ 

518 indices = [] 

519 for loca1, a1 in enumerate(coset): 

520 a1_, start, end = layout.myindices[a1] 

521 # When parallelization is done, this needs to be rewritten 

522 assert a1_ == a1 

523 for X in range(i2): 

524 indices.append(start + X) 

525 work.append((a, xp.array(indices))) 

526 

527 self.work = work 

528 self.S_aZZ = S_aZZ 

529 self.xp = xp 

530 

531 def apply(self, source, target): 

532 total = 0 

533 for a, ind in self.work: 

534 for spin in range(len(source)): 

535 total += len(ind) 

536 target[spin, ind] = self.S_aZZ[a] @ source[spin, ind] 

537 assert total / len(source) == source.shape[1] 

538 

539 

540def mat(rot_cc) -> str: 

541 """Convert 3x3 matrix to str. 

542 

543 >>> mat([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) 

544 '[[-1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]]' 

545 

546 """ 

547 return '[[' + '], ['.join(', '.join(f'{r:2}' 

548 for r in rot_c) 

549 for rot_c in rot_cc) + ']]' 

550 

551 

552def integer_ids(ids: Iterable) -> list[int]: 

553 """Convert arbitrary ids to int ids. 

554 

555 >>> integer_ids([(1, 'a'), (12, 'b'), (1, 'a')]) 

556 [0, 1, 0] 

557 """ 

558 dct: dict[Any, int] = {} 

559 iids = [] 

560 for id in ids: 

561 iid = dct.get(id) 

562 if iid is None: 

563 iid = len(dct) 

564 dct[id] = iid 

565 iids.append(iid) 

566 return iids 

567 

568 

569def safe_id(magmom_av, tolerance=1e-3): 

570 """Convert magnetic moments to integer id's. 

571 

572 While calculating id's for atoms, there may be rounding errors 

573 in magnetic moments supplied. This will create an unique integer 

574 identifier for each magnetic moment double, based on the range 

575 as set by the first occurence of each floating point number: 

576 [magmom_a - tolerance, magmom_a + tolerance]. 

577 

578 >>> safe_id([1.01, 0.99, 0.5], tolerance=0.025) 

579 [0, 0, 2] 

580 """ 

581 id_a = [] 

582 for a, magmom_v in enumerate(magmom_av): 

583 quantized = None 

584 for a2 in range(a): 

585 if np.linalg.norm(magmom_av[a2] - magmom_v) < tolerance: 

586 quantized = a2 

587 break 

588 if quantized is None: 

589 quantized = a 

590 id_a.append(quantized) 

591 return id_a