Coverage for gpaw/test/__init__.py: 70%

60 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from functools import wraps 

2from typing import Tuple 

3 

4import gpaw.mpi as mpi 

5import numpy as np 

6from gpaw.atom.configurations import parameters, tf_parameters 

7from gpaw.atom.generator import Generator 

8from gpaw.typing import Array1D 

9 

10 

11def print_reference(data_i, name='ref_i', fmt='%.12le'): 

12 if mpi.world.rank == 0: 

13 print('%s = [' % name, end='') 

14 for i, val in enumerate(data_i): 

15 if i > 0: 

16 print('', end='\n') 

17 print(' ' * (len(name) + 4), end='') 

18 print(fmt % val, end='') 

19 print(',', end='') 

20 print('\b]') 

21 

22 

23def findpeak(x: Array1D, y: Array1D) -> Tuple[float, float]: 

24 """Find peak. 

25 

26 >>> x = np.linspace(1, 5, 10) 

27 >>> y = 1 - (x - np.pi)**2 

28 >>> x0, y0 = findpeak(x, y) 

29 >>> f'x0={x0:.6f}, y0={y0:.6f}' 

30 'x0=3.141593, y0=1.000000' 

31 """ 

32 i = y.argmax() 

33 a, b, c = np.polyfit(x[i - 1:i + 2] - x[i], y[i - 1:i + 2], 2) 

34 assert a < 0 

35 dx = -0.5 * b / a 

36 x0 = x[i] + dx 

37 return x0, a * dx**2 + b * dx + c 

38 

39 

40def gen(symbol, exx=False, name=None, yukawa_gamma=None, 

41 write_xml=False, **kwargs): 

42 setup = None 

43 if mpi.rank == 0: 

44 if 'scalarrel' not in kwargs: 

45 kwargs['scalarrel'] = True 

46 g = Generator(symbol, **kwargs) 

47 if 'orbital_free' in kwargs: 

48 setup = g.run(exx=exx, name=name, yukawa_gamma=yukawa_gamma, 

49 write_xml=write_xml, 

50 **tf_parameters.get(symbol, {'rcut': 0.9})) 

51 else: 

52 setup = g.run(exx=exx, name=name, yukawa_gamma=yukawa_gamma, 

53 write_xml=write_xml, 

54 **parameters[symbol]) 

55 setup = mpi.broadcast(setup, 0) 

56 return setup 

57 

58 

59def only_on_master(comm, broadcast=None): 

60 """Decorator for executing the function only on the rank 0. 

61 

62 Parameters 

63 ---------- 

64 comm 

65 communicator 

66 broadcast 

67 function for broadcasting the return value or 

68 `None` for no broadcasting 

69 """ 

70 def wrap(func): 

71 @wraps(func) 

72 def wrapped_func(*args, **kwargs): 

73 if comm.rank == 0: 

74 ret = func(*args, **kwargs) 

75 else: 

76 ret = None 

77 comm.barrier() 

78 if broadcast is not None: 

79 ret = broadcast(ret, comm=comm) 

80 return ret 

81 return wrapped_func 

82 return wrap 

83 

84 

85def calculate_numerical_forces(atoms, eps=1e-6, iatoms=None, icarts=None): 

86 try: 

87 from ase.calculators.fd import calculate_numerical_forces as cnf 

88 except ImportError: 

89 pass 

90 else: 

91 return cnf(atoms, eps, iatoms, icarts) 

92 from ase.calculators.test import numeric_force 

93 if iatoms is None: 

94 iatoms = range(len(atoms)) 

95 if icarts is None: 

96 icarts = [0, 1, 2] 

97 return np.array( 

98 [[numeric_force(atoms, a, c, eps) for c in icarts] for a in iatoms])