Source estimation code.

This commit is contained in:
Brendan Smithyman
2015-06-29 18:45:44 -04:00
parent 0cf783d993
commit c507955903
+38
View File
@@ -528,6 +528,44 @@ class RemoteInterface(object):
def remoteDifferenceGatherFirst(self, *args):
self.remoteOpGatherFirst('-', *args)
def remoteSrcEstGatherFirst(self, keyresult, key1, key2, individual=False):
if self.useMPI:
root = 0
# # Gather
# code_reduce = 'temp_%(key)s = comm.reduce(%(key)s, root=%(root)d)'
# self.dview.execute(code_reduce%{'key': key1, 'root': root})
# SrcEst
if individual:
code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum(axis=1) / (%(key1)s.conj() * %(key1)s).sum(axis=1)'
else:
code_srcest = '%(keyresult)s = (%(key2)s.conj() * %(key1)s).sum() / (%(key1)s.conj() * %(key1)s).sum()'
self.e0.execute(code_srcest%{'key1': key1, 'key2': key2, 'keyresult': keyresult})
# Broadcast
code = 'if rank != %(root)d: %(key)s = None\n%(key)s = comm.bcast(%(key)s, root=%(root)d)'
self.dview.execute(code%{'key': keyresult, 'root': root})
else:
item1 = reduce(np.add, self.dview[key1])
item2 = self.e0[key2]
if individual:
item = (item2.conj() * item1).sum(axis=1) / (item1.conj() * item1).sum(axis=1)
else:
item = (item2.conj() * item1).sum() / (item1.conj() * item1).sum()
self.dview[keyresult] = item
def remoteApplySrc(self, keyData, keySrc):
code = '%(keyData)s = %(keySrc)s * %(keyData)s'
self.dview.execute(code%{'keyData': keyData, 'keySrc': keySrc})
# def normFromDifference(self, key):
# code = 'temp_norm%(key)s = (%(key)s * %(key)s.conj()).sum(0).sum(0)'