Coverage for gpaw/io/tar.py: 24%

139 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-14 00:18 +0000

1import numbers 

2import tarfile 

3import xml.sax 

4 

5import numpy as np 

6 

7from gpaw.mpi import broadcast as mpi_broadcast 

8from gpaw.mpi import world 

9 

10intsize = 4 

11floatsize = np.array([1], float).itemsize 

12complexsize = np.array([1], complex).itemsize 

13itemsizes = {'int': intsize, 'float': floatsize, 'complex': complexsize} 

14 

15 

16class FileReference: 

17 """Common base class for having reference to a file. The actual I/O 

18 classes implementing the referencing should be inherited from 

19 this class.""" 

20 

21 def __init__(self): 

22 raise NotImplementedError('Should be implemented in derived classes') 

23 

24 def __len__(self): 

25 raise NotImplementedError('Should be implemented in derived classes') 

26 

27 def __iter__(self): 

28 for i in range(len(self)): 

29 yield self[i] 

30 

31 def __getitem__(self): 

32 raise NotImplementedError('Should be implemented in derived classes') 

33 

34 def __array__(self): 

35 return self[::] 

36 

37 

38def open(filename, mode='r', comm=world): 

39 if not filename.endswith('.gpw'): 

40 filename += '.gpw' 

41 

42 assert mode == 'r' 

43 return Reader(filename, comm) 

44 

45 

46class Reader(xml.sax.handler.ContentHandler): 

47 def __init__(self, name, comm=world): 

48 self.comm = comm # used for broadcasting replicated data 

49 self.master = (self.comm.rank == 0) 

50 self.dims = {} 

51 self.shapes = {} 

52 self.dtypes = {} 

53 self.parameters = {} 

54 xml.sax.handler.ContentHandler.__init__(self) 

55 self.tar = tarfile.open(name, 'r') 

56 f = self.tar.extractfile('info.xml') 

57 xml.sax.parse(f, self) 

58 

59 def startElement(self, tag, attrs): 

60 if tag == 'gpaw_io': 

61 self.byteswap = ((attrs['endianness'] == 'little') != 

62 np.little_endian) 

63 elif tag == 'array': 

64 name = attrs['name'] 

65 self.dtypes[name] = attrs['type'] 

66 self.shapes[name] = [] 

67 self.name = name 

68 elif tag == 'dimension': 

69 n = int(attrs['length']) 

70 self.shapes[self.name].append(n) 

71 self.dims[attrs['name']] = n 

72 else: 

73 assert tag == 'parameter' 

74 try: 

75 value = eval(attrs['value'], {}) 

76 except (SyntaxError, NameError): 

77 value = str(attrs['value']) 

78 self.parameters[attrs['name']] = value 

79 

80 def dimension(self, name): 

81 return self.dims[name] 

82 

83 def __getitem__(self, name): 

84 return self.parameters[name] 

85 

86 def has_array(self, name): 

87 return name in self.shapes 

88 

89 def get(self, name, *indices, **kwargs): 

90 broadcast = kwargs.pop('broadcast', False) 

91 if self.master or not broadcast: 

92 fileobj, shape, size, dtype = self.get_file_object(name, indices) 

93 array = np.fromstring(fileobj.read(size), dtype) 

94 if self.byteswap: 

95 array = array.byteswap() 

96 if dtype == np.int32: 

97 array = np.asarray(array, int) 

98 array.shape = shape 

99 if shape == (): 

100 array = array.item() 

101 else: 

102 array = None 

103 

104 if broadcast: 

105 array = mpi_broadcast(array, 0, self.comm) 

106 return array 

107 

108 def get_reference(self, name, indices, length=None): 

109 fileobj, shape, size, dtype = self.get_file_object(name, indices) 

110 assert dtype != np.int32 

111 return TarFileReference(fileobj, shape, dtype, self.byteswap, length) 

112 

113 def get_file_object(self, name, indices): 

114 dtype, type, itemsize = self.get_data_type(name) 

115 fileobj = self.tar.extractfile(name) 

116 n = len(indices) 

117 shape = self.shapes[name] 

118 size = itemsize * np.prod(shape[n:], dtype=int) 

119 offset = 0 

120 stride = size 

121 for i in range(n - 1, -1, -1): 

122 offset += indices[i] * stride 

123 stride *= shape[i] 

124 fileobj.seek(offset) 

125 return fileobj, shape[n:], size, dtype 

126 

127 def get_data_type(self, name): 

128 type = self.dtypes[name] 

129 dtype = np.dtype({'int': np.int32, 

130 'float': float, 

131 'complex': complex}[type]) 

132 return dtype, type, dtype.itemsize 

133 

134 def get_parameters(self): 

135 return self.parameters 

136 

137 def close(self): 

138 self.tar.close() 

139 

140 

141class TarFileReference(FileReference): 

142 def __init__(self, fileobj, shape, dtype, byteswap, length): 

143 self.fileobj = fileobj 

144 self.shape = tuple(shape) 

145 self.dtype = dtype 

146 self.itemsize = dtype.itemsize 

147 self.byteswap = byteswap 

148 self.offset = fileobj.tell() 

149 self.length = length 

150 

151 def __len__(self): 

152 return self.shape[0] 

153 

154 def __getitem__(self, indices): 

155 if isinstance(indices, slice): 

156 start, stop, step = indices.indices(len(self)) 

157 if start != 0 or step != 1 or stop != len(self): 

158 raise NotImplementedError('You can only slice a TarReference ' 

159 'with [:] or [int]') 

160 else: 

161 indices = () 

162 elif isinstance(indices, numbers.Integral): 

163 indices = (indices,) 

164 else: # Probably tuple or ellipsis 

165 raise NotImplementedError('You can only slice a TarReference ' 

166 'with [:] or [int]') 

167 

168 n = len(indices) 

169 

170 size = np.prod(self.shape[n:], dtype=int) * self.itemsize 

171 offset = self.offset 

172 stride = size 

173 for i in range(n - 1, -1, -1): 

174 offset += indices[i] * stride 

175 stride *= self.shape[i] 

176 self.fileobj.seek(offset) 

177 array = np.fromstring(self.fileobj.read(size), self.dtype) 

178 if self.byteswap: 

179 array = array.byteswap() 

180 array.shape = self.shape[n:] 

181 if self.length: 

182 array = array[..., :self.length].copy() 

183 return array