Coverage for gpaw/_broadcast_imports.py: 70%

115 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-07-19 00:19 +0000

1"""Provide mechanism to broadcast imports from master to other processes. 

2 

3This reduces file system strain. 

4 

5Use: 

6 

7 with broadcast_imports(): 

8 <execute import statements> 

9 

10This temporarily overrides the Python import mechanism so that 

11 

12 1) master executes and caches import metadata and code 

13 2) import metadata and code are broadcast to all processes 

14 3) other processes execute the import statements from memory 

15 

16Warning: Do not perform any parallel operations while broadcast imports 

17are enabled. Non-master processes assume that they will receive module 

18data and will crash or deadlock if master sends anything else. 

19""" 

20 

21 

22import os 

23import sys 

24import marshal 

25from importlib.machinery import PathFinder, ModuleSpec 

26 

27from gpaw import GPAW_NO_C_EXTENSION, GPAW_MPI4PY 

28import gpaw.cgpaw as cgpaw 

29 

30cgpaw_version = getattr(cgpaw, 'version', 0) 

31if not GPAW_NO_C_EXTENSION and cgpaw_version != 10: 

32 improvement = '' 

33 if cgpaw_version == 9: 

34 improvement = ('GPAW has now much reduced memory consumption due to ' 

35 'optimized pwlfc_expand function in new GPAW. Enjoy. ') 

36 

37 raise ImportError(improvement + 'Please recompile GPAW''s C-extensions!') 

38 

39 

40if GPAW_MPI4PY: 

41 from gpaw.mpi4pywrapper import MPI4PYWrapper 

42 from mpi4py.MPI import COMM_WORLD 

43 world = MPI4PYWrapper(COMM_WORLD) 

44elif hasattr(cgpaw, 'Communicator'): 

45 if '_gpaw' not in sys.builtin_module_names: 

46 libmpi = os.environ.get('GPAW_MPI', 'libmpi.so') 

47 import ctypes 

48 try: 

49 ctypes.CDLL(libmpi, ctypes.RTLD_GLOBAL) 

50 except OSError: 

51 pass 

52 world = cgpaw.Communicator() 

53else: 

54 world = None # type: ignore 

55 

56 

57def marshal_broadcast(obj): 

58 if world.rank == 0: 

59 buf = marshal.dumps(obj) 

60 else: 

61 assert obj is None 

62 buf = None 

63 

64 buf = cgpaw.globally_broadcast_bytes(buf) 

65 try: 

66 return marshal.loads(buf) 

67 except ValueError as err: 

68 msg = ('Parallel import failure -- probably received garbage. ' 

69 'Error was: {}. This may happen if parallel operations are ' 

70 'performed while parallel imports are enabled.'.format(err)) 

71 raise ImportError(msg) 

72 

73 

74class BroadcastLoader: 

75 def __init__(self, spec, module_cache): 

76 self.module_cache = module_cache 

77 self.spec = spec 

78 

79 def create_module(self, spec): 

80 # Returning None means to create the (uninitialized) module 

81 # in the same way as normal. 

82 # 

83 # (But we could return e.g. a subclass of Module if we wanted.) 

84 return None 

85 

86 def exec_module(self, module): 

87 if world.rank == 0: 

88 # Load from file and store in cache: 

89 code = self.spec.loader.get_code(module.__name__) 

90 metadata = (self.spec.submodule_search_locations, self.spec.origin) 

91 self.module_cache[module.__name__] = (metadata, code) 

92 # We could execute the default mechanism to load the module here. 

93 # Instead we load from cache using our own loader, like on the 

94 # other cores. 

95 

96 return self.load_from_cache(module) 

97 

98 def load_from_cache(self, module): 

99 metadata, code = self.module_cache[module.__name__] 

100 origin = metadata[1] 

101 module.__file__ = origin 

102 # __package__, __path__, __cached__? 

103 module.__loader__ = self 

104 sys.modules[module.__name__] = module 

105 exec(code, module.__dict__) 

106 return module 

107 

108 def __str__(self): 

109 return ('<{} for {}:{} [{} modules cached]>' 

110 .format(self.__class__.__name__, 

111 self.spec.name, self.spec.origin, 

112 len(self.module_cache))) 

113 

114 

115class BroadcastImporter: 

116 def __init__(self): 

117 self.module_cache = {} 

118 self.cached_modules = [] 

119 

120 def find_spec(self, fullname, path=None, target=None): 

121 if world.rank == 0: 

122 spec = PathFinder.find_spec(fullname, path, target) 

123 if spec is None: 

124 return None 

125 

126 if spec.loader is None: 

127 return None 

128 

129 code = spec.loader.get_code(fullname) 

130 if code is None: # C extensions 

131 return None 

132 

133 loader = BroadcastLoader(spec, self.module_cache) 

134 assert fullname == spec.name 

135 

136 searchloc = spec.submodule_search_locations 

137 spec = ModuleSpec(fullname, loader, origin=spec.origin, 

138 is_package=searchloc is not None) 

139 if searchloc is not None: 

140 spec.submodule_search_locations += searchloc 

141 return spec 

142 else: 

143 if fullname not in self.module_cache: 

144 # Could this in principle interfere with builtin imports? 

145 return PathFinder.find_spec(fullname, path, target) 

146 

147 searchloc, origin = self.module_cache[fullname][0] 

148 loader = BroadcastLoader(None, self.module_cache) 

149 spec = ModuleSpec(fullname, loader, origin=origin, 

150 is_package=searchloc is not None) 

151 if searchloc is not None: 

152 spec.submodule_search_locations += searchloc 

153 loader.spec = spec # XXX loader.loader is still None 

154 return spec 

155 

156 def broadcast(self): 

157 if world.size == 1: 

158 return 

159 if world.rank == 0: 

160 # print('bcast {} modules'.format(len(self.module_cache))) 

161 marshal_broadcast(self.module_cache) 

162 else: 

163 self.module_cache = marshal_broadcast(None) 

164 # print('recv {} modules'.format(len(self.module_cache))) 

165 

166 def enable(self): 

167 if world is None: 

168 return 

169 

170 # There is the question of whether we lose anything by inserting 

171 # ourselves further on in the meta_path list. Maybe not, and maybe 

172 # that is a less violent act. 

173 sys.meta_path.insert(0, self) 

174 if world.rank != 0: 

175 self.broadcast() 

176 

177 def disable(self): 

178 if world is None: 

179 return 

180 

181 if world.rank == 0: 

182 self.broadcast() 

183 self.cached_modules += self.module_cache.keys() 

184 self.module_cache = {} 

185 myself = sys.meta_path.pop(0) 

186 assert myself is self 

187 

188 def __enter__(self): 

189 self.enable() 

190 

191 def __exit__(self, *args): 

192 self.disable() 

193 

194 

195broadcast_imports = BroadcastImporter()