Coverage for gpaw/test/vdw/test_libvdwxc_functionals.py: 81%

63 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import pytest 

2from gpaw.utilities import compiled_with_libvdwxc 

3import numpy as np 

4from gpaw.grid_descriptor import GridDescriptor 

5from gpaw.xc.libvdwxc import vdw_df, vdw_df2, vdw_df_cx, \ 

6 vdw_optPBE, vdw_optB88, vdw_C09, vdw_beef, \ 

7 libvdwxc_has_mpi, libvdwxc_has_pfft 

8 

9pytestmark = pytest.mark.skipif(not compiled_with_libvdwxc(), 

10 reason='not compiled_with_libvdwxc()') 

11 

12 

13def test_vdw_libvdwxc_functionals(): 

14 

15 # This test verifies that the results returned by the van der Waals 

16 # functionals implemented in libvdwxc do not change. 

17 

18 N_c = np.array([23, 10, 6]) 

19 gd = GridDescriptor(N_c, N_c * 0.2, pbc_c=(1, 0, 1)) 

20 

21 n_sg = gd.zeros(1) 

22 nG_sg = gd.collect(n_sg) 

23 if gd.comm.rank == 0: 

24 gen = np.random.RandomState(0) 

25 nG_sg[:] = gen.rand(*nG_sg.shape) 

26 gd.distribute(nG_sg, n_sg) 

27 

28 for mode in ['serial', 'mpi', 'pfft']: 

29 if mode == 'serial' and gd.comm.size > 1: 

30 continue 

31 if mode == 'mpi' and not libvdwxc_has_mpi(): 

32 continue 

33 if mode == 'pfft' and not libvdwxc_has_pfft(): 

34 continue 

35 

36 errs = [] 

37 

38 def test(vdwxcclass, Eref=np.nan, nvref=np.nan): 

39 print('') 

40 xc = vdwxcclass(mode=mode) 

41 xc.initialize_backend(gd) 

42 if gd.comm.rank == 0: 

43 print(xc.libvdwxc.tostring()) 

44 v_sg = gd.zeros(1) 

45 E = xc.calculate(gd, n_sg, v_sg) 

46 nv = gd.integrate(n_sg * v_sg, global_integral=True)[0] 

47 

48 Eerr = abs(E - Eref) 

49 nverr = abs(nv - nvref) 

50 errs.append((vdwxcclass.__name__, Eerr, nverr)) 

51 

52 if gd.comm.rank == 0: 

53 name = xc.name 

54 print(name) 

55 print('=' * len(name)) 

56 print('E = %19.16f vs ref = %19.16f :: err = %10.6e' 

57 % (E, Eref, Eerr)) 

58 print('nv = %19.16f vs ref = %19.16f :: err = %10.6e' 

59 % (nv, nvref, nverr)) 

60 print() 

61 gd.comm.barrier() 

62 

63 print('Update:') 

64 print(' test({}, {!r}, {!r})'.format(vdwxcclass.__name__, 

65 E, nv)) 

66 test(vdw_df, -3.8730338590248383, -4.905182929615311) 

67 test(vdw_df2, -3.9017518972499317, -4.933889152385742) 

68 test(vdw_df_cx, -3.7262108992876644, -4.760433536500078) 

69 test(vdw_optPBE, -3.7954301587466506, -4.834460613766266) 

70 test(vdw_optB88, -3.8486341203104613, -4.879005564922708) 

71 test(vdw_C09, -3.7071083039260464, -4.746114441237086) 

72 test(vdw_beef, -3.8926148228224444, -4.961101745896925) 

73 

74 if any(err[1] > 1e-14 or err[2] > 1e-14 for err in errs): 

75 # Try old values (compatibility) 

76 del errs[:] 

77 

78 test(vdw_df, -3.8730338473027368, -4.905182296422721) 

79 test(vdw_df2, -3.9017516508476211, -4.933888350723616) 

80 test(vdw_df_cx, -3.7262108875655624, -4.760432903307487) 

81 test(vdw_optPBE, -3.7954301470245491, -4.834459980573675) 

82 test(vdw_optB88, -3.8486341085883597, -4.879004931730118) 

83 test(vdw_C09, -3.7071082922039449, -4.746113808044496) 

84 test(vdw_beef, -3.8926145764201334, -4.9611009442348015) 

85 

86 for name, Eerr, nverr in errs: 

87 assert Eerr < 1e-14 and nverr < 1e-14, (name, Eerr, nverr)