Coverage for gpaw/test/test_libelpa.py: 98%
58 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import pytest
2from gpaw.utilities.elpa import LibElpa
3import numpy as np
4import scipy as sp
5from gpaw.blacs import BlacsGrid
6from gpaw.mpi import world
8pytestmark = pytest.mark.skipif(not LibElpa.have_elpa(),
9 reason='not LibElpa.have_elpa()')
12@pytest.mark.ci
13@pytest.mark.parametrize('dtype', [float, complex])
14@pytest.mark.parametrize('eigensolver', ['elpa', 'scalapack'])
15@pytest.mark.parametrize('eigentype', ['normal', 'general'])
16def test_libelpa(dtype, eigensolver, eigentype):
17 rng = np.random.RandomState(87878787)
19 if world.size == 1:
20 shape = 1, 1
21 else:
22 shape = world.size // 2, 2
23 bg = BlacsGrid(world, *shape)
25 M = 8
26 blocksize = 2
28 desc = bg.new_descriptor(M, M, blocksize, blocksize)
29 sdesc = desc.as_serial()
31 Aserial = sdesc.zeros(dtype=dtype)
32 if world.rank == 0:
33 Aserial[:] = rng.rand(*Aserial.shape)
34 if dtype == complex:
35 Aserial.imag += rng.rand(*Aserial.shape)
36 Aserial += Aserial.T.copy().conj()
37 A = desc.distribute_from_master(Aserial)
38 C2 = desc.zeros(dtype=dtype)
39 eps2 = np.zeros(M)
41 if eigentype == 'normal':
42 if world.rank == 0:
43 eps1, C1 = np.linalg.eigh(Aserial)
45 if eigensolver == 'elpa':
46 elpa = LibElpa(desc)
47 elpa.diagonalize(A.copy(), C2, eps2)
48 elif eigensolver == 'scalapack':
49 desc.diagonalize_dc(A.copy(), C2, eps2)
50 elif eigentype == 'general':
51 Sserial = sdesc.zeros(dtype=dtype)
52 if world.rank == 0:
53 Sserial[:] = np.eye(M)
54 Sserial[3, 1] += 0.5
55 if dtype == complex:
56 Sserial[2, 4] += 0.2j
57 S = desc.distribute_from_master(Sserial)
59 if world.rank == 0:
60 eps1, C1 = sp.linalg.eigh(Aserial, Sserial)
62 if eigensolver == 'elpa':
63 elpa = LibElpa(desc)
64 elpa.general_diagonalize(A.copy(), S.copy(), C2, eps2)
65 elif eigensolver == 'scalapack':
66 desc.general_diagonalize_dc(A.copy(), S.copy(), C2, eps2)
68 if world.rank == 0:
69 print(eps1)
70 print(eps2)
71 err = np.abs(eps1 - eps2).max()
72 assert err < 1e-13, err