diff --git a/SimPEG/Parallel.py b/SimPEG/Parallel.py index a09547b6..7e15e177 100644 --- a/SimPEG/Parallel.py +++ b/SimPEG/Parallel.py @@ -528,6 +528,44 @@ class RemoteInterface(object): def remoteDifferenceGatherFirst(self, *args): self.remoteOpGatherFirst('-', *args) + def remoteSrcEstGatherFirst(self, keyresult, key1, key2, individual=False): + + if self.useMPI: + + root = 0 + + # # Gather + # code_reduce = 'temp_%(key)s = comm.reduce(%(key)s, root=%(root)d)' + # self.dview.execute(code_reduce%{'key': key1, 'root': root}) + + # SrcEst + if individual: + code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum(axis=1) / (%(key1)s.conj() * %(key1)s).sum(axis=1)' + else: + code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum() / (%(key1)s.conj() * %(key1)s).sum()' + self.e0.execute(code_srcest%{'key1': key1, 'key2': key2, 'keyresult': keyresult}) + + # Broadcast + code = 'if rank != %(root)d: %(key)s = None\n%(key)s = comm.bcast(%(key)s, root=%(root)d)' + self.dview.execute(code%{'key': keyresult, 'root': root}) + + else: + + item1 = reduce(np.add, self.dview[key1]) + item2 = self.e0[key2] + + if individual: + item = (item2.conj() * item1).sum(axis=1) / (item1.conj() * item1).sum(axis=1) + else: + item = (item2.conj() * item1).sum() / (item1.conj() * item1).sum() + + self.dview[keyresult] = item + + def remoteApplySrc(self, keyData, keySrc): + + code = '%(keyData)s = %(keySrc)s * %(keyData)s' + self.dview.execute(code%{'keyData': keyData, 'keySrc': keySrc}) + # def normFromDifference(self, key): # code = 'temp_norm%(key)s = (%(key)s * %(key)s.conj()).sum(0).sum(0)'