Coverage for gpaw/response/context.py: 79%

75 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1from __future__ import annotations 

2from typing import Union 

3from pathlib import Path 

4from time import ctime 

5from sys import stdout 

6 

7from inspect import isgeneratorfunction 

8from functools import wraps 

9 

10from ase.utils import IOContext 

11from ase.utils.timing import Timer 

12 

13import gpaw.mpi as mpi 

14 

15 

16TXTFilename = Union[Path, str] 

17ResponseContextInput = Union['ResponseContext', dict, TXTFilename] 

18 

19 

20class ResponseContext: 

21 def __init__(self, txt: TXTFilename = '-', 

22 timer=None, comm=mpi.world, mode='w'): 

23 self.comm = comm 

24 self.iocontext = IOContext() 

25 self.open(txt, mode) 

26 self.set_timer(timer) 

27 

28 @staticmethod 

29 def from_input(context: ResponseContextInput) -> ResponseContext: 

30 if isinstance(context, ResponseContext): 

31 return context 

32 elif isinstance(context, dict): 

33 return ResponseContext(**context) 

34 elif isinstance(context, (Path, str)): # TXTFilename 

35 return ResponseContext(txt=context) 

36 raise ValueError('Expected ResponseContextInput, got', context) 

37 

38 def open(self, txt, mode): 

39 if txt is stdout and self.comm.rank != 0: 

40 txt = None 

41 self.fd = self.iocontext.openfile(txt, self.comm, mode) 

42 

43 def set_timer(self, timer): 

44 self.timer = timer or Timer() 

45 

46 def close(self): 

47 self.iocontext.close() 

48 

49 def __del__(self): 

50 self.close() 

51 

52 def with_txt(self, txt, mode='w'): 

53 return ResponseContext(txt=txt, comm=self.comm, timer=self.timer, 

54 mode=mode) 

55 

56 def print(self, *args, flush=True, **kwargs): 

57 print(*args, file=self.fd, flush=flush, **kwargs) 

58 

59 def new_txt_and_timer(self, txt, timer=None): 

60 self.write_timer() 

61 # Close old output file and create a new 

62 self.close() 

63 self.open(txt, mode='w') 

64 self.set_timer(timer) 

65 

66 def write_timer(self): 

67 self.timer.write(self.fd) 

68 self.print(ctime()) 

69 

70 

71class timer: 

72 """Decorator for timing a method call. 

73 NB: Includes copy-paste from ase, which is suboptimal... 

74 

75 Example:: 

76 

77 from gpaw.response.context import timer 

78 

79 class A: 

80 def __init__(self, context): 

81 self.context = context 

82 

83 @timer('Add two numbers') 

84 def add(self, x, y): 

85 return x + y 

86 

87 """ 

88 def __init__(self, name): 

89 self.name = name 

90 

91 def __call__(self, method): 

92 if isgeneratorfunction(method): 

93 @wraps(method) 

94 def new_method(slf, *args, **kwargs): 

95 gen = method(slf, *args, **kwargs) 

96 while True: 

97 slf.context.timer.start(self.name) 

98 try: 

99 x = next(gen) 

100 except StopIteration: 

101 break 

102 finally: 

103 slf.context.timer.stop() 

104 yield x 

105 else: 

106 @wraps(method) 

107 def new_method(slf, *args, **kwargs): 

108 slf.context.timer.start(self.name) 

109 x = method(slf, *args, **kwargs) 

110 try: 

111 slf.context.timer.stop() 

112 except IndexError: 

113 pass 

114 return x 

115 return new_method