Coverage for gpaw/lcaotddft/wfwriter.py: 95%
91 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1import ase.io.ulm as ulm
3from gpaw.io import Writer
5from gpaw.lcaotddft.observer import TDDFTObserver
8class WaveFunctionReader:
9 def __init__(self, filename, index=None, wfreader=None):
10 if index is None:
11 self.reader = ulm.Reader(filename)
12 tag = self.reader.get_tag()
13 if tag != WaveFunctionWriter.ulmtag:
14 raise RuntimeError('Unknown tag %s' % tag)
15 self.filename = filename
16 self._is_main_reader = True
17 else:
18 self.index = index
19 self.reader = wfreader.reader[index]
20 self.version = wfreader.version
21 self.split = wfreader.split
22 self.filename = wfreader.filename
23 self._is_main_reader = False
25 def __getattr__(self, attr):
26 try:
27 return getattr(self.reader, attr)
28 except AttributeError:
29 pass
30 except KeyError: # backwards compatibility with ase-3.19.0
31 pass
33 # Split reader handling
34 if attr == 'wave_functions' and self.split:
35 if not hasattr(self, 'splitreader'):
36 self.splitreader = ulm.Reader(self.split_filename)
37 tag = self.splitreader.get_tag()
38 assert tag == WaveFunctionWriter.ulmtag_split
39 return getattr(self.splitreader, attr)
41 # Compatibility for older versions
42 if attr == 'split':
43 return False
45 if attr == 'split_filename':
46 name, ext = tuple(self.filename.rsplit('.', 1))
47 if self.version < 3:
48 fname = '%s-%06d-%s.%s' % (name, self.niter, self.action, ext)
49 else:
50 fname = '%s-%06d.%s' % (name, self.index, ext)
51 return fname
53 raise AttributeError('Attribute %s not defined in version %s' %
54 (repr(attr), repr(self.version)))
56 def __len__(self):
57 return len(self.reader)
59 def __getitem__(self, index):
60 return WaveFunctionReader(None, index, self)
62 def close(self):
63 if hasattr(self, 'splitreader'):
64 self.splitreader.close()
65 del self.splitreader
66 if self._is_main_reader:
67 self.reader.close()
69 def __del__(self):
70 self.close()
73class WaveFunctionWriter(TDDFTObserver):
74 version = 3
75 ulmtag = 'WFW'
76 ulmtag_split = ulmtag + 'split'
78 def __init__(self, paw, filename, split=False, interval=1):
79 TDDFTObserver.__init__(self, paw, interval)
80 self.split = split
81 if paw.niter == 0:
82 self.writer = Writer(filename, paw.world, mode='w',
83 tag=self.__class__.ulmtag)
84 self.writer.write(version=self.__class__.version)
85 self.writer.write(split=self.split)
86 self.writer.sync()
87 self.index = 1
88 else:
89 # Check the earlier file
90 reader = WaveFunctionReader(filename)
91 assert reader.version == self.__class__.version
92 self.split = reader.split # Use the earlier split value
93 self.index = len(reader)
94 reader.close()
96 # Append to earlier file
97 self.writer = Writer(filename, paw.world, mode='a',
98 tag=self.__class__.ulmtag)
100 if self.split:
101 name, ext = tuple(filename.rsplit('.', 1))
102 self.split_filename_fmt = name + '-%06d.' + ext
104 def _update(self, paw):
105 # Write metadata to main writer
106 self.writer.write(niter=paw.niter, time=paw.time, action=paw.action)
107 if paw.action == 'kick':
108 self.writer.write(kick_strength=paw.kick_strength)
110 if self.split:
111 # Use separate writer for actual data
112 filename = self.split_filename_fmt % self.index
113 writer = Writer(filename, paw.world, mode='w',
114 tag=self.__class__.ulmtag_split)
115 else:
116 # Use the same writer for actual data
117 writer = self.writer
118 w = writer.child('wave_functions')
119 paw.wfs.write_wave_functions(w)
120 paw.wfs.write_occupations(w)
121 if self.split:
122 writer.close()
123 # Sync the main writer
124 self.writer.sync()
125 self.index += 1
127 def __del__(self):
128 self.writer.close()