Coverage for gpaw/test/lcao/test_lcao_complicated.py: 76%

45 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2from ase.build import fcc111 

3 

4from gpaw import GPAW, LCAO, FermiDirac 

5from gpaw.mpi import world 

6from gpaw.utilities import compiled_with_sl 

7 

8# This test verifies that energy and forces are (approximately) 

9# parallelization independent 

10# 

11# Tests the LCAO energy and forces in non-orthogonal cell with 

12# simultaneous parallelization over bands, domains and k-points (if 

13# enough CPUs are available), both with and without scalapack 

14# (if scalapack is available). 

15# 

16# Run with 1, 2, 4 or 8 (best) CPUs. 

17# 

18# This test covers many cases not caught by lcao_parallel or 

19# lcao_parallel_kpt 

20# 

21# Written November 24, 2011, r8567 

22 

23 

24def test_lcao_complicated(): 

25 system = fcc111('Au', size=(1, 3, 1)) 

26 system.numbers[0] = 8 

27 # It is important that the number of atoms is uneven; this 

28 # tests the case where the band parallelization does not match 

29 # the partitioning of orbitals between atoms (the middle atom has orbitals 

30 # on distinct band descriptor ranks) 

31 

32 system.center(vacuum=3.5, axis=2) 

33 system.rattle(stdev=0.2, seed=17) 

34 # from ase.visualize import view 

35 # view(system) 

36 

37 # system.set_pbc(0) 

38 # system.center(vacuum=3.5) 

39 

40 def calculate(parallel, comm=world, Eref=None, Fref=None): 

41 calc = GPAW(mode=LCAO(atomic_correction='sparse'), 

42 basis=dict(O='dzp', Au='sz(dzp)'), 

43 occupations=FermiDirac(0.1), 

44 kpts=(4, 1, 1), 

45 # txt=None, 

46 communicator=comm, 

47 nbands=16, 

48 parallel=parallel, 

49 h=0.35) 

50 system.calc = calc 

51 E = system.get_potential_energy() 

52 F = system.get_forces() 

53 

54 if world.rank == 0: 

55 print('Results') 

56 print('-----------') 

57 print(E) 

58 print(F) 

59 print('-----------') 

60 

61 if Eref is not None: 

62 Eerr = abs(E - Eref) 

63 assert Eerr < 1e-8, 'Bad E: err=%f; parallel=%s' % (Eerr, parallel) 

64 if Fref is not None: 

65 Ferr = np.abs(F - Fref).max() 

66 assert Ferr < 1e-6, 'Bad F: err=%f; parallel=%s' % (Ferr, parallel) 

67 return E, F 

68 

69 # First calculate reference energy and forces E and F 

70 # 

71 # If we want to really dumb things down, enable this to force an 

72 # entirely serial calculation: 

73 if 0: 

74 serial = world.new_communicator([0]) 

75 E = 0.0 

76 F = np.zeros((len(system), 3)) 

77 if world.rank == 0: 

78 E, F = calculate({}, serial) 

79 E = world.sum(E) 

80 world.sum(F) 

81 else: 

82 # Normally we'll just do it in parallel; 

83 # that case is covered well by other tests, so we can probably trust it 

84 E, F = calculate({}, world) 

85 

86 def check(parallel): 

87 return calculate(parallel, comm=world, Eref=E, Fref=F) 

88 

89 assert world.size in [1, 2, 4, 8], ('Number of CPUs %d not supported' 

90 % world.size) 

91 

92 parallel = dict(domain=1, band=1) 

93 if world.size % 2 == 0: 

94 parallel['band'] = 2 

95 if world.size % 4 == 0: 

96 parallel['domain'] = 2 

97 

98 # If size is 8, this will also use kpt parallelization. This test should 

99 # run with 8 CPUs for best coverage of parallelizations 

100 if world.size == 8: 

101 pass # sl_cpus = 4 ??? 

102 

103 if world.size > 1: 

104 check(parallel) 

105 

106 if compiled_with_sl() and world.size > 1: 

107 parallel['sl_auto'] = True 

108 check(parallel)