mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-29 21:26:31 +08:00
Source estimation code.
This commit is contained in:
@@ -528,6 +528,44 @@ class RemoteInterface(object):
|
||||
def remoteDifferenceGatherFirst(self, *args):
|
||||
self.remoteOpGatherFirst('-', *args)
|
||||
|
||||
def remoteSrcEstGatherFirst(self, keyresult, key1, key2, individual=False):
|
||||
|
||||
if self.useMPI:
|
||||
|
||||
root = 0
|
||||
|
||||
# # Gather
|
||||
# code_reduce = 'temp_%(key)s = comm.reduce(%(key)s, root=%(root)d)'
|
||||
# self.dview.execute(code_reduce%{'key': key1, 'root': root})
|
||||
|
||||
# SrcEst
|
||||
if individual:
|
||||
code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum(axis=1) / (%(key1)s.conj() * %(key1)s).sum(axis=1)'
|
||||
else:
|
||||
code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum() / (%(key1)s.conj() * %(key1)s).sum()'
|
||||
self.e0.execute(code_srcest%{'key1': key1, 'key2': key2, 'keyresult': keyresult})
|
||||
|
||||
# Broadcast
|
||||
code = 'if rank != %(root)d: %(key)s = None\n%(key)s = comm.bcast(%(key)s, root=%(root)d)'
|
||||
self.dview.execute(code%{'key': keyresult, 'root': root})
|
||||
|
||||
else:
|
||||
|
||||
item1 = reduce(np.add, self.dview[key1])
|
||||
item2 = self.e0[key2]
|
||||
|
||||
if individual:
|
||||
item = (item2.conj() * item1).sum(axis=1) / (item1.conj() * item1).sum(axis=1)
|
||||
else:
|
||||
item = (item2.conj() * item1).sum() / (item1.conj() * item1).sum()
|
||||
|
||||
self.dview[keyresult] = item
|
||||
|
||||
def remoteApplySrc(self, keyData, keySrc):
|
||||
|
||||
code = '%(keyData)s = %(keySrc)s * %(keyData)s'
|
||||
self.dview.execute(code%{'keyData': keyData, 'keySrc': keySrc})
|
||||
|
||||
# def normFromDifference(self, key):
|
||||
|
||||
# code = 'temp_norm%(key)s = (%(key)s * %(key)s.conj()).sum(0).sum(0)'
|
||||
|
||||
Reference in New Issue
Block a user