Coverage for gpaw/test/nlopt/test_nlodata.py: 92%
51 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1import pytest
2import numpy as np
4from gpaw.mpi import world
5from gpaw.nlopt.basic import NLOData
8def generate_testing_data(ns, nk, nb, seed=42):
9 rng = np.random.default_rng(seed=42)
11 w_sk = rng.random((ns, nk))
12 f_skn = rng.random((ns, nk, nb))
13 E_skn = rng.random((ns, nk, nb))
14 p_skvnn = rng.random((ns, nk, 3, nb, nb)) \
15 + 1j * rng.random((ns, nk, 3, nb, nb))
17 return w_sk, f_skn, E_skn, p_skvnn
20@pytest.mark.skipif(world.size > 1, reason='Serial only')
21def test_write_load_serial(in_tmp_dir):
22 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20)
24 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world)
25 nlo.write('nlodata.npz')
27 newdata = NLOData.load('nlodata.npz', world)
28 assert newdata.w_sk == pytest.approx(w_sk, abs=1e-16)
29 assert newdata.f_skn == pytest.approx(f_skn, abs=1e-16)
30 assert newdata.E_skn == pytest.approx(E_skn, abs=1e-16)
31 assert newdata.p_skvnn == pytest.approx(p_skvnn, abs=1e-16)
34def test_serial_file_parallel_data(in_tmp_dir):
35 # Random data only on rank = 0
36 if world.rank == 0:
37 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20)
38 else:
39 w_sk = None
40 f_skn = None
41 E_skn = None
42 p_skvnn = None
44 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world)
45 nlo.write('nlodata.npz')
46 k_info = nlo.distribute()
48 newdata = NLOData.load('nlodata.npz', world)
49 k_info_new = newdata.distribute()
50 for newdata, data in zip(k_info_new.values(), k_info.values()):
51 assert newdata[0] == pytest.approx(data[0], abs=1e-16)
52 assert newdata[1] == pytest.approx(data[1], abs=1e-16)
53 assert newdata[2] == pytest.approx(data[2], abs=1e-16)
54 assert newdata[3] == pytest.approx(data[3], abs=1e-16)
57def test_write_load_parallel(in_tmp_dir):
58 # Same random data array on each core
59 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20)
61 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world)
62 nlo.write('nlodata.npz')
64 newdata = NLOData.load('nlodata.npz', world)
65 k_info = newdata.distribute()
67 # Compare the distributed data with original data
68 for u, data in k_info.items():
69 s = 0 if u < w_sk.shape[1] else 1
70 k = u % w_sk.shape[1]
71 assert data[0] == pytest.approx(w_sk[s, k], abs=1e-16)
72 assert data[1] == pytest.approx(f_skn[s, k], abs=1e-16)
73 assert data[2] == pytest.approx(E_skn[s, k], abs=1e-16)
74 assert data[3] == pytest.approx(p_skvnn[s, k], abs=1e-16)