Coverage for gpaw/lrtddft/finite_differences.py: 76%
71 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import os.path
2import numpy as np
3from ase import parallel as mpi
5from gpaw.lrtddft.excitation import ExcitationLogger
8class FiniteDifference:
9 def __init__(self, atoms, propertyfunction,
10 save=False, name='fd', ending='',
11 d=0.001, parallel=1, log=None, txt='-', world=mpi.world):
12 """
13 atoms: Atoms object
14 The atoms to work on.
15 propertyfunction: function that returns a single number.
16 The finite difference calculation is progressed on this value.
17 For proper parallel usage the function should either be
18 either a property of the atom object
19 fd = FiniteDifference(atoms, atoms.property_xyz)
20 or an arbitrary function with the keyword "atoms"
21 fd = FiniteDifference(atoms, function_xyz)
22 xyz = fd.run(atoms=atoms)
23 d: float
24 Magnitude of displacements.
25 save: If true the write statement of the calculator is called
26 to save the displacementsteps.
27 name: string
28 Name for restart data
29 ending: string
30 File handel for restart data
31 parallel: int
32 splits the mpi.world into 'parallel' subprocs that calculate
33 displacements of different atoms individually.
34 """
36 self.atoms = atoms
37 self.indices = np.asarray(range(len(atoms)))
38 self.propertyfunction = propertyfunction
39 self.save = save
40 self.name = name
41 self.ending = ending
42 self.d = d
43 self.world = world
45 if log is not None:
46 self.log = log
47 else:
48 self.log = ExcitationLogger(world=mpi.world)
49 self.log.fd = txt
51 if parallel > world.size:
52 self.log('#', (self.__class__.__name__ + ':'),
53 'Serial calculation, keyword parallel ignored.')
54 parallel = 1
55 self.parallel = parallel
57 assert world.size % parallel == 0
59 natoms = len(self.atoms)
60 self.cores_per_atom = world.size // parallel
61 # my workers index
62 myi = world.rank // self.cores_per_atom
63 # distribute work
64 self.myindices = []
65 for a in range(natoms):
66 if a % parallel == myi:
67 self.myindices.append(a)
68 # print(world.rank, 'myindices', self.myindices)
70 def calculate(self, a, i, filename='fd', **kwargs):
71 """Evaluate finite difference along i'th axis on a'th atom.
72 This will trigger two calls to propertyfunction(), with atom a moved
73 plus/minus d in the i'th axial direction, respectively.
74 if save is True the +- states are saved after
75 the calculation
76 """
77 if 'atoms' in kwargs:
78 kwargs['atoms'] = self.atoms
80 p0 = self.atoms.positions[a, i]
82 self.atoms.positions[a, i] += self.d
83 eplus = self.propertyfunction(**kwargs)
84 if self.save is True:
85 savecalc = self.atoms.calc
86 savecalc.write(filename + '+' + self.ending)
88 self.atoms.positions[a, i] -= 2 * self.d
89 eminus = self.propertyfunction(**kwargs)
90 if self.save is True:
91 savecalc = self.atoms.calc
92 savecalc.write(filename + '-' + self.ending)
93 self.atoms.positions[a, i] = p0
95 self.value[a, i] = (eminus - eplus) / (2 * self.d)
97 if self.parallel > 1 and self.world.rank == 0:
98 self.log('# rank', mpi.world.rank, 'Atom', a,
99 'direction', i, 'FD: ', self.value[a, i])
100 else:
101 self.log('Atom', a, 'direction', i,
102 'FD: ', self.value[a, i])
104 def run(self, **kwargs):
105 """Evaluate finite differences for all atoms
106 """
107 self.value = np.zeros([len(self.atoms), 3])
109 for filename, a, i in self.displacements():
110 if a in self.myindices:
111 self.calculate(a, i, filename=filename, **kwargs)
113 self.world.barrier()
114 self.value /= self.cores_per_atom
115 self.world.sum(self.value)
117 return self.value
119 def displacements(self):
120 for a in self.indices:
121 for i in range(3):
122 filename = ('{}_{}_{}'.format(self.name, a, 'xyz'[i]))
123 yield filename, a, i
125 def restart(self, restartfunction, **kwargs):
126 """Uses restartfunction to recalculate values
127 from the saved files.
128 If a file with the corresponding name is found the
129 restartfunction is called to get the FD value
130 The restartfunction should take a string as input
131 parameter like the standart read() function.
132 If no file is found, a calculation is initiated.
133 Example:
134 def re(self, name):
135 calc = Calculator(restart=name)
136 return calc.get_potential_energy()
138 fd = FiniteDifference(atoms, atoms.get_potential_energy)
139 fd.restart(re)
140 """
141 for filename, a, i in self.displacements():
143 if (os.path.isfile(filename + '+' + self.ending) and
144 os.path.isfile(filename + '-' + self.ending)):
145 eplus = restartfunction(
146 self, filename + '+' + self.ending, **kwargs)
147 eminus = restartfunction(
148 self, filename + '-' + self.ending, **kwargs)
149 self.value[a, i] = (eminus - eplus) / (2 * self.d)
150 else:
151 self.calculate(a, i, filename=filename, **kwargs)
153 return self.value