Coverage for gpaw/wannier/edmiston_ruedenberg.py: 72%
39 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""Edmiston-ruedenberg localization."""
2from math import pi
4import numpy as np
6import gpaw.cgpaw as cgpaw
7from .overlaps import WannierOverlaps
8from .functions import WannierFunctions
11class LocalizationNotConvergedError(Exception):
12 """Error raised if maxiter is exceeded."""
15def localize(overlaps: WannierOverlaps,
16 maxiter: int = 100,
17 tolerance: float = 1e-5,
18 verbose: bool = not False) -> WannierFunctions:
19 """Simple localization.
21 Orthorhombic cell, no k-points, only occupied states.
22 """
23 if not overlaps.atoms.cell.orthorhombic:
24 raise NotImplementedError('An orthogonal cell is required')
25 assert (overlaps.monkhorst_pack_size == (1, 1, 1)).all()
27 Z_nnc = np.empty((overlaps.nbands, overlaps.nbands, 3), complex)
28 for c, direction in enumerate([(1, 0, 0), (0, 1, 0), (0, 0, 1)]):
29 Z_nnc[:, :, c] = overlaps.overlap(bz_index=0, direction=direction)
31 U_nn = np.identity(overlaps.nbands)
33 if verbose:
34 print('iter value change')
35 print('---- ---------- ----------')
37 old = 0.0
38 for iter in range(maxiter):
39 value = cgpaw.localize(Z_nnc, U_nn)
40 if verbose:
41 print(f'{iter:4} {value:10.3f} {value - old:10.6f}')
42 if value - old < tolerance:
43 break
44 old = value
45 else:
46 raise LocalizationNotConvergedError(
47 f'Did not converge in {maxiter} iterations')
49 # Find centers:
50 scaled_nc = -np.angle(Z_nnc.diagonal()).T / (2 * pi)
51 centers_nv = (scaled_nc % 1.0).dot(overlaps.atoms.cell)
53 return WannierFunctions(overlaps.atoms, centers_nv, value, [U_nn])
56if __name__ == '__main__':
57 import sys
58 from gpaw import GPAW
59 from gpaw.wannier.overlaps import calculate_overlaps
61 calc = GPAW(sys.argv[1])
62 nwannier = int(sys.argv[2])
63 overlaps = calculate_overlaps(calc, nwannier)
64 wan = overlaps.localize_er(verbose=True)
65 print(wan.centers)
66 wan.centers_as_atoms().edit()