Coverage for gpaw/lrtddft/finite_differences.py: 76%

71 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import os.path 

2import numpy as np 

3from ase import parallel as mpi 

4 

5from gpaw.lrtddft.excitation import ExcitationLogger 

6 

7 

8class FiniteDifference: 

9 def __init__(self, atoms, propertyfunction, 

10 save=False, name='fd', ending='', 

11 d=0.001, parallel=1, log=None, txt='-', world=mpi.world): 

12 """ 

13 atoms: Atoms object 

14 The atoms to work on. 

15 propertyfunction: function that returns a single number. 

16 The finite difference calculation is progressed on this value. 

17 For proper parallel usage the function should either be 

18 either a property of the atom object 

19 fd = FiniteDifference(atoms, atoms.property_xyz) 

20 or an arbitrary function with the keyword "atoms" 

21 fd = FiniteDifference(atoms, function_xyz) 

22 xyz = fd.run(atoms=atoms) 

23 d: float 

24 Magnitude of displacements. 

25 save: If true the write statement of the calculator is called 

26 to save the displacementsteps. 

27 name: string 

28 Name for restart data 

29 ending: string 

30 File handel for restart data 

31 parallel: int 

32 splits the mpi.world into 'parallel' subprocs that calculate 

33 displacements of different atoms individually. 

34 """ 

35 

36 self.atoms = atoms 

37 self.indices = np.asarray(range(len(atoms))) 

38 self.propertyfunction = propertyfunction 

39 self.save = save 

40 self.name = name 

41 self.ending = ending 

42 self.d = d 

43 self.world = world 

44 

45 if log is not None: 

46 self.log = log 

47 else: 

48 self.log = ExcitationLogger(world=mpi.world) 

49 self.log.fd = txt 

50 

51 if parallel > world.size: 

52 self.log('#', (self.__class__.__name__ + ':'), 

53 'Serial calculation, keyword parallel ignored.') 

54 parallel = 1 

55 self.parallel = parallel 

56 

57 assert world.size % parallel == 0 

58 

59 natoms = len(self.atoms) 

60 self.cores_per_atom = world.size // parallel 

61 # my workers index 

62 myi = world.rank // self.cores_per_atom 

63 # distribute work 

64 self.myindices = [] 

65 for a in range(natoms): 

66 if a % parallel == myi: 

67 self.myindices.append(a) 

68 # print(world.rank, 'myindices', self.myindices) 

69 

70 def calculate(self, a, i, filename='fd', **kwargs): 

71 """Evaluate finite difference along i'th axis on a'th atom. 

72 This will trigger two calls to propertyfunction(), with atom a moved 

73 plus/minus d in the i'th axial direction, respectively. 

74 if save is True the +- states are saved after 

75 the calculation 

76 """ 

77 if 'atoms' in kwargs: 

78 kwargs['atoms'] = self.atoms 

79 

80 p0 = self.atoms.positions[a, i] 

81 

82 self.atoms.positions[a, i] += self.d 

83 eplus = self.propertyfunction(**kwargs) 

84 if self.save is True: 

85 savecalc = self.atoms.calc 

86 savecalc.write(filename + '+' + self.ending) 

87 

88 self.atoms.positions[a, i] -= 2 * self.d 

89 eminus = self.propertyfunction(**kwargs) 

90 if self.save is True: 

91 savecalc = self.atoms.calc 

92 savecalc.write(filename + '-' + self.ending) 

93 self.atoms.positions[a, i] = p0 

94 

95 self.value[a, i] = (eminus - eplus) / (2 * self.d) 

96 

97 if self.parallel > 1 and self.world.rank == 0: 

98 self.log('# rank', mpi.world.rank, 'Atom', a, 

99 'direction', i, 'FD: ', self.value[a, i]) 

100 else: 

101 self.log('Atom', a, 'direction', i, 

102 'FD: ', self.value[a, i]) 

103 

104 def run(self, **kwargs): 

105 """Evaluate finite differences for all atoms 

106 """ 

107 self.value = np.zeros([len(self.atoms), 3]) 

108 

109 for filename, a, i in self.displacements(): 

110 if a in self.myindices: 

111 self.calculate(a, i, filename=filename, **kwargs) 

112 

113 self.world.barrier() 

114 self.value /= self.cores_per_atom 

115 self.world.sum(self.value) 

116 

117 return self.value 

118 

119 def displacements(self): 

120 for a in self.indices: 

121 for i in range(3): 

122 filename = ('{}_{}_{}'.format(self.name, a, 'xyz'[i])) 

123 yield filename, a, i 

124 

125 def restart(self, restartfunction, **kwargs): 

126 """Uses restartfunction to recalculate values 

127 from the saved files. 

128 If a file with the corresponding name is found the 

129 restartfunction is called to get the FD value 

130 The restartfunction should take a string as input 

131 parameter like the standart read() function. 

132 If no file is found, a calculation is initiated. 

133 Example: 

134 def re(self, name): 

135 calc = Calculator(restart=name) 

136 return calc.get_potential_energy() 

137 

138 fd = FiniteDifference(atoms, atoms.get_potential_energy) 

139 fd.restart(re) 

140 """ 

141 for filename, a, i in self.displacements(): 

142 

143 if (os.path.isfile(filename + '+' + self.ending) and 

144 os.path.isfile(filename + '-' + self.ending)): 

145 eplus = restartfunction( 

146 self, filename + '+' + self.ending, **kwargs) 

147 eminus = restartfunction( 

148 self, filename + '-' + self.ending, **kwargs) 

149 self.value[a, i] = (eminus - eplus) / (2 * self.d) 

150 else: 

151 self.calculate(a, i, filename=filename, **kwargs) 

152 

153 return self.value