Coverage for gpaw/test/wannier/test_wannier90.py: 100%

77 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import pytest 

2import os 

3import numpy as np 

4from gpaw import GPAW 

5from gpaw.wannier90 import Wannier90 

6from gpaw.wannier.w90 import read_wout_all 

7from pathlib import Path 

8from subprocess import PIPE, run 

9 

10 

11def out(): 

12 result = run('wannier90.x --version', 

13 stdout=PIPE, 

14 stderr=PIPE, 

15 universal_newlines=True, 

16 shell=True) 

17 return result.stdout 

18 

19 

20@pytest.mark.wannier 

21@pytest.mark.serial 

22@pytest.mark.skipif(': 3.' not in out(), 

23 reason="requires at least Wannier90 version 3.0") 

24@pytest.mark.parametrize('mode', ['sym', 'nosym']) 

25def test_wannier90(gpw_files, mode, in_tmp_dir, wannier90): 

26 o_ai = [[], [0, 1, 2, 3]] 

27 bands = range(4) 

28 

29 if mode == 'sym': 

30 calc = GPAW(gpw_files['gaas_pw']) 

31 assert calc.wfs.kd.nbzkpts > calc.wfs.kd.nibzkpts 

32 else: 

33 calc = GPAW(gpw_files['gaas_pw_nosym']) 

34 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts 

35 

36 seed = f'GaAs_{mode}' 

37 

38 wannier = Wannier90(calc, 

39 seed=seed, 

40 bands=bands, 

41 orbitals_ai=o_ai) 

42 

43 wannier.write_input(num_iter=1000, 

44 plot=False) 

45 

46 os.system('wannier90.x -pp ' + seed) 

47 wannier.write_projections() 

48 wannier.write_eigenvalues() 

49 wannier.write_overlaps() 

50 os.system('wannier90.x ' + seed) 

51 with (Path(f'{seed}.wout')).open() as fd: 

52 w = read_wout_all(fd) 

53 centers = np.sum(np.array(w['centers']), axis=0) 

54 print('centers:', centers) 

55 centers_correct = np.array([5.68, 5.68, 5.68]) 

56 assert np.allclose(centers, centers_correct, atol=1e-3) 

57 spreads = np.sum(np.array(w['spreads'])) 

58 assert spreads == pytest.approx(9.9733, abs=0.002) 

59 

60 # also test wavefunctions 

61 wannier.write_wavefunctions() 

62 check_wavefunctions() 

63 

64 

65@pytest.mark.wannier 

66@pytest.mark.serial 

67@pytest.mark.skipif(': 3.' not in out(), 

68 reason="requires at least Wannier90 version 3.0") 

69def test_wannier90_soc(gpw_files, in_tmp_dir): 

70 calc = GPAW(gpw_files['fe_pw_nosym']) 

71 seed = 'Fe' 

72 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts 

73 

74 wannier = Wannier90(calc, 

75 seed=seed, 

76 bands=range(9), 

77 spinors=True) 

78 

79 wannier.write_input(num_iter=200, 

80 dis_num_iter=500, 

81 dis_mix_ratio=1.0) 

82 os.system('wannier90.x -pp ' + seed) 

83 wannier.write_projections() 

84 wannier.write_eigenvalues() 

85 wannier.write_overlaps() 

86 

87 os.system('wannier90.x ' + seed) 

88 

89 with (Path('Fe.wout')).open() as fd: 

90 w = read_wout_all(fd) 

91 centers = np.sum(np.array(w['centers']), axis=0) 

92 centers_correct = [12.9, 12.9, 12.9] 

93 assert np.allclose(centers, centers_correct, atol=0.19) 

94 spreads = np.sum(np.array(w['spreads'])) 

95 assert spreads == pytest.approx(20.1, abs=0.6) 

96 

97 

98def check_wavefunctions(): 

99 

100 test1 = [[20, 20, 20, 1, 4], [20, 20, 20, 2, 4], [20, 20, 20, 3, 4]] 

101 test2 = [0.0656, 0.0634, 0.0437] 

102 for i in range(3): 

103 with open(f"UNK0000{i + 1}.1") as f: 

104 l1 = f.readline() 

105 l1 = l1.split(' ') 

106 l1 = [int(i) for i in l1] 

107 assert l1 == test1[i] 

108 l2 = f.readline() 

109 l2 = l2.split(' ') 

110 l2 = [float(i) for i in l2] 

111 l2 = l2[0] + 1j * l2[1] 

112 l2_abs = abs(l2) 

113 assert np.allclose(l2_abs, test2[i], atol=1e-3)