Coverage for gpaw/utilities/timelimit.py: 80%

99 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import time 

2import numpy as np 

3 

4from gpaw.analyse.observers import Observer 

5 

6 

7def time_to_seconds(timestr): 

8 """Convert time to seconds 

9 

10 Parameters: 

11 

12 timestr: float or string 

13 Float in seconds or string in format 

14 'DD-HH:MM:SS', 'HH:MM:SS', 'MM:SS', or 'SS'. 

15 """ 

16 try: 

17 return float(timestr) 

18 except ValueError: 

19 pass 

20 time = 0.0 

21 d_i = timestr.split('-') 

22 if len(d_i) > 1: 

23 assert len(d_i) == 2 

24 time += int(d_i[0]) * 24 * 60 * 60 

25 timestr = d_i[1] 

26 d_i = timestr.split(':') 

27 mult = 1 

28 for d in d_i[::-1]: 

29 time += int(d) * mult 

30 mult *= 60 

31 return time 

32 

33 

34class TimeLimiter(Observer): 

35 """Class for automatically breaking the loops of GPAW calculation. 

36 

37 The time estimation is done by a polynomial fit to 

38 the data `(i, dt)`, where `i` is the iteration index and 

39 `dt` is the calculation time of that iteration. 

40 

41 The loop is broken by adjusting paw.maxiter value (or equivalent). 

42 """ 

43 

44 # Keywords for supported loops 

45 scf = 'scf' 

46 tddft = 'tddft' 

47 

48 def __init__(self, paw, timestart=None, timelimit='10:00', 

49 output=None, interval=1): 

50 """__init__ method 

51 

52 Parameters: 

53 

54 paw: 

55 GPAW calculator 

56 timestart: float 

57 The start time defining the "zero time". 

58 Format: as given by time.time(). 

59 timelimit: float or string 

60 The allowed run time counted from `timestart`. 

61 Format: any supported by function `time_to_seconds()`. 

62 output: str 

63 The name of the output file for dumping the time estimates. 

64 """ 

65 Observer.__init__(self, interval) 

66 self.timelimit = time_to_seconds(timelimit) 

67 if timestart is None: 

68 self.time0 = time.time() 

69 else: 

70 self.time0 = timestart 

71 self.comm = paw.world 

72 self.output = output 

73 self.do_output = self.output is not None 

74 if self.comm.rank == 0 and self.do_output: 

75 self.outf = open(self.output, 'w') 

76 self.loop = None 

77 paw.attach(self, interval, paw) 

78 

79 def reset(self, loop, order=0, min_updates=5): 

80 """Reset the time estimation. 

81 

82 Parameters: 

83 

84 loop: str 

85 The keyword of the controlled loop. 

86 order: int 

87 The polynomial order of the fit used to estimate 

88 the run time between each update. 

89 min_updates: int 

90 The minimum number of updates until time estimates are given. 

91 """ 

92 if loop not in [self.scf, self.tddft]: 

93 raise RuntimeError(f'Unsupported loop type: {loop}') 

94 self.loop = loop 

95 if self.comm.rank == 0: 

96 self.order = order 

97 self.min_updates = max(min_updates, order + 1) 

98 self.time_t = [time.time()] # Add the initial time 

99 self.iteridx_t = [] 

100 

101 def update(self, paw): 

102 """Update time estimate and break calculation if necessary.""" 

103 # Select the iteration index 

104 if self.loop is None: 

105 return 

106 elif self.loop == self.scf: 

107 iteridx = paw.scf.niter 

108 elif self.loop == self.tddft: 

109 iteridx = paw.niter 

110 

111 # Update the arrays 

112 if self.comm.rank == 0: 

113 self.time_t.append(time.time()) 

114 self.iteridx_t.append(iteridx) 

115 self.p_k = None 

116 

117 if self.do_output: 

118 timediff = self.time_t[-1] - self.time_t[-2] 

119 line = 'update %12d %12.4f' % (iteridx, timediff) 

120 self.outf.write('%s\n' % line) 

121 # self.outf.flush() 

122 

123 # Check if there is time to do the next iteration 

124 if not self.has_time(iteridx + self.interval): 

125 # The calling loop is assumed to do "niter += 1" 

126 # after calling observers 

127 paw.log('{}: Breaking the loop ' 

128 'due to the time limit'.format(self.__class__.__name__)) 

129 if self.loop == self.scf: 

130 paw.scf.maxiter = iteridx 

131 elif self.loop == self.tddft: 

132 paw.maxiter = iteridx 

133 

134 def eta(self, iteridx): 

135 """Estimate the time required to calculate the iteration of 

136 the given index `iteridx`.""" 

137 if self.comm.rank == 0: 

138 if len(self.iteridx_t) < self.min_updates: 

139 eta = 0.0 

140 else: 

141 if self.p_k is None: 

142 iteridx_t = np.array(self.iteridx_t) 

143 time_t = np.array(self.time_t) 

144 timediff_t = time_t[1:] - time_t[:-1] 

145 

146 self.p_k = np.polyfit(iteridx_t, 

147 timediff_t, 

148 self.order) 

149 if type(iteridx) in (int, float): 

150 iteridx = [iteridx] 

151 iteridx_i = np.array(iteridx) 

152 eta = max(0.0, np.sum(np.polyval(self.p_k, iteridx_i))) 

153 

154 if self.do_output: 

155 line = 'eta %12s %12.4f' % (iteridx, eta) 

156 self.outf.write('%s\n' % line) 

157 return eta 

158 else: 

159 return None 

160 

161 def has_time(self, iteridx): 

162 """Check if there is still time to calculate the iteration of 

163 the given index `iteridx`.""" 

164 if self.timelimit is None: 

165 return True 

166 # Calculate eta on master and broadcast to all ranks 

167 data_i = np.empty(1, dtype=int) 

168 if self.comm.rank == 0: 

169 if len(self.iteridx_t) < self.min_updates: 

170 data_i[0] = True 

171 else: 

172 time_required = self.eta(iteridx) 

173 time_available = self.timelimit - (time.time() - self.time0) 

174 data_i[0] = time_required < time_available 

175 self.comm.broadcast(data_i, 0) 

176 return bool(data_i[0]) 

177 

178 def __del__(self): 

179 if self.comm.rank == 0 and self.do_output: 

180 self.outf.close()