Coverage for gpaw/lcaotddft/wfwriter.py: 95%

91 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1import ase.io.ulm as ulm 

2 

3from gpaw.io import Writer 

4 

5from gpaw.lcaotddft.observer import TDDFTObserver 

6 

7 

8class WaveFunctionReader: 

9 def __init__(self, filename, index=None, wfreader=None): 

10 if index is None: 

11 self.reader = ulm.Reader(filename) 

12 tag = self.reader.get_tag() 

13 if tag != WaveFunctionWriter.ulmtag: 

14 raise RuntimeError('Unknown tag %s' % tag) 

15 self.filename = filename 

16 self._is_main_reader = True 

17 else: 

18 self.index = index 

19 self.reader = wfreader.reader[index] 

20 self.version = wfreader.version 

21 self.split = wfreader.split 

22 self.filename = wfreader.filename 

23 self._is_main_reader = False 

24 

25 def __getattr__(self, attr): 

26 try: 

27 return getattr(self.reader, attr) 

28 except AttributeError: 

29 pass 

30 except KeyError: # backwards compatibility with ase-3.19.0 

31 pass 

32 

33 # Split reader handling 

34 if attr == 'wave_functions' and self.split: 

35 if not hasattr(self, 'splitreader'): 

36 self.splitreader = ulm.Reader(self.split_filename) 

37 tag = self.splitreader.get_tag() 

38 assert tag == WaveFunctionWriter.ulmtag_split 

39 return getattr(self.splitreader, attr) 

40 

41 # Compatibility for older versions 

42 if attr == 'split': 

43 return False 

44 

45 if attr == 'split_filename': 

46 name, ext = tuple(self.filename.rsplit('.', 1)) 

47 if self.version < 3: 

48 fname = '%s-%06d-%s.%s' % (name, self.niter, self.action, ext) 

49 else: 

50 fname = '%s-%06d.%s' % (name, self.index, ext) 

51 return fname 

52 

53 raise AttributeError('Attribute %s not defined in version %s' % 

54 (repr(attr), repr(self.version))) 

55 

56 def __len__(self): 

57 return len(self.reader) 

58 

59 def __getitem__(self, index): 

60 return WaveFunctionReader(None, index, self) 

61 

62 def close(self): 

63 if hasattr(self, 'splitreader'): 

64 self.splitreader.close() 

65 del self.splitreader 

66 if self._is_main_reader: 

67 self.reader.close() 

68 

69 def __del__(self): 

70 self.close() 

71 

72 

73class WaveFunctionWriter(TDDFTObserver): 

74 version = 3 

75 ulmtag = 'WFW' 

76 ulmtag_split = ulmtag + 'split' 

77 

78 def __init__(self, paw, filename, split=False, interval=1): 

79 TDDFTObserver.__init__(self, paw, interval) 

80 self.split = split 

81 if paw.niter == 0: 

82 self.writer = Writer(filename, paw.world, mode='w', 

83 tag=self.__class__.ulmtag) 

84 self.writer.write(version=self.__class__.version) 

85 self.writer.write(split=self.split) 

86 self.writer.sync() 

87 self.index = 1 

88 else: 

89 # Check the earlier file 

90 reader = WaveFunctionReader(filename) 

91 assert reader.version == self.__class__.version 

92 self.split = reader.split # Use the earlier split value 

93 self.index = len(reader) 

94 reader.close() 

95 

96 # Append to earlier file 

97 self.writer = Writer(filename, paw.world, mode='a', 

98 tag=self.__class__.ulmtag) 

99 

100 if self.split: 

101 name, ext = tuple(filename.rsplit('.', 1)) 

102 self.split_filename_fmt = name + '-%06d.' + ext 

103 

104 def _update(self, paw): 

105 # Write metadata to main writer 

106 self.writer.write(niter=paw.niter, time=paw.time, action=paw.action) 

107 if paw.action == 'kick': 

108 self.writer.write(kick_strength=paw.kick_strength) 

109 

110 if self.split: 

111 # Use separate writer for actual data 

112 filename = self.split_filename_fmt % self.index 

113 writer = Writer(filename, paw.world, mode='w', 

114 tag=self.__class__.ulmtag_split) 

115 else: 

116 # Use the same writer for actual data 

117 writer = self.writer 

118 w = writer.child('wave_functions') 

119 paw.wfs.write_wave_functions(w) 

120 paw.wfs.write_occupations(w) 

121 if self.split: 

122 writer.close() 

123 # Sync the main writer 

124 self.writer.sync() 

125 self.index += 1 

126 

127 def __del__(self): 

128 self.writer.close()