Coverage for gpaw/test/core/test_atom_arrays.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-08 00:17 +0000

1import numpy as np 

2import pytest 

3from gpaw.core.atom_arrays import AtomArraysLayout, AtomDistribution 

4from gpaw.mpi import world 

5from gpaw.test.core.test_matrix_elements import comms 

6from gpaw.gpu import cupy as cp 

7 

8 

9def test_aa_to_full(): 

10 d = np.array([[1, 2, 4], 

11 [2, 3, 5], 

12 [4, 5, 6]], dtype=float) 

13 a = AtomArraysLayout([(3, 3)]).empty() 

14 a[0][:] = d 

15 p = a.to_lower_triangle() 

16 assert (p[0] == [1, 2, 3, 4, 5, 6]).all() 

17 assert (p.to_full()[0] == d).all() 

18 

19 

20def test_scatter_from(): 

21 N = 9 

22 atomdist1 = AtomDistribution([0] * N, world) 

23 b1 = AtomArraysLayout([(3, 3)] * N, atomdist=atomdist1).empty(2) 

24 for a, b_sii in b1.items(): 

25 assert world.rank == 0 

26 b_sii[0] = a 

27 b_sii[1] = 2 * a 

28 b2 = b1.gather() 

29 if world.rank == 0: 

30 assert (b1.data == b2.data).all() 

31 atomdist3 = AtomDistribution.from_number_of_atoms(N, world) 

32 b3 = b1.layout.new(atomdist=atomdist3).empty(2) 

33 b3.scatter_from(b2.data if b2 is not None else None) 

34 for a, b_sii in b3.items(): 

35 assert (b_sii[0] == a).all() 

36 assert (b_sii[1] == 2 * a).all() 

37 

38 

39def test_gather(): 

40 """Two atoms on rank-1.""" 

41 r = min(1, world.size - 1) 

42 ranks = [r, r] 

43 atomdist = AtomDistribution(ranks, world) 

44 D_asii = AtomArraysLayout([(1, 1)] * 2, atomdist=atomdist).empty(1) 

45 if world.rank == r: 

46 D_asii[0][:] = 1 

47 D_asii[1][:] = 2 

48 D2_asii = D_asii.gather(broadcast=True) 

49 assert D2_asii.data.shape == (1, 2) 

50 for a, D_sii in D2_asii.items(): 

51 assert D_sii[0, 0, 0] == a + 1 

52 

53 

54@pytest.mark.gpu 

55@pytest.mark.parametrize('domain_comm, band_comm', list(comms())) 

56@pytest.mark.parametrize('xp', [np, cp]) 

57def test_P_ani_dH_aii(domain_comm, band_comm, xp): 

58 ni_a = [2, 3, 4, 17] 

59 dH_asii = AtomArraysLayout([(n, n) for n in ni_a], 

60 atomdist=domain_comm, 

61 xp=xp).empty(1) 

62 dH_asii.data[:] = 1.0 

63 P_ani = AtomArraysLayout(ni_a, 

64 dtype=complex, 

65 atomdist=domain_comm, 

66 xp=xp).empty( 

67 10, comm=band_comm) 

68 P_ani.data[:] = 1.0j 

69 out_ani = P_ani.new() 

70 P_ani.block_diag_multiply(dH_asii, out_ani, index=0) 

71 for a, out_ni in out_ani.items(): 

72 assert (out_ni == ni_a[a] * 1.0j).all()