Coverage for gpaw/new/__init__.py: 87%
68 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-20 00:19 +0000
1"""New ground-state DFT code."""
2from __future__ import annotations
3from collections import defaultdict
4from contextlib import contextmanager
5from time import time
6from typing import Iterable, TYPE_CHECKING
7if TYPE_CHECKING:
8 from gpaw.core import UGArray
11from gpaw.new.timer import trace, tracectx # noqa
14def prod(iterable: Iterable[int]) -> int:
15 """Simple int product.
17 >>> prod([])
18 1
19 >>> prod([2, 3])
20 6
21 """
22 result = 1
23 for x in iterable:
24 result *= x
25 return result
28def zips(*iterables, strict=True):
29 """From PEP 618."""
30 if not iterables:
31 return
32 iterators = tuple(iter(iterable) for iterable in iterables)
33 try:
34 while True:
35 items = []
36 for iterator in iterators:
37 items.append(next(iterator))
38 yield tuple(items)
39 except StopIteration:
40 pass
41 if not strict:
42 return
43 if items:
44 i = len(items)
45 plural = " " if i == 1 else "s 1-"
46 msg = f"zips() argument {i + 1} is shorter than argument{plural}{i}"
47 raise ValueError(msg)
48 sentinel = object()
49 for i, iterator in enumerate(iterators[1:], 1):
50 if next(iterator, sentinel) is not sentinel:
51 plural = " " if i == 1 else "s 1-"
52 msg = f"zips() argument {i + 1} is longer than argument{plural}{i}"
53 raise ValueError(msg)
56def spinsum(a_sX: UGArray, mean: bool = False) -> UGArray:
57 if a_sX.dims[0] == 2:
58 a_X = a_sX.desc.empty(xp=a_sX.xp)
59 a_sX.data[:2].sum(axis=0, out=a_X.data)
60 if mean:
61 a_X.data *= 0.5
62 return a_X
63 return a_sX[0]
66class Timer:
67 def __init__(self):
68 self.times = defaultdict(float)
69 self.times['Total'] = -time()
71 @contextmanager
72 def __call__(self, name):
73 t1 = time()
74 try:
75 yield
76 finally:
77 t2 = time()
78 self.times[name] += t2 - t1
80 def write(self, log):
81 self.times['Total'] += time()
82 total = self.times['Total']
83 log('\ntiming: # [seconds]')
84 n = max(len(name) for name in self.times) + 2
85 w = len(f'{total:.3f}')
86 N = 71 - n - w
87 for name, t in self.times.items():
88 m = int(round(2 * N * t / total))
89 bar = '━' * (m // 2) + '╸' * (m % 2)
90 log(f' {name + ":":{n}}{t:{w}.3f} # {bar}')