From 12e64f418c5424cf712700bfbe4cbc7a786d2a0d Mon Sep 17 00:00:00 2001 From: Brendan Smithyman Date: Thu, 2 Jul 2015 14:25:15 -0400 Subject: [PATCH] Change location of endpointName setting. --- SimPEG/Parallel.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/SimPEG/Parallel.py b/SimPEG/Parallel.py index 7e15e177..b1ab9b34 100644 --- a/SimPEG/Parallel.py +++ b/SimPEG/Parallel.py @@ -144,18 +144,18 @@ except NameError: class SystemSolver(object): - def __init__(self, problem, endpointName, schedule): + def __init__(self, problem, schedule): self.problem = problem - self.endpointName = endpointName + self.remote = problem.remote self.schedule = schedule def __call__(self, entry, isrcs): # TODO: Replace with SuperReference instances fnformat = '%s.functions["%s"]' - fnRef = Reference(fnformat%(self.endpointName, self.schedule[entry]['solve'])) - clearRef = Reference(fnformat%(self.endpointName, self.schedule[entry]['clear'])) + fnRef = Reference(fnformat%(self.remote.endpointName, self.schedule[entry]['solve'])) + clearRef = Reference(fnformat%(self.remote.endpointName, self.schedule[entry]['clear'])) reduceLabels = self.schedule[entry]['reduce'] dview = self.problem.remote.dview @@ -178,7 +178,7 @@ class SystemSolver(object): raise Exception('Scheduler must run over slice or None!') # TODO: Replace w/ hook into Endpoint classes - systemsOnWorkers = dview['%s.localProblems.keys()'%self.endpointName] + systemsOnWorkers = dview['%s.localProblems.keys()'%self.remote.endpointName] ids = dview['rank'] tags = set() for ltags in systemsOnWorkers: @@ -210,7 +210,7 @@ class SystemSolver(object): iworks = 0 for work in self._subSlice(isrcs, int(round(chunksPerWorker*len(relIDs)))): if work: - job = lview.apply(fnRef, Reference(self.endpointName), tag, work) + job = lview.apply(fnRef, Reference(self.remote.endpointName), tag, work) systemJobs.append(job) label = 'Compute: %d, %d, %d'%(tag[0], tag[1], iworks) systemNodes.append(label) @@ -231,7 +231,7 @@ class SystemSolver(object): # TODO: Remove dependency on self._hasSystemRank, once the SuperReferences # are able to be used. They will automatically schedule only on the # correct (allowed) systems. - job = lview.apply(depend(self._hasSystemRank, tag, rank)(clearRef), Reference(self.endpointName), tag) + job = lview.apply(clearRef, Reference(self.remote.endpointName), tag, rank) clearJobs.append(job) label = 'Wrap: %d, %d, %d'%(tag[0],tag[1], i) G.add_node(label, jobs=[job]) @@ -241,7 +241,7 @@ class SystemSolver(object): for i, sjob in enumerate(systemJobs): with lview.temp_flags(block=False, follow=sjob): - job = lview.apply(clearRef, Reference(self.endpointName), tag) + job = lview.apply(clearRef, Reference(self.remote.endpointName), tag) clearJobs.append(job) label = 'Wrap: %d, %d, %d'%(tag[0],tag[1],i) G.add_node(label, jobs=[job]) @@ -257,7 +257,7 @@ class SystemSolver(object): jobs = [] after = clearJobs for label in reduceLabels: - job = self.problem.remote.reduceLB(Reference(self.endpointName), label, after=after) + job = self.problem.remote.reduceLB(Reference(self.remote.endpointName), label, after) after = job if job is not None: jobs.append(job) @@ -273,10 +273,10 @@ class SystemSolver(object): # TODO: Hopefully obsoleted by SuperReference @staticmethod @interactive - def _hasSystemRank(tag, wid): - global localSystem + def _hasSystemRank(endpoint, tag, wid): global rank - return (tag in localSystem) and (rank == wid) + + return (tag in endpoint.localProblems) and (rank == wid) @staticmethod def _getChunks(problems, chunks=1): @@ -292,7 +292,7 @@ class SystemSolver(object): class RemoteInterface(object): - def __init__(self, profile=None, MPI=None, nThreads=1, bootstrap=None): + def __init__(self, profile=None, MPI=None, nThreads=1, bootstrap=None, endpointName='endpoint'): # TODO: Add interface for namespace bootstrapping from # the dispatcher / problem side @@ -361,6 +361,8 @@ class RemoteInterface(object): for command in bootstrap.strip().split('\n'): dview.execute(command.strip()) + self.endpointName = endpointName + @property def nThreads(self): return self._nThreads @@ -599,6 +601,9 @@ class RemoteInterface(object): # code = '%(endpoint)s.globalFields["%(key)s"] = comm.reduce(%(endpoint)s.localFields["%(key)s"], root=%(root)d)' # exec(code%{'endpoint': endpoint, 'key': key, 'root': root}) + if key not in endpoint.localFields: + endpoint.localFields[key] = endpoint.fieldspec[key]() + endpoint.globalFields[key] = comm.reduce(endpoint.localFields[key], root=root) @staticmethod