Coverage for gpaw/test/core/test_matrix_elements.py: 96%
55 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-09 00:21 +0000
1import pytest
3from gpaw.core import PWDesc, UGDesc
4from gpaw.core.matrix import Matrix
5from gpaw.mpi import world
8def comms():
9 """Yield communicator combinations."""
10 for s in [1, 2, 4, 8]:
11 if s > world.size:
12 return
13 domain_comm = world.new_communicator(
14 range(world.rank // s * s, world.rank // s * s + s))
15 band_comm = world.new_communicator(
16 range(world.rank % s, world.size, s))
17 yield domain_comm, band_comm
20def func(f):
21 """Operator for matrix elements."""
22 g = f.copy()
23 g.data *= 2.3
24 return g
27# TODO: test also UGArray
28@pytest.mark.parametrize('domain_comm, band_comm', list(comms()))
29@pytest.mark.parametrize('dtype', [float, complex])
30@pytest.mark.parametrize('nbands', [1, 7, 21])
31@pytest.mark.parametrize('function', [None, func])
32def test_me(domain_comm, band_comm, dtype, nbands, function):
33 a = 2.5
34 n = 20
35 grid = UGDesc(cell=[a, a, a], size=(n, n, n))
36 desc = PWDesc(ecut=50, cell=grid.cell)
37 desc = desc.new(comm=domain_comm, dtype=dtype)
38 f = desc.empty(nbands, comm=band_comm)
39 f.randomize()
41 M = f.matrix_elements(f, function=function)
42 out = Matrix(nbands, nbands, dist=(band_comm, -1, 1), dtype=dtype)
43 out.data[:] = 1e308 # will overflow when multiplied by 2
44 f.matrix_elements(f, function=function, out=out)
46 f1 = f.gathergather()
47 M2 = M.gather()
48 if f1 is not None:
49 M1 = f1.matrix_elements(f1, function=function)
50 M1.tril2full()
51 M2.tril2full()
52 dM = M1.data - M2.data
53 assert abs(dM).max() < 1e-11
55 if function is None:
56 g = f.new()
57 g.randomize()
58 M = f.matrix_elements(g)
60 f1 = f.gathergather()
61 g1 = g.gathergather()
62 M2 = M.gather()
63 if f1 is not None:
64 M1 = f1.matrix_elements(g1)
65 M1.tril2full()
66 M2.tril2full()
67 dM = M1.data - M2.data
68 assert abs(dM).max() < 1e-11
71if __name__ == '__main__':
72 d, b = list(comms())[0]
73 test_me(d, b, float, 4, None)