From 2ae20b8cd3c7693a07473b92f164a6b13aeb50d5 Mon Sep 17 00:00:00 2001 From: rowanc1 Date: Fri, 25 Apr 2014 15:48:38 -0700 Subject: [PATCH] updates to survey. --- SimPEG/Survey.py | 346 ++++++++++++++++++++++-------------- SimPEG/Tests/test_Survey.py | 51 ++++++ 2 files changed, 268 insertions(+), 129 deletions(-) create mode 100644 SimPEG/Tests/test_Survey.py diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 4d6a4cf3..dc5aeadb 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -1,6 +1,195 @@ import Utils, numpy as np, scipy.sparse as sp +class BaseRx(object): + """SimPEG Receiver Object""" + + locs = None #: Locations (nRx x nDim) + + knownRxTypes = None #: Set this to a list of strings to ensure that txType is known + + projGLoc = 'CC' #: Projection grid location, default is CC + + storeProjections = True #: Store calls to getP (organized by mesh) + + def __init__(self, locs, rxType, **kwargs): + self.locs = locs + self.rxType = rxType + self._Ps = {} + Utils.setKwargs(self, **kwargs) + + @property + def rxType(self): + """Receiver Type""" + return getattr(self, '_rxType', None) + @rxType.setter + def rxType(self, value): + known = self.knownRxTypes + if known is not None: + assert value in known, "rxType must be in ['%s']" % ("', '".join(known)) + self._rxType = value + + @property + def nD(self): + """Number of data in the receiver.""" + return self.locs.shape[0] + + def getP(self, mesh): + """ + Returns the projection matrices as a + list for all components collected by + the receivers. + + .. note:: + + Projection matrices are stored as a dictionary listed by meshes. + """ + if mesh in self._Ps: + return self._Ps[mesh] + + P = mesh.getInterpolationMat(self.locs, self.projGLoc) + if self.storeProjections: + self._Ps[mesh] = P + return P + + +class BaseTimeRx(BaseRx): + """SimPEG Receiver Object""" + + times = None #: Times when the receivers were active. + + def __init__(self, locs, times, rxType, **kwargs): + self.times = times + BaseRx.__init__(self, locs, rxType, **kwargs) + + @property + def nD(self): + """Number of data in the receiver.""" + return self.locs.shape[0] * len(self.times) + + def getP(self, mesh, timeMesh): + """ + Returns the projection matrices as a + list for all components collected by + the receivers. + + .. note:: + + Projection matrices are stored as a dictionary (mesh, timeMesh) + """ + if (mesh, timeMesh) in self._Ps: + return self._Ps[(mesh, timeMesh)] + + Ps = mesh.getInterpolationMat(self.locs, self.projGLoc) + Pt = timeMesh.getInterpolationMat(self.times, 'N') + P = sp.kron(Pt, Ps) + + if self.storeProjections: + self._Ps[(mesh, timeMesh)] = P + + return P + + +class BaseTx(object): + """SimPEG Transmitter Object""" + + loc = None #: Location [x,y,z] + + rxList = None #: SimPEG Receiver List + rxPair = BaseRx + + knownTxTypes = None #: Set this to a list of strings to ensure that txType is known + + def __init__(self, loc, txType, rxList, **kwargs): + assert type(rxList) is list, 'rxList must be a list' + for rx in rxList: + assert isinstance(rx, self.rxPair), 'rxList must be a %s'%self.rxListPair.__name__ + assert len(set(rxList)) == len(rxList), 'The rxList must be unique' + + self.loc = loc + self.txType = txType + self.rxList = rxList + Utils.setKwargs(self, **kwargs) + + @property + def txType(self): + """Transmitter Type""" + return getattr(self, '_txType', None) + @txType.setter + def txType(self, value): + known = self.knownTxTypes + if known is not None: + assert value in known, "txType must be in ['%s']" % ("', '".join(known)) + self._txType = value + + @property + def nD(self): + """Number of data""" + return self.vnD.sum() + + @property + def vnD(self): + """Vector number of data""" + return np.array([rx.nD for rx in self.rxList]) + + +class Data(object): + """Fancy data storage by Tx and Rx""" + + def __init__(self, survey, v=None): + self.survey = survey + self._dataDict = {} + for tx in self.survey.txList: + self._dataDict[tx] = {} + if v is not None: + self.fromvec(v) + + def _ensureCorrectKey(self, key): + if type(key) is tuple: + if len(key) is not 2: + raise KeyError('Key must be [Tx, Rx]') + if key[0] not in self.survey.txList: + raise KeyError('Tx Key must be a transmitter in the survey.') + if key[1] not in key[0].rxList: + raise KeyError('Rx Key must be a receiver for the transmitter.') + return key + elif isinstance(key, self.survey.txPair): + if key not in self.survey.txList: + raise KeyError('Key must be a transmitter in the survey.') + return key, None + else: + raise KeyError('Key must be [Tx] or [Tx,Rx]') + + def __setitem__(self, key, value): + tx, rx = self._ensureCorrectKey(key) + assert rx is not None, 'set data using [Tx, Rx]' + assert type(value) == np.ndarray, 'value must by ndarray' + assert value.size == rx.nD, "value must have the same number of data as the transmitter." + self._dataDict[tx][rx] = Utils.mkvc(value) + + def __getitem__(self, key): + tx, rx = self._ensureCorrectKey(key) + if rx is not None: + if rx not in self._dataDict[tx]: + raise Exception('Data for receiver has not yet been set.') + return self._dataDict[tx][rx] + + return np.concatenate([self[tx,rx] for rx in tx.rxList]) + + def tovec(self): + return np.concatenate([self[tx] for tx in self.survey.txList]) + + def fromvec(self, v): + v = Utils.mkvc(v) + assert v.size == self.survey.nD, 'v must have the correct number of data.' + indBot, indTop = 0, 0 + for tx in self.survey.txList: + for rx in tx.rxList: + indTop += rx.nD + self[tx, rx] = v[indBot:indTop] + indBot += rx.nD + + class BaseSurvey(object): """Survey holds the observed data, and the standard deviations.""" @@ -16,6 +205,20 @@ class BaseSurvey(object): def __init__(self, **kwargs): Utils.setKwargs(self, **kwargs) + txPair = BaseTx #: Transmitter Pair + + @property + def txList(self): + """Transmitter List""" + return getattr(self, '_txList', None) + + @txList.setter + def txList(self, value): + assert type(value) is list, 'txList must be a list' + assert np.all([isinstance(tx, self.txPair) for tx in value]), 'All transmitters must be instances of %s' % self.txPair.__name__ + assert len(set(value)) == len(value), 'The txList must be unique' + self._txList = value + @property def prob(self): """ @@ -48,14 +251,22 @@ class BaseSurvey(object): self._prob = None @property - def nD(self): - """Number of data.""" - if hasattr(self, 'dobs'): - return self.dobs.size - raise NotImplemented('Number of data is unknown.') + def ispaired(self): return self.prob is not None @property - def ispaired(self): return self.prob is not None + def nD(self): + """Number of data""" + return self.vnD.sum() + + @property + def vnD(self): + """Vector number of data""" + return np.array([tx.nD for tx in self.txList]) + + @property + def nTx(self): + """Number of Transmitters""" + return len(self.txList) @Utils.count @Utils.requires('prob') @@ -182,129 +393,6 @@ class BaseSurvey(object): # self._phi_d_target = value -class BaseRx(object): - """SimPEG Receiver Object""" - - locs = None #: Locations (nRx x nDim) - - knownRxTypes = None #: Set this to a list of strings to ensure that txType is known - - projGLoc = 'CC' #: Projection grid location, default is CC - - def __init__(self, locs, rxType, **kwargs): - self.locs = locs - self.rxType = rxType - self._Ps = {} - Utils.setKwargs(self, **kwargs) - - @property - def rxType(self): - """Receiver Type""" - return getattr(self, '_rxType', None) - @rxType.setter - def rxType(self, value): - known = self.knownRxTypes - if known is not None: - assert value in known, "rxType must be in ['%s']" % ("', '".join(known)) - self._rxType = value - - @property - def nD(self): - """Number of data in the receiver.""" - return self.locs.shape[0] - - def getP(self, mesh): - """ - Returns the projection matrices as a - list for all components collected by - the receivers. - - .. note:: - - Projection matrices are stored as a dictionary listed by meshes. - """ - if mesh not in self._Ps: - self._Ps[mesh] = mesh.getInterpolationMat(self.locs, self.projGLoc) - P = self._Ps[mesh] - return P - - -class BaseTimeRx(BaseRx): - """SimPEG Receiver Object""" - - times = None #: Times when the receivers were active. - - def __init__(self, locs, times, rxType, **kwargs): - self.times = times - BaseRx.__init__(self, locs, rxType, **kwargs) - - @property - def nD(self): - """Number of data in the receiver.""" - return self.locs.shape[0] * len(self.times) - - def getP(self, mesh, timeMesh): - """ - Returns the projection matrices as a - list for all components collected by - the receivers. - - .. note:: - - Projection matrices are stored as a dictionary (mesh, timeMesh) - """ - if (mesh, timeMesh) not in self._Ps: - Ps = mesh.getInterpolationMat(self.locs, self.projGLoc) - Pt = timeMesh.getInterpolationMat(self.times, 'N') - self._Ps[(mesh, timeMesh)] = sp.kron(Pt, Ps) - - P = self._Ps[(mesh, timeMesh)] - return P - - - -class BaseTx(object): - """SimPEG Transmitter Object""" - - loc = None #: Location [x,y,z] - - rxList = None #: SimPEG Receiver List - rxPair = BaseRx - - knownTxTypes = None #: Set this to a list of strings to ensure that txType is known - - def __init__(self, loc, txType, rxList, **kwargs): - assert type(rxList) is list, 'rxList must be a list' - for rx in rxList: - assert isinstance(rx, self.rxPair), 'rxList must be a %s'%self.rxListPair.__name__ - assert len(set(rxList)) == len(rxList), 'The rxList must be unique' - - self.loc = loc - self.txType = txType - self.rxList = rxList - Utils.setKwargs(self, **kwargs) - - @property - def txType(self): - """Transmitter Type""" - return getattr(self, '_txType', None) - @txType.setter - def txType(self, value): - known = self.knownTxTypes - if known is not None: - assert value in known, "txType must be in ['%s']" % ("', '".join(known)) - self._txType = value - - @property - def nD(self): - """Number of data""" - return self.vnD.sum() - - @property - def vnD(self): - """Vector number of data""" - return np.array([rx.nD for rx in self.rxList]) - if __name__ == '__main__': d = BaseData() d.dpred() diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Survey.py new file mode 100644 index 00000000..f9866bb5 --- /dev/null +++ b/SimPEG/Tests/test_Survey.py @@ -0,0 +1,51 @@ +import unittest +from SimPEG import * + +class DataTest(unittest.TestCase): + + def setUp(self): + mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) + x = np.linspace(5,10,3) + XYZ = Utils.ndgrid(x,x,np.r_[0.]) + txLoc = np.r_[0,0,0.] + rxList0 = Survey.BaseRx(XYZ, 'exi') + Tx0 = Survey.BaseTx(txLoc, 'VMD', [rxList0]) + rxList1 = Survey.BaseRx(XYZ, 'bxi') + Tx1 = Survey.BaseTx(txLoc, 'VMD', [rxList1]) + rxList2 = Survey.BaseRx(XYZ, 'bxi') + Tx2 = Survey.BaseTx(txLoc, 'VMD', [rxList2]) + rxList3 = Survey.BaseRx(XYZ, 'bxi') + Tx3 = Survey.BaseTx(txLoc, 'VMD', [rxList3]) + Tx4 = Survey.BaseTx(txLoc, 'VMD', [rxList0, rxList1, rxList2, rxList3]) + txList = [Tx0,Tx1,Tx2,Tx3,Tx4] + survey = Survey.BaseSurvey(txList=txList) + self.D = Survey.Data(survey) + self.Tx0 = Tx0 + self.Tx1 = Tx1 + self.mesh = mesh + self.XYZ = XYZ + + def test_data(self): + V = [] + for tx in self.D.survey.txList: + for rx in tx.rxList: + v = np.random.rand(rx.nD) + V += [v] + self.D[tx, rx] = v + self.assertTrue(np.all(v == self.D[tx, rx])) + V = np.concatenate(V) + self.assertTrue(np.all(V == Utils.mkvc(self.D))) + + D2 = Survey.Data(self.D.survey, V) + self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D))) + + + def test_uniqueTxs(self): + txs = self.D.survey.txList + txs += [txs[0]] + self.assertRaises(AssertionError, Survey.BaseSurvey, txList=txs) + + + +if __name__ == '__main__': + unittest.main()