Coverage for gpaw/test/nlopt/test_nlodata.py: 92%

51 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1import pytest 

2import numpy as np 

3 

4from gpaw.mpi import world 

5from gpaw.nlopt.basic import NLOData 

6 

7 

8def generate_testing_data(ns, nk, nb, seed=42): 

9 rng = np.random.default_rng(seed=42) 

10 

11 w_sk = rng.random((ns, nk)) 

12 f_skn = rng.random((ns, nk, nb)) 

13 E_skn = rng.random((ns, nk, nb)) 

14 p_skvnn = rng.random((ns, nk, 3, nb, nb)) \ 

15 + 1j * rng.random((ns, nk, 3, nb, nb)) 

16 

17 return w_sk, f_skn, E_skn, p_skvnn 

18 

19 

20@pytest.mark.skipif(world.size > 1, reason='Serial only') 

21def test_write_load_serial(in_tmp_dir): 

22 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20) 

23 

24 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world) 

25 nlo.write('nlodata.npz') 

26 

27 newdata = NLOData.load('nlodata.npz', world) 

28 assert newdata.w_sk == pytest.approx(w_sk, abs=1e-16) 

29 assert newdata.f_skn == pytest.approx(f_skn, abs=1e-16) 

30 assert newdata.E_skn == pytest.approx(E_skn, abs=1e-16) 

31 assert newdata.p_skvnn == pytest.approx(p_skvnn, abs=1e-16) 

32 

33 

34def test_serial_file_parallel_data(in_tmp_dir): 

35 # Random data only on rank = 0 

36 if world.rank == 0: 

37 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20) 

38 else: 

39 w_sk = None 

40 f_skn = None 

41 E_skn = None 

42 p_skvnn = None 

43 

44 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world) 

45 nlo.write('nlodata.npz') 

46 k_info = nlo.distribute() 

47 

48 newdata = NLOData.load('nlodata.npz', world) 

49 k_info_new = newdata.distribute() 

50 for newdata, data in zip(k_info_new.values(), k_info.values()): 

51 assert newdata[0] == pytest.approx(data[0], abs=1e-16) 

52 assert newdata[1] == pytest.approx(data[1], abs=1e-16) 

53 assert newdata[2] == pytest.approx(data[2], abs=1e-16) 

54 assert newdata[3] == pytest.approx(data[3], abs=1e-16) 

55 

56 

57def test_write_load_parallel(in_tmp_dir): 

58 # Same random data array on each core 

59 w_sk, f_skn, E_skn, p_skvnn = generate_testing_data(2, 4, 20) 

60 

61 nlo = NLOData(w_sk, f_skn, E_skn, p_skvnn, world) 

62 nlo.write('nlodata.npz') 

63 

64 newdata = NLOData.load('nlodata.npz', world) 

65 k_info = newdata.distribute() 

66 

67 # Compare the distributed data with original data 

68 for u, data in k_info.items(): 

69 s = 0 if u < w_sk.shape[1] else 1 

70 k = u % w_sk.shape[1] 

71 assert data[0] == pytest.approx(w_sk[s, k], abs=1e-16) 

72 assert data[1] == pytest.approx(f_skn[s, k], abs=1e-16) 

73 assert data[2] == pytest.approx(E_skn[s, k], abs=1e-16) 

74 assert data[3] == pytest.approx(p_skvnn[s, k], abs=1e-16)