diff --git a/SimPEG/Parallel.py b/SimPEG/Parallel.py index e0c9352c..2df3cd28 100644 --- a/SimPEG/Parallel.py +++ b/SimPEG/Parallel.py @@ -38,14 +38,24 @@ class Endpoint(object): footprint on the remote workers. ''' - localFields = {} - globalFields = {} - localSystems = {} - functions = {} + systemFactory = lambda: None # Callable for constructing system / problem + localFields = {} # Dictionary for storing local fields + globalFields = {} # Dictionary for storing merged fields + localSystems = {} # Dictionary of local subsystem / problem objects + functions = {} # Dictionary of callables to carry out modelling / etc. + fieldspec = {} # Dictionary of callables to setup field storage objects - def __init__(self): + def setupLocalFields(self, whichfields=None): - pass + if whichfields is None: + self.localFields = {} + if getattr(self, 'fieldspec', None) is not None: + for fn in (whichfields or self.fieldspec): + self.localFields[fn] = self.fieldspec[fn]() + + def setupLocalSystem(self, subConfig): + + self.localSystems[subConfig['tag']] = self.systemFactory(subConfig) class SystemGraph(networkx.DiGraph): @@ -104,16 +114,18 @@ except NameError: class SystemSolver(object): - def __init__(self, dispatcher, schedule): + def __init__(self, dispatcher, endpoint, schedule): self.dispatcher = dispatcher + self.endpoint = endpoint self.schedule = schedule def __call__(self, entry, isrcs): # TODO: Replace with SuperReference instances - fnRef = self.schedule[entry]['solve'] - clearRef = self.schedule[entry]['clear'] + fnformat = '%s.functions["%s"]' + fnRef = Reference(fnformat%(self.endpoint, self.schedule[entry]['solve'])) + clearRef = Reference(fnformat%(self.endpoint, self.schedule[entry]['clear'])) reduceLabels = self.schedule[entry]['reduce'] dview = self.dispatcher.remote.dview @@ -143,7 +155,7 @@ class SystemSolver(object): isrcslist = [isrcs] # TODO: Replace w/ hook into Endpoint classes - systemsOnWorkers = dview['localSystem.keys()'] + systemsOnWorkers = dview['%s.localSystems.keys()'%self.endpoint] ids = dview['rank'] tags = set() for ltags in systemsOnWorkers: @@ -175,7 +187,7 @@ class SystemSolver(object): iworks = 0 for work in self._getChunks(isrcslist, int(round(chunksPerWorker*len(relIDs)))): if work: - job = lview.apply(fnRef, tag, work) + job = lview.apply(fnRef, Reference(self.endpoint), tag, work) systemJobs.append(job) label = 'Compute: %d, %d, %d'%(tag[0], tag[1], iworks) systemNodes.append(label) @@ -196,7 +208,7 @@ class SystemSolver(object): # TODO: Remove dependency on self._hasSystemRank, once the SuperReferences # are able to be used. They will automatically schedule only on the # correct (allowed) systems. - job = lview.apply(depend(self._hasSystemRank, tag, rank)(clearRef), tag) + job = lview.apply(depend(self._hasSystemRank, tag, rank)(clearRef), Reference(self.endpoint), tag) clearJobs.append(job) label = 'Wrap: %d, %d, %d'%(tag[0],tag[1], i) G.add_node(label, jobs=[job]) @@ -206,7 +218,7 @@ class SystemSolver(object): for i, sjob in enumerate(systemJobs): with lview.temp_flags(block=False, follow=sjob): - job = lview.apply(clearRef, tag) + job = lview.apply(clearRef, Reference(self.endpoint), tag) clearJobs.append(job) label = 'Wrap: %d, %d, %d'%(tag[0],tag[1],i) G.add_node(label, jobs=[job]) @@ -222,7 +234,8 @@ class SystemSolver(object): jobs = [] after = clearJobs for label in reduceLabels: - job = self.dispatcher.remote.reduceLB(label, after=after) + key = '%s.localFields["%s"]'%(self.endpoint, label) + job = self.dispatcher.remote.reduceLB(Reference(self.endpoint), key, after=after) after = job if job is not None: jobs.append(job) @@ -351,13 +364,13 @@ class RemoteInterface(object): return item - def reduceLB(self, key, after=None): + def reduceLB(self, endpoint, key, after=None): repeat = lambda value: (value for i in xrange(len(self.pclient.ids))) if self.useMPI: with self.lview.temp_flags(block=False, after=after): - job = self.lview.map(self._reduceJob, xrange(len(self.pclient.ids)), repeat(0), repeat(key)) + job = self.lview.map(self._reduceJob, xrange(len(self.pclient.ids)), repeat(0), endpoint, repeat(key)) return job @@ -485,19 +498,14 @@ class RemoteInterface(object): @staticmethod @interactive - def _reduceJob(worker, root, key): + def _reduceJob(worker, root, endpoint, key): from IPython.parallel.error import UnmetDependency if not rank == worker: raise UnmetDependency - from SimPEG.Utils import CommonReducer - - # exec('global %s'%key) - - code = 'globals()["%(key)s"] = comm.reduce(%(key)s, root=%(root)d)' + code = 'endpoint.globalFields[%(key)s] = comm.reduce(%(key)s, root=%(root)d)' exec(code%{'key': key, 'root': root}) - exec('globals()["%(key)s"] = %(key)s if %(key)s is not None else CommonReducer()'%{'key': key}) @staticmethod def _adjustMKLVectorization(nt=1):