Coverage for gpaw/test/poisson/test_fastpoisson.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1# This test verifies that the FastPoissonSolver produces solutions to 

2# the Poisson equation with very small residuals for different cells, 

3# pbcs, and grids. 

4 

5import itertools 

6import numpy as np 

7from ase.build import bulk 

8from gpaw.poisson import FastPoissonSolver, BadAxesError 

9from gpaw.grid_descriptor import GridDescriptor 

10from gpaw.fd_operators import Laplace 

11from gpaw.mpi import world 

12from gpaw.utilities import h2gpts 

13 

14# Test: different pbcs 

15# For pbc=000, test charged system 

16# Different cells (orthorhombic/general) 

17# use_cholesky keyword 

18 

19 

20def test_poisson_fastpoisson(): 

21 cell_cv = np.array(bulk('Au').cell) 

22 rng = np.random.RandomState(42) 

23 

24 tf = range(2) 

25 

26 def icells(): 

27 # cells: orthorhombic fcc bcc hcp 

28 yield 'diag', np.diag([3., 4., 5.]) 

29 

30 from ase.build import fcc111 

31 atoms = fcc111('Au', size=(1, 1, 1)) 

32 atoms.center(vacuum=1, axis=2) 

33 yield 'fcc111@z', atoms.cell.copy() 

34 yield 'fcc111@x', atoms.cell[[2, 0, 1]] 

35 yield 'fcc111@y', atoms.cell[[1, 2, 0]] 

36 

37 for sym in ['Au', 'Fe', 'Sc']: 

38 cell = bulk(sym).cell 

39 yield sym, cell.copy() 

40 

41 tolerance = 1e-12 

42 

43 def test(cellno, cellname, cell_cv, idiv, pbc, nn): 

44 N_c = h2gpts(0.12, cell_cv, idiv=idiv) 

45 if idiv == 1: 

46 N_c += 1 - N_c % 2 # We want especially to test uneven grids 

47 gd = GridDescriptor(N_c, cell_cv, pbc_c=pbc) 

48 rho_g = gd.zeros() 

49 phi_g = gd.zeros() 

50 rho_g[:] = -0.3 + rng.rand(*rho_g.shape) 

51 

52 # Neutralize charge: 

53 charge = gd.integrate(rho_g) 

54 magic = gd.get_size_of_global_array().prod() 

55 rho_g -= charge / gd.dv / magic 

56 charge = gd.integrate(rho_g) 

57 assert abs(charge) < 1e-12 

58 

59 # Check use_cholesky=True/False ? 

60 from gpaw.poisson import FDPoissonSolver 

61 ps = FastPoissonSolver(nn=nn) 

62 # print('setgrid') 

63 

64 # Will raise BadAxesError for some pbc/cell combinations 

65 ps.set_grid_descriptor(gd) 

66 

67 ps.solve(phi_g, rho_g) 

68 

69 laplace = Laplace(gd, scale=-1.0 / (4.0 * np.pi), n=nn) 

70 

71 def get_residual_err(phi_g): 

72 rhotest_g = gd.zeros() 

73 laplace.apply(phi_g, rhotest_g) 

74 residual = np.abs(rhotest_g - rho_g) 

75 # Residual is not accurate at end of non-periodic directions 

76 # except for nn=1 (since effectively we use the right stencil 

77 # only for nn=1 at the boundary). 

78 # 

79 # To do this check correctly, the Laplacian should have lower 

80 # nn at the boundaries. Therefore we do not test the residual 

81 # at these ends, only in between, by zeroing the bad ones: 

82 if nn > 1: 

83 exclude_points = nn - 1 

84 for c in range(3): 

85 if nn > 1 and not pbc[c]: 

86 # get view ehere axis c refers becomes 

87 # zeroth dimension: 

88 X = residual.transpose(c, (c + 1) % 3, (c + 2) % 3) 

89 

90 if gd.beg_c[c] == 1: 

91 X[:exclude_points] = 0.0 

92 if gd.end_c[c] == gd.N_c[c]: 

93 X[-exclude_points:] = 0.0 

94 return residual.max() 

95 

96 maxerr = get_residual_err(phi_g) 

97 pbcstring = '{}{}{}'.format(*pbc) 

98 

99 if 0: 

100 ps2 = FDPoissonSolver(relax='J', nn=nn, eps=1e-18) 

101 ps2.set_grid_descriptor(gd) 

102 phi2_g = gd.zeros() 

103 ps2.solve(phi2_g, rho_g) 

104 

105 phimaxerr = np.abs(phi2_g - phi_g).max() 

106 maxerr2 = get_residual_err(phi2_g) 

107 msg = ('{:2d} {:8s} pbc={} err={:8.5e} err[J]={:8.5e} ' 

108 'err[phi]={:8.5e} nn={:1d}' 

109 .format(cellno, cellname, pbcstring, maxerr, maxerr2, 

110 phimaxerr, nn)) 

111 

112 state = 'ok' if maxerr < tolerance else 'FAIL' 

113 

114 msg = ('{:2d} {:8s} grid={} pbc={} err[fast]={:8.5e} nn={:1d} {}' 

115 .format(cellno, cellname, N_c, pbcstring, maxerr, nn, state)) 

116 if world.rank == 0: 

117 print(msg) 

118 

119 return maxerr 

120 

121 errs = [] 

122 for idiv in [4, 1]: 

123 for cellno, (cellname, cell_cv) in enumerate(icells()): 

124 for pbc in itertools.product(tf, tf, tf): 

125 for nn in [1, 3]: 

126 try: 

127 err = test(cellno, cellname, cell_cv, idiv, pbc, nn) 

128 except BadAxesError: 

129 # Ignore incompatible pbc/cell combinations 

130 continue 

131 

132 errs.append(err) 

133 

134 for i, err in enumerate(errs): 

135 assert err < tolerance, err