Coverage for gpaw/wannier/edmiston_ruedenberg.py: 72%

39 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1"""Edmiston-ruedenberg localization.""" 

2from math import pi 

3 

4import numpy as np 

5 

6import gpaw.cgpaw as cgpaw 

7from .overlaps import WannierOverlaps 

8from .functions import WannierFunctions 

9 

10 

11class LocalizationNotConvergedError(Exception): 

12 """Error raised if maxiter is exceeded.""" 

13 

14 

15def localize(overlaps: WannierOverlaps, 

16 maxiter: int = 100, 

17 tolerance: float = 1e-5, 

18 verbose: bool = not False) -> WannierFunctions: 

19 """Simple localization. 

20 

21 Orthorhombic cell, no k-points, only occupied states. 

22 """ 

23 if not overlaps.atoms.cell.orthorhombic: 

24 raise NotImplementedError('An orthogonal cell is required') 

25 assert (overlaps.monkhorst_pack_size == (1, 1, 1)).all() 

26 

27 Z_nnc = np.empty((overlaps.nbands, overlaps.nbands, 3), complex) 

28 for c, direction in enumerate([(1, 0, 0), (0, 1, 0), (0, 0, 1)]): 

29 Z_nnc[:, :, c] = overlaps.overlap(bz_index=0, direction=direction) 

30 

31 U_nn = np.identity(overlaps.nbands) 

32 

33 if verbose: 

34 print('iter value change') 

35 print('---- ---------- ----------') 

36 

37 old = 0.0 

38 for iter in range(maxiter): 

39 value = cgpaw.localize(Z_nnc, U_nn) 

40 if verbose: 

41 print(f'{iter:4} {value:10.3f} {value - old:10.6f}') 

42 if value - old < tolerance: 

43 break 

44 old = value 

45 else: 

46 raise LocalizationNotConvergedError( 

47 f'Did not converge in {maxiter} iterations') 

48 

49 # Find centers: 

50 scaled_nc = -np.angle(Z_nnc.diagonal()).T / (2 * pi) 

51 centers_nv = (scaled_nc % 1.0).dot(overlaps.atoms.cell) 

52 

53 return WannierFunctions(overlaps.atoms, centers_nv, value, [U_nn]) 

54 

55 

56if __name__ == '__main__': 

57 import sys 

58 from gpaw import GPAW 

59 from gpaw.wannier.overlaps import calculate_overlaps 

60 

61 calc = GPAW(sys.argv[1]) 

62 nwannier = int(sys.argv[2]) 

63 overlaps = calculate_overlaps(calc, nwannier) 

64 wan = overlaps.localize_er(verbose=True) 

65 print(wan.centers) 

66 wan.centers_as_atoms().edit()