Coverage for gpaw/test/poisson/test_fastpoisson.py: 100%
76 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1# This test verifies that the FastPoissonSolver produces solutions to
2# the Poisson equation with very small residuals for different cells,
3# pbcs, and grids.
5import itertools
6import numpy as np
7from ase.build import bulk
8from gpaw.poisson import FastPoissonSolver, BadAxesError
9from gpaw.grid_descriptor import GridDescriptor
10from gpaw.fd_operators import Laplace
11from gpaw.mpi import world
12from gpaw.utilities import h2gpts
14# Test: different pbcs
15# For pbc=000, test charged system
16# Different cells (orthorhombic/general)
17# use_cholesky keyword
20def test_poisson_fastpoisson():
21 cell_cv = np.array(bulk('Au').cell)
22 rng = np.random.RandomState(42)
24 tf = range(2)
26 def icells():
27 # cells: orthorhombic fcc bcc hcp
28 yield 'diag', np.diag([3., 4., 5.])
30 from ase.build import fcc111
31 atoms = fcc111('Au', size=(1, 1, 1))
32 atoms.center(vacuum=1, axis=2)
33 yield 'fcc111@z', atoms.cell.copy()
34 yield 'fcc111@x', atoms.cell[[2, 0, 1]]
35 yield 'fcc111@y', atoms.cell[[1, 2, 0]]
37 for sym in ['Au', 'Fe', 'Sc']:
38 cell = bulk(sym).cell
39 yield sym, cell.copy()
41 tolerance = 1e-12
43 def test(cellno, cellname, cell_cv, idiv, pbc, nn):
44 N_c = h2gpts(0.12, cell_cv, idiv=idiv)
45 if idiv == 1:
46 N_c += 1 - N_c % 2 # We want especially to test uneven grids
47 gd = GridDescriptor(N_c, cell_cv, pbc_c=pbc)
48 rho_g = gd.zeros()
49 phi_g = gd.zeros()
50 rho_g[:] = -0.3 + rng.rand(*rho_g.shape)
52 # Neutralize charge:
53 charge = gd.integrate(rho_g)
54 magic = gd.get_size_of_global_array().prod()
55 rho_g -= charge / gd.dv / magic
56 charge = gd.integrate(rho_g)
57 assert abs(charge) < 1e-12
59 # Check use_cholesky=True/False ?
60 from gpaw.poisson import FDPoissonSolver
61 ps = FastPoissonSolver(nn=nn)
62 # print('setgrid')
64 # Will raise BadAxesError for some pbc/cell combinations
65 ps.set_grid_descriptor(gd)
67 ps.solve(phi_g, rho_g)
69 laplace = Laplace(gd, scale=-1.0 / (4.0 * np.pi), n=nn)
71 def get_residual_err(phi_g):
72 rhotest_g = gd.zeros()
73 laplace.apply(phi_g, rhotest_g)
74 residual = np.abs(rhotest_g - rho_g)
75 # Residual is not accurate at end of non-periodic directions
76 # except for nn=1 (since effectively we use the right stencil
77 # only for nn=1 at the boundary).
78 #
79 # To do this check correctly, the Laplacian should have lower
80 # nn at the boundaries. Therefore we do not test the residual
81 # at these ends, only in between, by zeroing the bad ones:
82 if nn > 1:
83 exclude_points = nn - 1
84 for c in range(3):
85 if nn > 1 and not pbc[c]:
86 # get view ehere axis c refers becomes
87 # zeroth dimension:
88 X = residual.transpose(c, (c + 1) % 3, (c + 2) % 3)
90 if gd.beg_c[c] == 1:
91 X[:exclude_points] = 0.0
92 if gd.end_c[c] == gd.N_c[c]:
93 X[-exclude_points:] = 0.0
94 return residual.max()
96 maxerr = get_residual_err(phi_g)
97 pbcstring = '{}{}{}'.format(*pbc)
99 if 0:
100 ps2 = FDPoissonSolver(relax='J', nn=nn, eps=1e-18)
101 ps2.set_grid_descriptor(gd)
102 phi2_g = gd.zeros()
103 ps2.solve(phi2_g, rho_g)
105 phimaxerr = np.abs(phi2_g - phi_g).max()
106 maxerr2 = get_residual_err(phi2_g)
107 msg = ('{:2d} {:8s} pbc={} err={:8.5e} err[J]={:8.5e} '
108 'err[phi]={:8.5e} nn={:1d}'
109 .format(cellno, cellname, pbcstring, maxerr, maxerr2,
110 phimaxerr, nn))
112 state = 'ok' if maxerr < tolerance else 'FAIL'
114 msg = ('{:2d} {:8s} grid={} pbc={} err[fast]={:8.5e} nn={:1d} {}'
115 .format(cellno, cellname, N_c, pbcstring, maxerr, nn, state))
116 if world.rank == 0:
117 print(msg)
119 return maxerr
121 errs = []
122 for idiv in [4, 1]:
123 for cellno, (cellname, cell_cv) in enumerate(icells()):
124 for pbc in itertools.product(tf, tf, tf):
125 for nn in [1, 3]:
126 try:
127 err = test(cellno, cellname, cell_cv, idiv, pbc, nn)
128 except BadAxesError:
129 # Ignore incompatible pbc/cell combinations
130 continue
132 errs.append(err)
134 for i, err in enumerate(errs):
135 assert err < tolerance, err