Coverage for gpaw/test/xc/test_qna_stress.py: 82%

49 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-12 00:18 +0000

1import pytest 

2import numpy as np 

3from ase.parallel import parprint 

4from gpaw import GPAW 

5from gpaw.mpi import world 

6 

7 

8def numeric_stress(atoms, d=1e-6, component=None): 

9 cell = atoms.cell.copy() 

10 V = atoms.get_volume() 

11 for i in range(3): 

12 x = np.eye(3) 

13 if component == (i, i): 

14 x[i, i] += d 

15 atoms.set_cell(np.dot(cell, x), scale_atoms=True) 

16 eplus = atoms.get_potential_energy(force_consistent=True) 

17 

18 x[i, i] -= 2 * d 

19 atoms.set_cell(np.dot(cell, x), scale_atoms=True) 

20 eminus = atoms.get_potential_energy(force_consistent=True) 

21 atoms.set_cell(cell, scale_atoms=True) 

22 return (eplus - eminus) / (2 * d * V) 

23 

24 if (component == (i, (i - 2) % 3)) or (component == ((i - 2) % 3, i)): 

25 j = i - 2 

26 x[i, j] = d 

27 x[j, i] = d 

28 atoms.set_cell(np.dot(cell, x), scale_atoms=True) 

29 eplus = atoms.get_potential_energy(force_consistent=True) 

30 

31 x[i, j] = -d 

32 x[j, i] = -d 

33 atoms.set_cell(np.dot(cell, x), scale_atoms=True) 

34 eminus = atoms.get_potential_energy(force_consistent=True) 

35 atoms.set_cell(cell, scale_atoms=True) 

36 return (eplus - eminus) / (4 * d * V) 

37 

38 raise ValueError(f'Invalid component {component}') 

39 

40 

41@pytest.mark.old_gpaw_only 

42@pytest.mark.skipif(world.size > 1, reason='See #898') 

43def test_xc_qna_stress(in_tmp_dir, gpw_files): 

44 calc = GPAW(gpw_files['Cu3Au_qna']) 

45 atoms = calc.get_atoms() 

46 atoms.set_cell(np.dot(atoms.cell, 

47 [[1.02, 0, 0.03], 

48 [0, 0.99, -0.02], 

49 [0.2, -0.01, 1.03]]), 

50 scale_atoms=True) 

51 

52 s_analytical = atoms.get_stress(voigt=False) 

53 print(s_analytical) 

54 components = [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)] 

55 for componentid in [1]: 

56 component = components[componentid] 

57 s_numerical = numeric_stress(atoms, 1e-5, component) 

58 s_err = s_numerical - s_analytical.__getitem__(component) 

59 

60 parprint('Analytical stress:', s_analytical.__getitem__(component)) 

61 parprint('Numerical stress :', s_numerical) 

62 parprint('Error in stress :', s_err) 

63 assert np.abs(s_err) < 0.002