Coverage for gpaw/_broadcast_imports.py: 70%
115 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-07-19 00:19 +0000
1"""Provide mechanism to broadcast imports from master to other processes.
3This reduces file system strain.
5Use:
7 with broadcast_imports():
8 <execute import statements>
10This temporarily overrides the Python import mechanism so that
12 1) master executes and caches import metadata and code
13 2) import metadata and code are broadcast to all processes
14 3) other processes execute the import statements from memory
16Warning: Do not perform any parallel operations while broadcast imports
17are enabled. Non-master processes assume that they will receive module
18data and will crash or deadlock if master sends anything else.
19"""
22import os
23import sys
24import marshal
25from importlib.machinery import PathFinder, ModuleSpec
27from gpaw import GPAW_NO_C_EXTENSION, GPAW_MPI4PY
28import gpaw.cgpaw as cgpaw
30cgpaw_version = getattr(cgpaw, 'version', 0)
31if not GPAW_NO_C_EXTENSION and cgpaw_version != 10:
32 improvement = ''
33 if cgpaw_version == 9:
34 improvement = ('GPAW has now much reduced memory consumption due to '
35 'optimized pwlfc_expand function in new GPAW. Enjoy. ')
37 raise ImportError(improvement + 'Please recompile GPAW''s C-extensions!')
40if GPAW_MPI4PY:
41 from gpaw.mpi4pywrapper import MPI4PYWrapper
42 from mpi4py.MPI import COMM_WORLD
43 world = MPI4PYWrapper(COMM_WORLD)
44elif hasattr(cgpaw, 'Communicator'):
45 if '_gpaw' not in sys.builtin_module_names:
46 libmpi = os.environ.get('GPAW_MPI', 'libmpi.so')
47 import ctypes
48 try:
49 ctypes.CDLL(libmpi, ctypes.RTLD_GLOBAL)
50 except OSError:
51 pass
52 world = cgpaw.Communicator()
53else:
54 world = None # type: ignore
57def marshal_broadcast(obj):
58 if world.rank == 0:
59 buf = marshal.dumps(obj)
60 else:
61 assert obj is None
62 buf = None
64 buf = cgpaw.globally_broadcast_bytes(buf)
65 try:
66 return marshal.loads(buf)
67 except ValueError as err:
68 msg = ('Parallel import failure -- probably received garbage. '
69 'Error was: {}. This may happen if parallel operations are '
70 'performed while parallel imports are enabled.'.format(err))
71 raise ImportError(msg)
74class BroadcastLoader:
75 def __init__(self, spec, module_cache):
76 self.module_cache = module_cache
77 self.spec = spec
79 def create_module(self, spec):
80 # Returning None means to create the (uninitialized) module
81 # in the same way as normal.
82 #
83 # (But we could return e.g. a subclass of Module if we wanted.)
84 return None
86 def exec_module(self, module):
87 if world.rank == 0:
88 # Load from file and store in cache:
89 code = self.spec.loader.get_code(module.__name__)
90 metadata = (self.spec.submodule_search_locations, self.spec.origin)
91 self.module_cache[module.__name__] = (metadata, code)
92 # We could execute the default mechanism to load the module here.
93 # Instead we load from cache using our own loader, like on the
94 # other cores.
96 return self.load_from_cache(module)
98 def load_from_cache(self, module):
99 metadata, code = self.module_cache[module.__name__]
100 origin = metadata[1]
101 module.__file__ = origin
102 # __package__, __path__, __cached__?
103 module.__loader__ = self
104 sys.modules[module.__name__] = module
105 exec(code, module.__dict__)
106 return module
108 def __str__(self):
109 return ('<{} for {}:{} [{} modules cached]>'
110 .format(self.__class__.__name__,
111 self.spec.name, self.spec.origin,
112 len(self.module_cache)))
115class BroadcastImporter:
116 def __init__(self):
117 self.module_cache = {}
118 self.cached_modules = []
120 def find_spec(self, fullname, path=None, target=None):
121 if world.rank == 0:
122 spec = PathFinder.find_spec(fullname, path, target)
123 if spec is None:
124 return None
126 if spec.loader is None:
127 return None
129 code = spec.loader.get_code(fullname)
130 if code is None: # C extensions
131 return None
133 loader = BroadcastLoader(spec, self.module_cache)
134 assert fullname == spec.name
136 searchloc = spec.submodule_search_locations
137 spec = ModuleSpec(fullname, loader, origin=spec.origin,
138 is_package=searchloc is not None)
139 if searchloc is not None:
140 spec.submodule_search_locations += searchloc
141 return spec
142 else:
143 if fullname not in self.module_cache:
144 # Could this in principle interfere with builtin imports?
145 return PathFinder.find_spec(fullname, path, target)
147 searchloc, origin = self.module_cache[fullname][0]
148 loader = BroadcastLoader(None, self.module_cache)
149 spec = ModuleSpec(fullname, loader, origin=origin,
150 is_package=searchloc is not None)
151 if searchloc is not None:
152 spec.submodule_search_locations += searchloc
153 loader.spec = spec # XXX loader.loader is still None
154 return spec
156 def broadcast(self):
157 if world.size == 1:
158 return
159 if world.rank == 0:
160 # print('bcast {} modules'.format(len(self.module_cache)))
161 marshal_broadcast(self.module_cache)
162 else:
163 self.module_cache = marshal_broadcast(None)
164 # print('recv {} modules'.format(len(self.module_cache)))
166 def enable(self):
167 if world is None:
168 return
170 # There is the question of whether we lose anything by inserting
171 # ourselves further on in the meta_path list. Maybe not, and maybe
172 # that is a less violent act.
173 sys.meta_path.insert(0, self)
174 if world.rank != 0:
175 self.broadcast()
177 def disable(self):
178 if world is None:
179 return
181 if world.rank == 0:
182 self.broadcast()
183 self.cached_modules += self.module_cache.keys()
184 self.module_cache = {}
185 myself = sys.meta_path.pop(0)
186 assert myself is self
188 def __enter__(self):
189 self.enable()
191 def __exit__(self, *args):
192 self.disable()
195broadcast_imports = BroadcastImporter()