Coverage for gpaw/test/test_libelpa.py: 98%

58 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import pytest 

2from gpaw.utilities.elpa import LibElpa 

3import numpy as np 

4import scipy as sp 

5from gpaw.blacs import BlacsGrid 

6from gpaw.mpi import world 

7 

8pytestmark = pytest.mark.skipif(not LibElpa.have_elpa(), 

9 reason='not LibElpa.have_elpa()') 

10 

11 

12@pytest.mark.ci 

13@pytest.mark.parametrize('dtype', [float, complex]) 

14@pytest.mark.parametrize('eigensolver', ['elpa', 'scalapack']) 

15@pytest.mark.parametrize('eigentype', ['normal', 'general']) 

16def test_libelpa(dtype, eigensolver, eigentype): 

17 rng = np.random.RandomState(87878787) 

18 

19 if world.size == 1: 

20 shape = 1, 1 

21 else: 

22 shape = world.size // 2, 2 

23 bg = BlacsGrid(world, *shape) 

24 

25 M = 8 

26 blocksize = 2 

27 

28 desc = bg.new_descriptor(M, M, blocksize, blocksize) 

29 sdesc = desc.as_serial() 

30 

31 Aserial = sdesc.zeros(dtype=dtype) 

32 if world.rank == 0: 

33 Aserial[:] = rng.rand(*Aserial.shape) 

34 if dtype == complex: 

35 Aserial.imag += rng.rand(*Aserial.shape) 

36 Aserial += Aserial.T.copy().conj() 

37 A = desc.distribute_from_master(Aserial) 

38 C2 = desc.zeros(dtype=dtype) 

39 eps2 = np.zeros(M) 

40 

41 if eigentype == 'normal': 

42 if world.rank == 0: 

43 eps1, C1 = np.linalg.eigh(Aserial) 

44 

45 if eigensolver == 'elpa': 

46 elpa = LibElpa(desc) 

47 elpa.diagonalize(A.copy(), C2, eps2) 

48 elif eigensolver == 'scalapack': 

49 desc.diagonalize_dc(A.copy(), C2, eps2) 

50 elif eigentype == 'general': 

51 Sserial = sdesc.zeros(dtype=dtype) 

52 if world.rank == 0: 

53 Sserial[:] = np.eye(M) 

54 Sserial[3, 1] += 0.5 

55 if dtype == complex: 

56 Sserial[2, 4] += 0.2j 

57 S = desc.distribute_from_master(Sserial) 

58 

59 if world.rank == 0: 

60 eps1, C1 = sp.linalg.eigh(Aserial, Sserial) 

61 

62 if eigensolver == 'elpa': 

63 elpa = LibElpa(desc) 

64 elpa.general_diagonalize(A.copy(), S.copy(), C2, eps2) 

65 elif eigensolver == 'scalapack': 

66 desc.general_diagonalize_dc(A.copy(), S.copy(), C2, eps2) 

67 

68 if world.rank == 0: 

69 print(eps1) 

70 print(eps2) 

71 err = np.abs(eps1 - eps2).max() 

72 assert err < 1e-13, err