Coverage for gpaw/test/lcao/test_lcao_parallel.py: 10%
94 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import sys
3import pytest
4from gpaw.utilities import devnull
6from gpaw import GPAW, FermiDirac, KohnShamConvergenceError
7from gpaw.utilities import compiled_with_sl
8from gpaw.mpi import world
9from ase.build import molecule
11# Calculates energy and forces for various parallelizations
13pytestmark = pytest.mark.skipif(world.size < 4,
14 reason='world.size < 4')
17def test_lcao_lcao_parallel():
18 tolerance = 4e-5
20 parallel = dict()
22 basekwargs = dict(mode='lcao',
23 nbands=6,
24 parallel=parallel)
26 Eref = None
27 Fref_av = None
29 def run(formula='H2O', vacuum=2.0, cell=None, pbc=0, **morekwargs):
30 print(formula, parallel)
31 system = molecule(formula)
32 kwargs = dict(basekwargs)
33 kwargs.update(morekwargs)
34 calc = GPAW(**kwargs)
35 system.calc = calc
36 system.center(vacuum)
37 if cell is None:
38 system.center(vacuum)
39 else:
40 system.set_cell(cell)
41 system.set_pbc(pbc)
43 try:
44 system.get_potential_energy()
45 except KohnShamConvergenceError:
46 pass
48 E = calc.hamiltonian.e_total_free
49 F_av = calc.get_forces()
51 nonlocal Eref, Fref_av
52 if Eref is None:
53 Eref = E
54 Fref_av = F_av
56 eerr = abs(E - Eref)
57 ferr = abs(F_av - Fref_av).max()
59 if calc.wfs.world.rank == 0:
60 print('Energy', E)
61 print()
62 print('Forces')
63 print(F_av)
64 print()
65 print('Errs', eerr, ferr)
67 if eerr > tolerance or ferr > tolerance:
68 if calc.wfs.world.rank == 0:
69 stderr = sys.stderr
70 else:
71 stderr = devnull
72 if eerr > tolerance:
73 print('Failed!', file=stderr)
74 print('E = %f, Eref = %f' % (E, Eref), file=stderr)
75 msg = 'Energy err larger than tolerance: %f' % eerr
76 if ferr > tolerance:
77 print('Failed!', file=stderr)
78 print('Forces:', file=stderr)
79 print(F_av, file=stderr)
80 print(file=stderr)
81 print('Ref forces:', file=stderr)
82 print(Fref_av, file=stderr)
83 print(file=stderr)
84 msg = 'Force err larger than tolerance: %f' % ferr
85 print(file=stderr)
86 print('Args:', file=stderr)
87 print(formula, vacuum, cell, pbc, morekwargs, file=stderr)
88 print(parallel, file=stderr)
89 raise AssertionError(msg)
91 # reference:
92 # state-parallelization = 1,
93 # domain-decomposition = (1, 2, 2)
94 run()
96 # state-parallelization = 2,
97 # domain-decomposition = (1, 2, 1)
98 parallel['band'] = 2
99 parallel['domain'] = (1, 2, world.size // 4)
100 run()
102 if compiled_with_sl():
103 # state-parallelization = 2,
104 # domain-decomposition = (1, 2, 1)
105 # with blacs
106 parallel['sl_default'] = (2, 2, 2)
107 run()
109 # state-parallelization = 1,
110 # domain-decomposition = (1, 2, 2)
111 # with blacs
112 del parallel['band']
113 del parallel['domain']
114 run()
116 # perform spin polarization test
117 parallel = dict()
119 basekwargs = dict(mode='lcao',
120 nbands=6,
121 parallel=parallel)
123 Eref = None
124 Fref_av = None
126 OH_kwargs = dict(formula='NH2', vacuum=1.5, pbc=1, spinpol=1,
127 occupations=FermiDirac(0.1))
129 # start with empty parallel keyword
130 # del parallel['sl_default']
131 # parallel = None
132 # parallel = dict()
133 # print parallel
134 # del parallel['band']
135 parallel['domain'] = (1, 2, 1)
137 # reference:
138 # spin-polarization = 2,
139 # state-parallelization = 1,
140 # domain-decomposition = (1, 2, 1)
141 run(**OH_kwargs)
143 # spin-polarization = 1,
144 # state-parallelization= 2,
145 # domain-decomposition = (1, 1, 1)
146 del parallel['domain']
147 parallel['band'] = 2
148 run(**OH_kwargs)
150 if compiled_with_sl():
151 # spin-polarization = 2,
152 # state-parallelization = 2,
153 # domain-decomposition = (1, 1, 1)
154 # with blacs
155 parallel['sl_default'] = (2, 1, 2)
156 run(**OH_kwargs)
158 # spin-polarization = 2,
159 # state-parallelization = 1,
160 # domain-decomposition = (1, 2, 1)
161 del parallel['band']
162 parallel['domain'] = (1, 2, 1)
163 run(**OH_kwargs)
165 # spin-polarization = 1,
166 # state-parallelization = 1,
167 # domain-decomposition = (1, 2, 2)
168 parallel['domain'] = (1, 2, 2)
169 run(**OH_kwargs)