Coverage for gpaw/test/core/test_matrix_elements.py: 96%

55 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-09 00:21 +0000

1import pytest 

2 

3from gpaw.core import PWDesc, UGDesc 

4from gpaw.core.matrix import Matrix 

5from gpaw.mpi import world 

6 

7 

8def comms(): 

9 """Yield communicator combinations.""" 

10 for s in [1, 2, 4, 8]: 

11 if s > world.size: 

12 return 

13 domain_comm = world.new_communicator( 

14 range(world.rank // s * s, world.rank // s * s + s)) 

15 band_comm = world.new_communicator( 

16 range(world.rank % s, world.size, s)) 

17 yield domain_comm, band_comm 

18 

19 

20def func(f): 

21 """Operator for matrix elements.""" 

22 g = f.copy() 

23 g.data *= 2.3 

24 return g 

25 

26 

27# TODO: test also UGArray 

28@pytest.mark.parametrize('domain_comm, band_comm', list(comms())) 

29@pytest.mark.parametrize('dtype', [float, complex]) 

30@pytest.mark.parametrize('nbands', [1, 7, 21]) 

31@pytest.mark.parametrize('function', [None, func]) 

32def test_me(domain_comm, band_comm, dtype, nbands, function): 

33 a = 2.5 

34 n = 20 

35 grid = UGDesc(cell=[a, a, a], size=(n, n, n)) 

36 desc = PWDesc(ecut=50, cell=grid.cell) 

37 desc = desc.new(comm=domain_comm, dtype=dtype) 

38 f = desc.empty(nbands, comm=band_comm) 

39 f.randomize() 

40 

41 M = f.matrix_elements(f, function=function) 

42 out = Matrix(nbands, nbands, dist=(band_comm, -1, 1), dtype=dtype) 

43 out.data[:] = 1e308 # will overflow when multiplied by 2 

44 f.matrix_elements(f, function=function, out=out) 

45 

46 f1 = f.gathergather() 

47 M2 = M.gather() 

48 if f1 is not None: 

49 M1 = f1.matrix_elements(f1, function=function) 

50 M1.tril2full() 

51 M2.tril2full() 

52 dM = M1.data - M2.data 

53 assert abs(dM).max() < 1e-11 

54 

55 if function is None: 

56 g = f.new() 

57 g.randomize() 

58 M = f.matrix_elements(g) 

59 

60 f1 = f.gathergather() 

61 g1 = g.gathergather() 

62 M2 = M.gather() 

63 if f1 is not None: 

64 M1 = f1.matrix_elements(g1) 

65 M1.tril2full() 

66 M2.tril2full() 

67 dM = M1.data - M2.data 

68 assert abs(dM).max() < 1e-11 

69 

70 

71if __name__ == '__main__': 

72 d, b = list(comms())[0] 

73 test_me(d, b, float, 4, None)