Coverage for gpaw/test/poisson/test_poisson_moment.py: 99%
110 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
2import pytest
4from ase.units import Bohr
5from gpaw.poisson import PoissonSolver, NoInteractionPoissonSolver
6from gpaw.poisson_moment import MomentCorrectionPoissonSolver, MomentCorrection
7from gpaw.poisson_extravacuum import ExtraVacuumPoissonSolver
8from gpaw.grid_descriptor import GridDescriptor
11@pytest.mark.parametrize('moment_corrections, expected_len', [
12 (None, 0),
13 ([], 0),
14 (4, 1),
15 (9, 1),
16 ([dict(moms=range(4), center=np.array([1, 3, 5]))], 1),
17 ([dict(moms=range(4), center=np.array([5, 3, 5])),
18 dict(moms=range(4), center=np.array([7, 5, 3]))], 2)
19])
20def test_defaults(moment_corrections, expected_len):
21 poisson_ref = NoInteractionPoissonSolver()
22 poisson = MomentCorrectionPoissonSolver(
23 poissonsolver=poisson_ref,
24 moment_corrections=moment_corrections)
26 assert isinstance(poisson.moment_corrections, list), \
27 poisson.moment_corrections
28 assert len(poisson.moment_corrections) == expected_len
29 assert all([isinstance(mom, MomentCorrection)
30 for mom in poisson.moment_corrections])
33@pytest.mark.parametrize('moment_corrections', [
34 None,
35 [],
36])
37def test_description_empty(moment_corrections):
38 poisson_ref = NoInteractionPoissonSolver()
39 poisson = MomentCorrectionPoissonSolver(
40 poissonsolver=poisson_ref,
41 moment_corrections=moment_corrections)
43 desc = poisson.get_description()
44 desc_ref = poisson_ref.get_description()
46 assert isinstance(desc, str)
47 assert isinstance(desc_ref, str)
48 assert desc_ref in desc
49 assert '0 moment corrections' in desc
52@pytest.mark.parametrize('moment_corrections, expected_strings', [
53 (4, ['1 moment corrections', 'center', 'range(0, 4)']),
54 (9, ['1 moment corrections', 'center', 'range(0, 9)']),
55 ([dict(moms=range(4), center=np.array([1, 1, 1]))],
56 ['1 moment corrections', '[1.00, 1.00, 1.00]', 'range(0, 4)']),
57 ([dict(moms=[1, 2, 3], center=np.array([1, 1, 1]))],
58 ['1 moment corrections', '[1.00, 1.00, 1.00]', 'range(1, 4)']),
59 ([dict(moms=[0, 2, 3], center=np.array([1, 1, 1]))],
60 ['1 moment corrections', '[1.00, 1.00, 1.00]', '(0, 2, 3)']),
61 ([dict(moms=range(4), center=np.array([2, 3, 4])),
62 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))],
63 ['2 moment corrections', '[2.00, 3.00, 4.00]',
64 '[7.40, 3.10, 0.10]', 'range(0, 4)']),
65])
66def test_description(moment_corrections, expected_strings):
67 poisson_ref = NoInteractionPoissonSolver()
68 poisson = MomentCorrectionPoissonSolver(
69 poissonsolver=poisson_ref,
70 moment_corrections=moment_corrections)
72 desc = poisson.get_description()
73 desc_ref = poisson_ref.get_description()
75 assert isinstance(desc, str)
76 assert isinstance(desc_ref, str)
78 # Make sure that the description starts with the description of the wrapped
79 # solver
80 assert desc.startswith(desc_ref)
82 # and follows with the moments
83 desc_rem = desc[len(desc_ref):]
84 for expected_str in expected_strings:
85 assert expected_str in desc_rem, \
86 f'"{expected_str}" not in "{desc_rem}"'
89@pytest.mark.parametrize('moment_corrections, expected_string', [
90 ([], 'no corrections'),
91 (4, 'array([0, 1, 2, 3]) @ None'),
92 (9, 'array([0, 1, 2, 3, 4, 5, 6, 7, 8]) @ None'),
93 ([dict(moms=range(4), center=np.array([1., 1., 1.]))],
94 'array([0, 1, 2, 3]) @ array([1., 1., 1.])'),
95 ([dict(moms=[1, 2, 3], center=np.array([1., 1., 1.]))],
96 'array([1, 2, 3]) @ array([1., 1., 1.])'),
97 ([dict(moms=[0, 2, 3], center=np.array([1., 1., 1.]))],
98 'array([0, 2, 3]) @ array([1., 1., 1.])'),
99 ([dict(moms=range(4), center=np.array([2, 3, 4])),
100 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))],
101 '2 corrections'),
102])
103def test_repr(moment_corrections, expected_string):
104 poisson_ref = NoInteractionPoissonSolver()
105 poisson = MomentCorrectionPoissonSolver(
106 poissonsolver=poisson_ref,
107 moment_corrections=moment_corrections)
109 rep = repr(poisson)
110 expected_repr = f'MomentCorrectionPoissonSolver ({expected_string})'
112 assert isinstance(rep, str)
113 assert rep == expected_repr, f'{rep} not equal to {expected_repr}'
116@pytest.fixture
117def gd():
118 N_c = (16, 16, 3 * 16)
119 cell_cv = (1, 1, 3)
120 gd = GridDescriptor(N_c, cell_cv, False)
122 return gd
125@pytest.mark.parametrize('moment_corrections', [
126 4,
127 9,
128 [dict(moms=range(4), center=np.array([1, 1, 1]))],
129 [dict(moms=range(4), center=np.array([2, 3, 4])),
130 dict(moms=range(4), center=np.array([7.4, 3.1, 0.1]))],
131])
132def test_write(gd, moment_corrections):
133 poisson_ref = PoissonSolver()
134 poisson_ref.set_grid_descriptor(gd)
136 poisson = MomentCorrectionPoissonSolver(
137 poissonsolver=poisson_ref,
138 moment_corrections=moment_corrections)
139 poisson.set_grid_descriptor(gd)
141 from gpaw.io import Writer
142 from gpaw.mpi import world
143 filename = '/dev/null'
145 # By using the Writer we check that everything is JSON serializable
146 writer = Writer(filename, world)
147 writer.child('poisson').write(**poisson.todict())
148 writer.close()
151@pytest.fixture
152def rho_g(gd):
153 # Construct model density
154 coord_vg = gd.get_grid_point_coordinates()
155 z_g = coord_vg[2, :]
156 rho_g = gd.zeros()
157 for z0 in [1, 2]:
158 rho_g += 10 * (z_g - z0) * \
159 np.exp(-20 * np.sum((coord_vg.T - np.array([.5, .5, z0])).T**2,
160 axis=0))
162 return rho_g
165@pytest.fixture
166def poisson_solve(gd, rho_g):
168 def _poisson_solve(poisson):
169 poisson.set_grid_descriptor(gd)
170 phi_g = gd.zeros()
171 poisson.solve(phi_g, rho_g)
173 return phi_g
175 return _poisson_solve
178@pytest.fixture
179def compare(gd, tolerance, cmp_begin):
180 # Some test cases compare in only a small region of space
181 if cmp_begin is None:
182 slice = None
183 else:
184 Ng_c = gd.get_size_of_global_array()
185 cmp_end = 1 - cmp_begin
186 idx_c = [np.arange(int(N * cmp_begin), int(N * cmp_end)) for N in Ng_c]
187 slice = np.ix_(*idx_c)
189 def _compare(phi1_g, phi2_g):
190 big_phi1_g = gd.collect(phi1_g)
191 big_phi2_g = gd.collect(phi2_g)
192 if gd.comm.rank == 0:
193 if slice is not None:
194 big_phi1_g = big_phi1_g[slice]
195 big_phi2_g = big_phi2_g[slice]
196 assert np.max(np.absolute(big_phi1_g - big_phi2_g)) == (
197 pytest.approx(0.0, abs=tolerance))
199 return _compare
202@pytest.fixture
203def poisson_ref(gd, ref):
204 poisson_default = PoissonSolver()
205 if ref == 'default':
206 # Get reference from default poissonsolver
207 # Using the default solver the potential is forced to zero at the box
208 # boundries. The potential thus has the wrong shape near the boundries
209 # but is nearly right in the center of the box
210 return poisson_default
211 elif ref == 'extravac':
212 # Get reference from extravacuum solver
213 # With 4 times extra vacuum the potential is well converged everywhere
214 poisson_extravac = ExtraVacuumPoissonSolver(
215 gpts=4 * gd.N_c,
216 poissonsolver_large=poisson_default)
217 return poisson_extravac
218 else:
219 raise ValueError(f'No such ref {ref}')
222@pytest.mark.parametrize('ref, moment_corrections, tolerance, cmp_begin', [
223 # MomentCorrectionPoissonSolver without any moment corrections should be
224 # exactly as the underlying solver
225 ('default', None, 0.0, None),
226 # It should also be possible to chain default+extravacuum+moment correction
227 # With moment_correction=None the MomentCorrection solver doesn't actually
228 # do anything, so the potential should be identical to the extra vacuum
229 # reference
230 ('extravac', None, 0.0, None),
231 # Test moment_corrections=int
232 # The moment correction is applied to the center of the cell. This is not
233 # enough to have a converged potential near the edges
234 # The closer we are to the center the better though
235 ('default', 4, 3.5e-2, 0.25),
236 ('default', 4, 2.5e-2, 0.40),
237 # Test moment_corrections=list
238 # Remember that the solver expects Ångström units and we have specified
239 # the grid in Bohr
240 # This should give a well converged potential everywhere, that we can
241 # compare to the reference extravacuum potential
242 ('extravac',
243 [{'moms': range(4), 'center': np.array([.5, .5, 1]) * Bohr},
244 {'moms': range(4), 'center': np.array([.5, .5, 2]) * Bohr}],
245 3e-3, None),
246 # It should be possible to chain default+extravacuum+moment correction
247 # As the potential is already well converged, there should be little change
248 ('extravac',
249 [{'moms': range(4), 'center': np.array([.5, .5, 1]) * Bohr},
250 {'moms': range(4), 'center': np.array([.5, .5, 2]) * Bohr}],
251 5e-4, None),
252])
253def test_poisson_moment_correction(gd, rho_g, poisson_solve,
254 compare, poisson_ref,
255 ref, moment_corrections,
256 tolerance, cmp_begin):
257 # Solve for the potential using the reference solver
258 phiref_g = poisson_solve(poisson_ref)
260 # Create a MomentCorrectionPoissonSolver and solve for the potential
261 poisson = MomentCorrectionPoissonSolver(poissonsolver=poisson_ref,
262 moment_corrections=None)
263 phi_g = poisson_solve(poisson)
265 # Test the MomentCorrectionPoissonSolver
266 compare(phi_g, phiref_g)