Coverage for gpaw/test/wannier/test_wannier90.py: 100%
77 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import pytest
2import os
3import numpy as np
4from gpaw import GPAW
5from gpaw.wannier90 import Wannier90
6from gpaw.wannier.w90 import read_wout_all
7from pathlib import Path
8from subprocess import PIPE, run
11def out():
12 result = run('wannier90.x --version',
13 stdout=PIPE,
14 stderr=PIPE,
15 universal_newlines=True,
16 shell=True)
17 return result.stdout
20@pytest.mark.wannier
21@pytest.mark.serial
22@pytest.mark.skipif(': 3.' not in out(),
23 reason="requires at least Wannier90 version 3.0")
24@pytest.mark.parametrize('mode', ['sym', 'nosym'])
25def test_wannier90(gpw_files, mode, in_tmp_dir, wannier90):
26 o_ai = [[], [0, 1, 2, 3]]
27 bands = range(4)
29 if mode == 'sym':
30 calc = GPAW(gpw_files['gaas_pw'])
31 assert calc.wfs.kd.nbzkpts > calc.wfs.kd.nibzkpts
32 else:
33 calc = GPAW(gpw_files['gaas_pw_nosym'])
34 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts
36 seed = f'GaAs_{mode}'
38 wannier = Wannier90(calc,
39 seed=seed,
40 bands=bands,
41 orbitals_ai=o_ai)
43 wannier.write_input(num_iter=1000,
44 plot=False)
46 os.system('wannier90.x -pp ' + seed)
47 wannier.write_projections()
48 wannier.write_eigenvalues()
49 wannier.write_overlaps()
50 os.system('wannier90.x ' + seed)
51 with (Path(f'{seed}.wout')).open() as fd:
52 w = read_wout_all(fd)
53 centers = np.sum(np.array(w['centers']), axis=0)
54 print('centers:', centers)
55 centers_correct = np.array([5.68, 5.68, 5.68])
56 assert np.allclose(centers, centers_correct, atol=1e-3)
57 spreads = np.sum(np.array(w['spreads']))
58 assert spreads == pytest.approx(9.9733, abs=0.002)
60 # also test wavefunctions
61 wannier.write_wavefunctions()
62 check_wavefunctions()
65@pytest.mark.wannier
66@pytest.mark.serial
67@pytest.mark.skipif(': 3.' not in out(),
68 reason="requires at least Wannier90 version 3.0")
69def test_wannier90_soc(gpw_files, in_tmp_dir):
70 calc = GPAW(gpw_files['fe_pw_nosym'])
71 seed = 'Fe'
72 assert calc.wfs.kd.nbzkpts == calc.wfs.kd.nibzkpts
74 wannier = Wannier90(calc,
75 seed=seed,
76 bands=range(9),
77 spinors=True)
79 wannier.write_input(num_iter=200,
80 dis_num_iter=500,
81 dis_mix_ratio=1.0)
82 os.system('wannier90.x -pp ' + seed)
83 wannier.write_projections()
84 wannier.write_eigenvalues()
85 wannier.write_overlaps()
87 os.system('wannier90.x ' + seed)
89 with (Path('Fe.wout')).open() as fd:
90 w = read_wout_all(fd)
91 centers = np.sum(np.array(w['centers']), axis=0)
92 centers_correct = [12.9, 12.9, 12.9]
93 assert np.allclose(centers, centers_correct, atol=0.19)
94 spreads = np.sum(np.array(w['spreads']))
95 assert spreads == pytest.approx(20.1, abs=0.6)
98def check_wavefunctions():
100 test1 = [[20, 20, 20, 1, 4], [20, 20, 20, 2, 4], [20, 20, 20, 3, 4]]
101 test2 = [0.0656, 0.0634, 0.0437]
102 for i in range(3):
103 with open(f"UNK0000{i + 1}.1") as f:
104 l1 = f.readline()
105 l1 = l1.split(' ')
106 l1 = [int(i) for i in l1]
107 assert l1 == test1[i]
108 l2 = f.readline()
109 l2 = l2.split(' ')
110 l2 = [float(i) for i in l2]
111 l2 = l2[0] + 1j * l2[1]
112 l2_abs = abs(l2)
113 assert np.allclose(l2_abs, test2[i], atol=1e-3)