Coverage for gpaw/io/tar.py: 24%
139 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-14 00:18 +0000
1import numbers
2import tarfile
3import xml.sax
5import numpy as np
7from gpaw.mpi import broadcast as mpi_broadcast
8from gpaw.mpi import world
10intsize = 4
11floatsize = np.array([1], float).itemsize
12complexsize = np.array([1], complex).itemsize
13itemsizes = {'int': intsize, 'float': floatsize, 'complex': complexsize}
16class FileReference:
17 """Common base class for having reference to a file. The actual I/O
18 classes implementing the referencing should be inherited from
19 this class."""
21 def __init__(self):
22 raise NotImplementedError('Should be implemented in derived classes')
24 def __len__(self):
25 raise NotImplementedError('Should be implemented in derived classes')
27 def __iter__(self):
28 for i in range(len(self)):
29 yield self[i]
31 def __getitem__(self):
32 raise NotImplementedError('Should be implemented in derived classes')
34 def __array__(self):
35 return self[::]
38def open(filename, mode='r', comm=world):
39 if not filename.endswith('.gpw'):
40 filename += '.gpw'
42 assert mode == 'r'
43 return Reader(filename, comm)
46class Reader(xml.sax.handler.ContentHandler):
47 def __init__(self, name, comm=world):
48 self.comm = comm # used for broadcasting replicated data
49 self.master = (self.comm.rank == 0)
50 self.dims = {}
51 self.shapes = {}
52 self.dtypes = {}
53 self.parameters = {}
54 xml.sax.handler.ContentHandler.__init__(self)
55 self.tar = tarfile.open(name, 'r')
56 f = self.tar.extractfile('info.xml')
57 xml.sax.parse(f, self)
59 def startElement(self, tag, attrs):
60 if tag == 'gpaw_io':
61 self.byteswap = ((attrs['endianness'] == 'little') !=
62 np.little_endian)
63 elif tag == 'array':
64 name = attrs['name']
65 self.dtypes[name] = attrs['type']
66 self.shapes[name] = []
67 self.name = name
68 elif tag == 'dimension':
69 n = int(attrs['length'])
70 self.shapes[self.name].append(n)
71 self.dims[attrs['name']] = n
72 else:
73 assert tag == 'parameter'
74 try:
75 value = eval(attrs['value'], {})
76 except (SyntaxError, NameError):
77 value = str(attrs['value'])
78 self.parameters[attrs['name']] = value
80 def dimension(self, name):
81 return self.dims[name]
83 def __getitem__(self, name):
84 return self.parameters[name]
86 def has_array(self, name):
87 return name in self.shapes
89 def get(self, name, *indices, **kwargs):
90 broadcast = kwargs.pop('broadcast', False)
91 if self.master or not broadcast:
92 fileobj, shape, size, dtype = self.get_file_object(name, indices)
93 array = np.fromstring(fileobj.read(size), dtype)
94 if self.byteswap:
95 array = array.byteswap()
96 if dtype == np.int32:
97 array = np.asarray(array, int)
98 array.shape = shape
99 if shape == ():
100 array = array.item()
101 else:
102 array = None
104 if broadcast:
105 array = mpi_broadcast(array, 0, self.comm)
106 return array
108 def get_reference(self, name, indices, length=None):
109 fileobj, shape, size, dtype = self.get_file_object(name, indices)
110 assert dtype != np.int32
111 return TarFileReference(fileobj, shape, dtype, self.byteswap, length)
113 def get_file_object(self, name, indices):
114 dtype, type, itemsize = self.get_data_type(name)
115 fileobj = self.tar.extractfile(name)
116 n = len(indices)
117 shape = self.shapes[name]
118 size = itemsize * np.prod(shape[n:], dtype=int)
119 offset = 0
120 stride = size
121 for i in range(n - 1, -1, -1):
122 offset += indices[i] * stride
123 stride *= shape[i]
124 fileobj.seek(offset)
125 return fileobj, shape[n:], size, dtype
127 def get_data_type(self, name):
128 type = self.dtypes[name]
129 dtype = np.dtype({'int': np.int32,
130 'float': float,
131 'complex': complex}[type])
132 return dtype, type, dtype.itemsize
134 def get_parameters(self):
135 return self.parameters
137 def close(self):
138 self.tar.close()
141class TarFileReference(FileReference):
142 def __init__(self, fileobj, shape, dtype, byteswap, length):
143 self.fileobj = fileobj
144 self.shape = tuple(shape)
145 self.dtype = dtype
146 self.itemsize = dtype.itemsize
147 self.byteswap = byteswap
148 self.offset = fileobj.tell()
149 self.length = length
151 def __len__(self):
152 return self.shape[0]
154 def __getitem__(self, indices):
155 if isinstance(indices, slice):
156 start, stop, step = indices.indices(len(self))
157 if start != 0 or step != 1 or stop != len(self):
158 raise NotImplementedError('You can only slice a TarReference '
159 'with [:] or [int]')
160 else:
161 indices = ()
162 elif isinstance(indices, numbers.Integral):
163 indices = (indices,)
164 else: # Probably tuple or ellipsis
165 raise NotImplementedError('You can only slice a TarReference '
166 'with [:] or [int]')
168 n = len(indices)
170 size = np.prod(self.shape[n:], dtype=int) * self.itemsize
171 offset = self.offset
172 stride = size
173 for i in range(n - 1, -1, -1):
174 offset += indices[i] * stride
175 stride *= self.shape[i]
176 self.fileobj.seek(offset)
177 array = np.fromstring(self.fileobj.read(size), self.dtype)
178 if self.byteswap:
179 array = array.byteswap()
180 array.shape = self.shape[n:]
181 if self.length:
182 array = array[..., :self.length].copy()
183 return array