Change location of endpointName setting.

This commit is contained in:
Brendan Smithyman
2015-07-02 14:25:15 -04:00
parent c507955903
commit 12e64f418c
+18 -13
View File
@@ -144,18 +144,18 @@ except NameError:
class SystemSolver(object):
def __init__(self, problem, endpointName, schedule):
def __init__(self, problem, schedule):
self.problem = problem
self.endpointName = endpointName
self.remote = problem.remote
self.schedule = schedule
def __call__(self, entry, isrcs):
# TODO: Replace with SuperReference instances
fnformat = '%s.functions["%s"]'
fnRef = Reference(fnformat%(self.endpointName, self.schedule[entry]['solve']))
clearRef = Reference(fnformat%(self.endpointName, self.schedule[entry]['clear']))
fnRef = Reference(fnformat%(self.remote.endpointName, self.schedule[entry]['solve']))
clearRef = Reference(fnformat%(self.remote.endpointName, self.schedule[entry]['clear']))
reduceLabels = self.schedule[entry]['reduce']
dview = self.problem.remote.dview
@@ -178,7 +178,7 @@ class SystemSolver(object):
raise Exception('Scheduler must run over slice or None!')
# TODO: Replace w/ hook into Endpoint classes
systemsOnWorkers = dview['%s.localProblems.keys()'%self.endpointName]
systemsOnWorkers = dview['%s.localProblems.keys()'%self.remote.endpointName]
ids = dview['rank']
tags = set()
for ltags in systemsOnWorkers:
@@ -210,7 +210,7 @@ class SystemSolver(object):
iworks = 0
for work in self._subSlice(isrcs, int(round(chunksPerWorker*len(relIDs)))):
if work:
job = lview.apply(fnRef, Reference(self.endpointName), tag, work)
job = lview.apply(fnRef, Reference(self.remote.endpointName), tag, work)
systemJobs.append(job)
label = 'Compute: %d, %d, %d'%(tag[0], tag[1], iworks)
systemNodes.append(label)
@@ -231,7 +231,7 @@ class SystemSolver(object):
# TODO: Remove dependency on self._hasSystemRank, once the SuperReferences
# are able to be used. They will automatically schedule only on the
# correct (allowed) systems.
job = lview.apply(depend(self._hasSystemRank, tag, rank)(clearRef), Reference(self.endpointName), tag)
job = lview.apply(clearRef, Reference(self.remote.endpointName), tag, rank)
clearJobs.append(job)
label = 'Wrap: %d, %d, %d'%(tag[0],tag[1], i)
G.add_node(label, jobs=[job])
@@ -241,7 +241,7 @@ class SystemSolver(object):
for i, sjob in enumerate(systemJobs):
with lview.temp_flags(block=False, follow=sjob):
job = lview.apply(clearRef, Reference(self.endpointName), tag)
job = lview.apply(clearRef, Reference(self.remote.endpointName), tag)
clearJobs.append(job)
label = 'Wrap: %d, %d, %d'%(tag[0],tag[1],i)
G.add_node(label, jobs=[job])
@@ -257,7 +257,7 @@ class SystemSolver(object):
jobs = []
after = clearJobs
for label in reduceLabels:
job = self.problem.remote.reduceLB(Reference(self.endpointName), label, after=after)
job = self.problem.remote.reduceLB(Reference(self.remote.endpointName), label, after)
after = job
if job is not None:
jobs.append(job)
@@ -273,10 +273,10 @@ class SystemSolver(object):
# TODO: Hopefully obsoleted by SuperReference
@staticmethod
@interactive
def _hasSystemRank(tag, wid):
global localSystem
def _hasSystemRank(endpoint, tag, wid):
global rank
return (tag in localSystem) and (rank == wid)
return (tag in endpoint.localProblems) and (rank == wid)
@staticmethod
def _getChunks(problems, chunks=1):
@@ -292,7 +292,7 @@ class SystemSolver(object):
class RemoteInterface(object):
def __init__(self, profile=None, MPI=None, nThreads=1, bootstrap=None):
def __init__(self, profile=None, MPI=None, nThreads=1, bootstrap=None, endpointName='endpoint'):
# TODO: Add interface for namespace bootstrapping from
# the dispatcher / problem side
@@ -361,6 +361,8 @@ class RemoteInterface(object):
for command in bootstrap.strip().split('\n'):
dview.execute(command.strip())
self.endpointName = endpointName
@property
def nThreads(self):
return self._nThreads
@@ -599,6 +601,9 @@ class RemoteInterface(object):
# code = '%(endpoint)s.globalFields["%(key)s"] = comm.reduce(%(endpoint)s.localFields["%(key)s"], root=%(root)d)'
# exec(code%{'endpoint': endpoint, 'key': key, 'root': root})
if key not in endpoint.localFields:
endpoint.localFields[key] = endpoint.fieldspec[key]()
endpoint.globalFields[key] = comm.reduce(endpoint.localFields[key], root=root)
@staticmethod