Coverage for gpaw/test/core/test_atom_arrays.py: 100%
55 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-08 00:17 +0000
1import numpy as np
2import pytest
3from gpaw.core.atom_arrays import AtomArraysLayout, AtomDistribution
4from gpaw.mpi import world
5from gpaw.test.core.test_matrix_elements import comms
6from gpaw.gpu import cupy as cp
9def test_aa_to_full():
10 d = np.array([[1, 2, 4],
11 [2, 3, 5],
12 [4, 5, 6]], dtype=float)
13 a = AtomArraysLayout([(3, 3)]).empty()
14 a[0][:] = d
15 p = a.to_lower_triangle()
16 assert (p[0] == [1, 2, 3, 4, 5, 6]).all()
17 assert (p.to_full()[0] == d).all()
20def test_scatter_from():
21 N = 9
22 atomdist1 = AtomDistribution([0] * N, world)
23 b1 = AtomArraysLayout([(3, 3)] * N, atomdist=atomdist1).empty(2)
24 for a, b_sii in b1.items():
25 assert world.rank == 0
26 b_sii[0] = a
27 b_sii[1] = 2 * a
28 b2 = b1.gather()
29 if world.rank == 0:
30 assert (b1.data == b2.data).all()
31 atomdist3 = AtomDistribution.from_number_of_atoms(N, world)
32 b3 = b1.layout.new(atomdist=atomdist3).empty(2)
33 b3.scatter_from(b2.data if b2 is not None else None)
34 for a, b_sii in b3.items():
35 assert (b_sii[0] == a).all()
36 assert (b_sii[1] == 2 * a).all()
39def test_gather():
40 """Two atoms on rank-1."""
41 r = min(1, world.size - 1)
42 ranks = [r, r]
43 atomdist = AtomDistribution(ranks, world)
44 D_asii = AtomArraysLayout([(1, 1)] * 2, atomdist=atomdist).empty(1)
45 if world.rank == r:
46 D_asii[0][:] = 1
47 D_asii[1][:] = 2
48 D2_asii = D_asii.gather(broadcast=True)
49 assert D2_asii.data.shape == (1, 2)
50 for a, D_sii in D2_asii.items():
51 assert D_sii[0, 0, 0] == a + 1
54@pytest.mark.gpu
55@pytest.mark.parametrize('domain_comm, band_comm', list(comms()))
56@pytest.mark.parametrize('xp', [np, cp])
57def test_P_ani_dH_aii(domain_comm, band_comm, xp):
58 ni_a = [2, 3, 4, 17]
59 dH_asii = AtomArraysLayout([(n, n) for n in ni_a],
60 atomdist=domain_comm,
61 xp=xp).empty(1)
62 dH_asii.data[:] = 1.0
63 P_ani = AtomArraysLayout(ni_a,
64 dtype=complex,
65 atomdist=domain_comm,
66 xp=xp).empty(
67 10, comm=band_comm)
68 P_ani.data[:] = 1.0j
69 out_ani = P_ani.new()
70 P_ani.block_diag_multiply(dH_asii, out_ani, index=0)
71 for a, out_ni in out_ani.items():
72 assert (out_ni == ni_a[a] * 1.0j).all()