mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-29 01:59:48 +08:00
Improvements to scheduler (now use slices).
Change name "dispatcher" to "problem".
This commit is contained in:
+22
-21
@@ -144,9 +144,9 @@ except NameError:
|
||||
|
||||
class SystemSolver(object):
|
||||
|
||||
def __init__(self, dispatcher, endpointName, schedule):
|
||||
def __init__(self, problem, endpointName, schedule):
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
self.problem = problem
|
||||
self.endpointName = endpointName
|
||||
self.schedule = schedule
|
||||
|
||||
@@ -158,10 +158,10 @@ class SystemSolver(object):
|
||||
clearRef = Reference(fnformat%(self.endpointName, self.schedule[entry]['clear']))
|
||||
reduceLabels = self.schedule[entry]['reduce']
|
||||
|
||||
dview = self.dispatcher.remote.dview
|
||||
lview = self.dispatcher.remote.lview
|
||||
dview = self.problem.remote.dview
|
||||
lview = self.problem.remote.lview
|
||||
|
||||
chunksPerWorker = getattr(self.dispatcher, 'chunksPerWorker', 1)
|
||||
chunksPerWorker = getattr(self.problem, 'chunksPerWorker', 1)
|
||||
|
||||
G = SystemGraph()
|
||||
|
||||
@@ -170,19 +170,12 @@ class SystemSolver(object):
|
||||
|
||||
# Parse sources
|
||||
# TODO: Get from Survey somehow?
|
||||
nsrc = self.dispatcher.nsrc
|
||||
nsrc = self.problem.nsrc
|
||||
if isrcs is None:
|
||||
isrcslist = range(nsrc)
|
||||
isrcs = slice(None)
|
||||
|
||||
elif isinstance(isrcs, slice):
|
||||
isrcslist = range(isrcs.start or 0, isrcs.stop or nsrc, isrcs.step or 1)
|
||||
|
||||
else:
|
||||
try:
|
||||
_ = isrcs[0]
|
||||
isrcslist = isrcs
|
||||
except TypeError:
|
||||
isrcslist = [isrcs]
|
||||
elif not isinstance(isrcs, slice):
|
||||
raise Exception('Scheduler must run over slice or None!')
|
||||
|
||||
# TODO: Replace w/ hook into Endpoint classes
|
||||
systemsOnWorkers = dview['%s.localProblems.keys()'%self.endpointName]
|
||||
@@ -215,7 +208,7 @@ class SystemSolver(object):
|
||||
|
||||
with lview.temp_flags(block=False):
|
||||
iworks = 0
|
||||
for work in self._getChunks(isrcslist, int(round(chunksPerWorker*len(relIDs)))):
|
||||
for work in self._subSlice(isrcs, int(round(chunksPerWorker*len(relIDs)))):
|
||||
if work:
|
||||
job = lview.apply(fnRef, Reference(self.endpointName), tag, work)
|
||||
systemJobs.append(job)
|
||||
@@ -225,7 +218,7 @@ class SystemSolver(object):
|
||||
G.add_edge(tagNode, label)
|
||||
iworks += 1
|
||||
|
||||
if getattr(self.dispatcher, 'ensembleClear', False): # True for ensemble ending, False for individual ending
|
||||
if getattr(self.problem, 'ensembleClear', False): # True for ensemble ending, False for individual ending
|
||||
tagNode = 'Wrap: %d, %d'%tag
|
||||
for label in systemNodes:
|
||||
G.add_edge(label, tagNode)
|
||||
@@ -264,7 +257,7 @@ class SystemSolver(object):
|
||||
jobs = []
|
||||
after = clearJobs
|
||||
for label in reduceLabels:
|
||||
job = self.dispatcher.remote.reduceLB(Reference(self.endpointName), label, after=after)
|
||||
job = self.problem.remote.reduceLB(Reference(self.endpointName), label, after=after)
|
||||
after = job
|
||||
if job is not None:
|
||||
jobs.append(job)
|
||||
@@ -275,7 +268,7 @@ class SystemSolver(object):
|
||||
return G
|
||||
|
||||
def wait(self, G):
|
||||
self.dispatcher.remote.lview.wait(G.node['End']['jobs'] if G.node['End']['jobs'] else (G.node[wn]['jobs'] for wn in (G.predecessors(tn)[0] for tn in G.predecessors('End'))))
|
||||
self.problem.remote.lview.wait(G.node['End']['jobs'] if G.node['End']['jobs'] else (G.node[wn]['jobs'] for wn in (G.predecessors(tn)[0] for tn in G.predecessors('End'))))
|
||||
|
||||
# TODO: Hopefully obsoleted by SuperReference
|
||||
@staticmethod
|
||||
@@ -290,6 +283,13 @@ class SystemSolver(object):
|
||||
nproblems = len(problems)
|
||||
return (problems[i*nproblems // chunks: (i+1)*nproblems // chunks] for i in range(chunks))
|
||||
|
||||
@staticmethod
|
||||
def _subSlice(insl, chunks=1):
|
||||
start = insl.start or 0
|
||||
nproblems = insl.stop - start
|
||||
return [slice(start + i*nproblems/chunks, start + (i+1)*nproblems/chunks) for i in xrange(chunks)]
|
||||
|
||||
|
||||
class RemoteInterface(object):
|
||||
|
||||
def __init__(self, profile=None, MPI=None, nThreads=1, bootstrap=None):
|
||||
@@ -543,7 +543,8 @@ class RemoteInterface(object):
|
||||
|
||||
code = 'temp_norm = (%(key)s * %(key)s.conj()).sum(0).sum(0)'
|
||||
self.e0.execute(code%{'key': key})
|
||||
code = 'temp_norm = {key: np.sqrt(temp_norm[key] for key in temp_norm)}'
|
||||
code = 'temp_norm = {key: np.sqrt(temp_norm[key]).real for key in temp_norm}'
|
||||
self.e0.execute(code%{'key': key})
|
||||
result = CommonReducer(self.e0['temp_norm'])
|
||||
self.e0.execute('del temp_norm')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user