Coverage for gpaw/new/__init__.py: 87%

68 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-20 00:19 +0000

1"""New ground-state DFT code.""" 

2from __future__ import annotations 

3from collections import defaultdict 

4from contextlib import contextmanager 

5from time import time 

6from typing import Iterable, TYPE_CHECKING 

7if TYPE_CHECKING: 

8 from gpaw.core import UGArray 

9 

10 

11from gpaw.new.timer import trace, tracectx # noqa 

12 

13 

14def prod(iterable: Iterable[int]) -> int: 

15 """Simple int product. 

16 

17 >>> prod([]) 

18 1 

19 >>> prod([2, 3]) 

20 6 

21 """ 

22 result = 1 

23 for x in iterable: 

24 result *= x 

25 return result 

26 

27 

28def zips(*iterables, strict=True): 

29 """From PEP 618.""" 

30 if not iterables: 

31 return 

32 iterators = tuple(iter(iterable) for iterable in iterables) 

33 try: 

34 while True: 

35 items = [] 

36 for iterator in iterators: 

37 items.append(next(iterator)) 

38 yield tuple(items) 

39 except StopIteration: 

40 pass 

41 if not strict: 

42 return 

43 if items: 

44 i = len(items) 

45 plural = " " if i == 1 else "s 1-" 

46 msg = f"zips() argument {i + 1} is shorter than argument{plural}{i}" 

47 raise ValueError(msg) 

48 sentinel = object() 

49 for i, iterator in enumerate(iterators[1:], 1): 

50 if next(iterator, sentinel) is not sentinel: 

51 plural = " " if i == 1 else "s 1-" 

52 msg = f"zips() argument {i + 1} is longer than argument{plural}{i}" 

53 raise ValueError(msg) 

54 

55 

56def spinsum(a_sX: UGArray, mean: bool = False) -> UGArray: 

57 if a_sX.dims[0] == 2: 

58 a_X = a_sX.desc.empty(xp=a_sX.xp) 

59 a_sX.data[:2].sum(axis=0, out=a_X.data) 

60 if mean: 

61 a_X.data *= 0.5 

62 return a_X 

63 return a_sX[0] 

64 

65 

66class Timer: 

67 def __init__(self): 

68 self.times = defaultdict(float) 

69 self.times['Total'] = -time() 

70 

71 @contextmanager 

72 def __call__(self, name): 

73 t1 = time() 

74 try: 

75 yield 

76 finally: 

77 t2 = time() 

78 self.times[name] += t2 - t1 

79 

80 def write(self, log): 

81 self.times['Total'] += time() 

82 total = self.times['Total'] 

83 log('\ntiming: # [seconds]') 

84 n = max(len(name) for name in self.times) + 2 

85 w = len(f'{total:.3f}') 

86 N = 71 - n - w 

87 for name, t in self.times.items(): 

88 m = int(round(2 * N * t / total)) 

89 bar = '━' * (m // 2) + '╸' * (m % 2) 

90 log(f' {name + ":":{n}}{t:{w}.3f} # {bar}')