Coverage for gpaw/response/context.py: 79%
75 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1from __future__ import annotations
2from typing import Union
3from pathlib import Path
4from time import ctime
5from sys import stdout
7from inspect import isgeneratorfunction
8from functools import wraps
10from ase.utils import IOContext
11from ase.utils.timing import Timer
13import gpaw.mpi as mpi
16TXTFilename = Union[Path, str]
17ResponseContextInput = Union['ResponseContext', dict, TXTFilename]
20class ResponseContext:
21 def __init__(self, txt: TXTFilename = '-',
22 timer=None, comm=mpi.world, mode='w'):
23 self.comm = comm
24 self.iocontext = IOContext()
25 self.open(txt, mode)
26 self.set_timer(timer)
28 @staticmethod
29 def from_input(context: ResponseContextInput) -> ResponseContext:
30 if isinstance(context, ResponseContext):
31 return context
32 elif isinstance(context, dict):
33 return ResponseContext(**context)
34 elif isinstance(context, (Path, str)): # TXTFilename
35 return ResponseContext(txt=context)
36 raise ValueError('Expected ResponseContextInput, got', context)
38 def open(self, txt, mode):
39 if txt is stdout and self.comm.rank != 0:
40 txt = None
41 self.fd = self.iocontext.openfile(txt, self.comm, mode)
43 def set_timer(self, timer):
44 self.timer = timer or Timer()
46 def close(self):
47 self.iocontext.close()
49 def __del__(self):
50 self.close()
52 def with_txt(self, txt, mode='w'):
53 return ResponseContext(txt=txt, comm=self.comm, timer=self.timer,
54 mode=mode)
56 def print(self, *args, flush=True, **kwargs):
57 print(*args, file=self.fd, flush=flush, **kwargs)
59 def new_txt_and_timer(self, txt, timer=None):
60 self.write_timer()
61 # Close old output file and create a new
62 self.close()
63 self.open(txt, mode='w')
64 self.set_timer(timer)
66 def write_timer(self):
67 self.timer.write(self.fd)
68 self.print(ctime())
71class timer:
72 """Decorator for timing a method call.
73 NB: Includes copy-paste from ase, which is suboptimal...
75 Example::
77 from gpaw.response.context import timer
79 class A:
80 def __init__(self, context):
81 self.context = context
83 @timer('Add two numbers')
84 def add(self, x, y):
85 return x + y
87 """
88 def __init__(self, name):
89 self.name = name
91 def __call__(self, method):
92 if isgeneratorfunction(method):
93 @wraps(method)
94 def new_method(slf, *args, **kwargs):
95 gen = method(slf, *args, **kwargs)
96 while True:
97 slf.context.timer.start(self.name)
98 try:
99 x = next(gen)
100 except StopIteration:
101 break
102 finally:
103 slf.context.timer.stop()
104 yield x
105 else:
106 @wraps(method)
107 def new_method(slf, *args, **kwargs):
108 slf.context.timer.start(self.name)
109 x = method(slf, *args, **kwargs)
110 try:
111 slf.context.timer.stop()
112 except IndexError:
113 pass
114 return x
115 return new_method