From 14ee13fadb340ddbbc56e663aca1bc7c6b9e04f1 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 10:18:26 -0700 Subject: [PATCH 1/6] remove redundant mkvc option --- SimPEG/Problem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SimPEG/Problem.py b/SimPEG/Problem.py index e729ee10..87111b31 100644 --- a/SimPEG/Problem.py +++ b/SimPEG/Problem.py @@ -263,7 +263,7 @@ class TimeFields(Fields): for i, TIND_i in enumerate(timeII): fieldI = pointerFields[:,:,i] if fieldI.shape[0] == fieldI.size: - fieldI = Utils.mkvc(fieldI,1) + fieldI = Utils.mkvc(fieldI) out[i] = func(fieldI, srcII, TIND_i) if out[i].ndim == 1: out[i] = out[i][:,np.newaxis,np.newaxis] From 7e171ede05a06844b6323e09711ba6be9b4ec95d Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 10:26:53 -0700 Subject: [PATCH 2/6] Move Fields Objects to their own file. --- SimPEG/Fields.py | 273 +++++++++++++++++++++++++++++++++++++++++++++ SimPEG/Problem.py | 275 +--------------------------------------------- 2 files changed, 274 insertions(+), 274 deletions(-) create mode 100644 SimPEG/Fields.py diff --git a/SimPEG/Fields.py b/SimPEG/Fields.py new file mode 100644 index 00000000..f81cfb67 --- /dev/null +++ b/SimPEG/Fields.py @@ -0,0 +1,273 @@ +import Utils, numpy as np, scipy.sparse as sp + +class Fields(object): + """Fancy Field Storage + + u[:,'phi'] = phi + print u[src0,'phi'] + + """ + + knownFields = None #: Known fields, a dict with locations, e.g. {"e": "E", "phi": "CC"} + aliasFields = None #: Aliased fields, a dict with [alias, location, function], e.g. {"b":["e","F",lambda(F,e,ind)]} + dtype = float #: dtype is the type of the storage matrix. This can be a dictionary. + + def __init__(self, mesh, survey, **kwargs): + self.survey = survey + self.mesh = mesh + Utils.setKwargs(self, **kwargs) + self._fields = {} + + if self.knownFields is None: + raise Exception('knownFields cannot be set to None') + if self.aliasFields is None: + self.aliasFields = {} + + allFields = [k for k in self.knownFields] + [a for a in self.aliasFields] + assert len(allFields) == len(set(allFields)), 'Aliased fields and Known Fields have overlapping definitions.' + self.startup() + + def startup(self): + pass + + @property + def approxSize(self): + """The approximate cost to storing all of the known fields.""" + sz = 0.0 + for f in self.knownFields: + loc =self.knownFields[f] + sz += np.array(self._storageShape(loc)).prod()*8.0/(1024**2) + return "%e MB"%sz + + def _storageShape(self, loc): + nSrc = self.survey.nSrc + + nP = {'CC': self.mesh.nC, + 'N': self.mesh.nN, + 'F': self.mesh.nF, + 'E': self.mesh.nE}[loc] + + return (nP, nSrc) + + def _initStore(self, name): + if name in self._fields: + return self._fields[name] + + assert name in self.knownFields, 'field name is not known.' + + loc = self.knownFields[name] + + if type(self.dtype) is dict: + dtype = self.dtype[name] + else: + dtype = self.dtype + field = np.zeros(self._storageShape(loc), dtype=dtype) + + self._fields[name] = field + + return field + + def _srcIndex(self, srcTestList): + if type(srcTestList) is slice: + ind = srcTestList + else: + if type(srcTestList) is not list: + srcTestList = [srcTestList] + for srcTest in srcTestList: + if srcTest not in self.survey.srcList: + raise KeyError('Invalid Source, not in survey list.') + + ind = np.in1d(self.survey.srcList, srcTestList) + return ind + + def _nameIndex(self, name, accessType): + + if type(name) is slice: + assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.' + name = None + + if name is None: + return + if accessType=='set' and name not in self.knownFields: + if name in self.aliasFields: + raise KeyError("Invalid field name (%s) for setter, you can't set an aliased property"%name) + else: + raise KeyError('Invalid field name (%s) for setter'%name) + + elif accessType=='get' and (name not in self.knownFields and name not in self.aliasFields): + raise KeyError('Invalid field name (%s) for getter'%name) + return name + + def _indexAndNameFromKey(self, key, accessType): + if type(key) is not tuple: + key = (key,) + if len(key) == 1: + key += (None,) + + assert len(key) == 2, 'must be [Src, fieldName]' + + srcTestList, name = key + name = self._nameIndex(name, accessType) + ind = self._srcIndex(srcTestList) + return ind, name + + def __setitem__(self, key, value): + ind, name = self._indexAndNameFromKey(key, 'set') + if name is None: + freq = key + assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.' + newFields = value + elif name in self.knownFields: + newFields = {name: value} + else: + raise Exception('Unknown setter') + + for name in newFields: + field = self._initStore(name) + self._setField(field, newFields[name], name, ind) + + def __getitem__(self, key): + ind, name = self._indexAndNameFromKey(key, 'get') + if name is None: + out = {} + for name in self._fields: + out[name] = self._getField(name, ind) + return out + return self._getField(name, ind) + + def _setField(self, field, val, name, ind): + if isinstance(val, np.ndarray) and (field.shape[0] == field.size or val.ndim == 1): + val = Utils.mkvc(val,2) + field[:,ind] = val + + def _getField(self, name, ind): + if name in self._fields: + out = self._fields[name][:,ind] + else: + # Aliased fields + alias, loc, func = self.aliasFields[name] + + srcII = np.array(self.survey.srcList)[ind] + if isinstance(srcII, np.ndarray): + srcII = srcII.tolist() + if len(srcII) == 1: + srcII = srcII[0] + + if type(func) is str: + assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' + func = getattr(self, func) + out = func(self._fields[alias][:,ind], srcII) + if out.shape[0] == out.size or out.ndim == 1: + out = Utils.mkvc(out,2) + return out + + def __contains__(self, other): + if other in self.aliasFields: + other = self.aliasFields[other][0] + return self._fields.__contains__(other) + + +class TimeFields(Fields): + """Fancy Field Storage for time domain problems + + u[:,'phi', timeInd] = phi + print u[src0,'phi'] + + """ + + def _storageShape(self, loc): + nP = {'CC': self.mesh.nC, + 'N': self.mesh.nN, + 'F': self.mesh.nF, + 'E': self.mesh.nE}[loc] + nSrc = self.survey.nSrc + nT = self.survey.prob.nT + 1 + return (nP, nSrc, nT) + + def _indexAndNameFromKey(self, key, accessType): + if type(key) is not tuple: + key = (key,) + if len(key) == 1: + key += (None,) + if len(key) == 2: + key += (slice(None,None,None),) + + assert len(key) == 3, 'must be [Src, fieldName, times]' + + srcTestList, name, timeInd = key + + name = self._nameIndex(name, accessType) + srcInd = self._srcIndex(srcTestList) + + return (srcInd, timeInd), name + + def _correctShape(self, name, ind, deflate=False): + srcInd, timeInd = ind + if name in self.knownFields: + loc = self.knownFields[name] + else: + loc = self.aliasFields[name][1] + nP, total_nSrc, total_nT = self._storageShape(loc) + nSrc = np.ones(total_nSrc, dtype=bool)[srcInd].sum() + nT = np.ones(total_nT, dtype=bool)[timeInd].sum() + shape = nP, nSrc, nT + if deflate: + shape = tuple([s for s in shape if s > 1]) + if len(shape) == 1: + shape = shape + (1,) + return shape + + def _setField(self, field, val, name, ind): + srcInd, timeInd = ind + shape = self._correctShape(name, ind) + if Utils.isScalar(val): + field[:,srcInd,timeInd] = val + return + if val.size != np.array(shape).prod(): + raise ValueError('Incorrect size for data.') + correctShape = field[:,srcInd,timeInd].shape + field[:,srcInd,timeInd] = val.reshape(correctShape, order='F') + + def _getField(self, name, ind): + srcInd, timeInd = ind + + if name in self._fields: + out = self._fields[name][:,srcInd,timeInd] + else: + # Aliased fields + alias, loc, func = self.aliasFields[name] + if type(func) is str: + assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' + func = getattr(self, func) + pointerFields = self._fields[alias][:,srcInd,timeInd] + pointerShape = self._correctShape(alias, ind) + pointerFields = pointerFields.reshape(pointerShape, order='F') + + timeII = np.arange(self.survey.prob.nT + 1)[timeInd] + srcII = np.array(self.survey.srcList)[srcInd] + if isinstance(srcII, np.ndarray): + srcII = srcII.tolist() + if len(srcII) == 1: + srcII = srcII[0] + + if timeII.size == 1: + pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) + pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F') + out = func(pointerFields, srcII, timeII) + else: #loop over the time steps + nT = pointerShape[2] + out = range(nT) + for i, TIND_i in enumerate(timeII): + fieldI = pointerFields[:,:,i] + if fieldI.shape[0] == fieldI.size: + fieldI = Utils.mkvc(fieldI) + out[i] = func(fieldI, srcII, TIND_i) + if out[i].ndim == 1: + out[i] = out[i][:,np.newaxis,np.newaxis] + elif out[i].ndim == 2: + out[i] = out[i][:,:,np.newaxis] + out = np.concatenate(out, axis=2) + + shape = self._correctShape(name, ind, deflate=True) + return out.reshape(shape, order='F') + diff --git a/SimPEG/Problem.py b/SimPEG/Problem.py index 87111b31..75e9d4bf 100644 --- a/SimPEG/Problem.py +++ b/SimPEG/Problem.py @@ -1,280 +1,7 @@ import Utils, Survey, Models, numpy as np, scipy.sparse as sp Solver = Utils.SolverUtils.Solver import Maps, Mesh - - -class Fields(object): - """Fancy Field Storage - - u[:,'phi'] = phi - print u[src0,'phi'] - - """ - - knownFields = None #: Known fields, a dict with locations, e.g. {"e": "E", "phi": "CC"} - aliasFields = None #: Aliased fields, a dict with [alias, location, function], e.g. {"b":["e","F",lambda(F,e,ind)]} - dtype = float #: dtype is the type of the storage matrix. This can be a dictionary. - - def __init__(self, mesh, survey, **kwargs): - self.survey = survey - self.mesh = mesh - Utils.setKwargs(self, **kwargs) - self._fields = {} - - if self.knownFields is None: - raise Exception('knownFields cannot be set to None') - if self.aliasFields is None: - self.aliasFields = {} - - allFields = [k for k in self.knownFields] + [a for a in self.aliasFields] - assert len(allFields) == len(set(allFields)), 'Aliased fields and Known Fields have overlapping definitions.' - self.startup() - - def startup(self): - pass - - @property - def approxSize(self): - """The approximate cost to storing all of the known fields.""" - sz = 0.0 - for f in self.knownFields: - loc =self.knownFields[f] - sz += np.array(self._storageShape(loc)).prod()*8.0/(1024**2) - return "%e MB"%sz - - def _storageShape(self, loc): - nSrc = self.survey.nSrc - - nP = {'CC': self.mesh.nC, - 'N': self.mesh.nN, - 'F': self.mesh.nF, - 'E': self.mesh.nE}[loc] - - return (nP, nSrc) - - def _initStore(self, name): - if name in self._fields: - return self._fields[name] - - assert name in self.knownFields, 'field name is not known.' - - loc = self.knownFields[name] - - if type(self.dtype) is dict: - dtype = self.dtype[name] - else: - dtype = self.dtype - field = np.zeros(self._storageShape(loc), dtype=dtype) - - self._fields[name] = field - - return field - - def _srcIndex(self, srcTestList): - if type(srcTestList) is slice: - ind = srcTestList - else: - if type(srcTestList) is not list: - srcTestList = [srcTestList] - for srcTest in srcTestList: - if srcTest not in self.survey.srcList: - raise KeyError('Invalid Source, not in survey list.') - - ind = np.in1d(self.survey.srcList, srcTestList) - return ind - - def _nameIndex(self, name, accessType): - - if type(name) is slice: - assert name == slice(None,None,None), 'Fancy field name slicing is not supported... yet.' - name = None - - if name is None: - return - if accessType=='set' and name not in self.knownFields: - if name in self.aliasFields: - raise KeyError("Invalid field name (%s) for setter, you can't set an aliased property"%name) - else: - raise KeyError('Invalid field name (%s) for setter'%name) - - elif accessType=='get' and (name not in self.knownFields and name not in self.aliasFields): - raise KeyError('Invalid field name (%s) for getter'%name) - return name - - def _indexAndNameFromKey(self, key, accessType): - if type(key) is not tuple: - key = (key,) - if len(key) == 1: - key += (None,) - - assert len(key) == 2, 'must be [Src, fieldName]' - - srcTestList, name = key - name = self._nameIndex(name, accessType) - ind = self._srcIndex(srcTestList) - return ind, name - - def __setitem__(self, key, value): - ind, name = self._indexAndNameFromKey(key, 'set') - if name is None: - freq = key - assert type(value) is dict, 'New fields must be a dictionary, if field is not specified.' - newFields = value - elif name in self.knownFields: - newFields = {name: value} - else: - raise Exception('Unknown setter') - - for name in newFields: - field = self._initStore(name) - self._setField(field, newFields[name], name, ind) - - def __getitem__(self, key): - ind, name = self._indexAndNameFromKey(key, 'get') - if name is None: - out = {} - for name in self._fields: - out[name] = self._getField(name, ind) - return out - return self._getField(name, ind) - - def _setField(self, field, val, name, ind): - if isinstance(val, np.ndarray) and (field.shape[0] == field.size or val.ndim == 1): - val = Utils.mkvc(val,2) - field[:,ind] = val - - def _getField(self, name, ind): - if name in self._fields: - out = self._fields[name][:,ind] - else: - # Aliased fields - alias, loc, func = self.aliasFields[name] - - srcII = np.array(self.survey.srcList)[ind] - if isinstance(srcII, np.ndarray): - srcII = srcII.tolist() - if len(srcII) == 1: - srcII = srcII[0] - - if type(func) is str: - assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' - func = getattr(self, func) - out = func(self._fields[alias][:,ind], srcII) - if out.shape[0] == out.size or out.ndim == 1: - out = Utils.mkvc(out,2) - return out - - def __contains__(self, other): - if other in self.aliasFields: - other = self.aliasFields[other][0] - return self._fields.__contains__(other) - - -class TimeFields(Fields): - """Fancy Field Storage for time domain problems - - u[:,'phi', timeInd] = phi - print u[src0,'phi'] - - """ - - def _storageShape(self, loc): - nP = {'CC': self.mesh.nC, - 'N': self.mesh.nN, - 'F': self.mesh.nF, - 'E': self.mesh.nE}[loc] - nSrc = self.survey.nSrc - nT = self.survey.prob.nT + 1 - return (nP, nSrc, nT) - - def _indexAndNameFromKey(self, key, accessType): - if type(key) is not tuple: - key = (key,) - if len(key) == 1: - key += (None,) - if len(key) == 2: - key += (slice(None,None,None),) - - assert len(key) == 3, 'must be [Src, fieldName, times]' - - srcTestList, name, timeInd = key - - name = self._nameIndex(name, accessType) - srcInd = self._srcIndex(srcTestList) - - return (srcInd, timeInd), name - - def _correctShape(self, name, ind, deflate=False): - srcInd, timeInd = ind - if name in self.knownFields: - loc = self.knownFields[name] - else: - loc = self.aliasFields[name][1] - nP, total_nSrc, total_nT = self._storageShape(loc) - nSrc = np.ones(total_nSrc, dtype=bool)[srcInd].sum() - nT = np.ones(total_nT, dtype=bool)[timeInd].sum() - shape = nP, nSrc, nT - if deflate: - shape = tuple([s for s in shape if s > 1]) - if len(shape) == 1: - shape = shape + (1,) - return shape - - def _setField(self, field, val, name, ind): - srcInd, timeInd = ind - shape = self._correctShape(name, ind) - if Utils.isScalar(val): - field[:,srcInd,timeInd] = val - return - if val.size != np.array(shape).prod(): - raise ValueError('Incorrect size for data.') - correctShape = field[:,srcInd,timeInd].shape - field[:,srcInd,timeInd] = val.reshape(correctShape, order='F') - - def _getField(self, name, ind): - srcInd, timeInd = ind - - if name in self._fields: - out = self._fields[name][:,srcInd,timeInd] - else: - # Aliased fields - alias, loc, func = self.aliasFields[name] - if type(func) is str: - assert hasattr(self, func), 'The alias field function is a string, but it does not exist in the Fields class.' - func = getattr(self, func) - pointerFields = self._fields[alias][:,srcInd,timeInd] - pointerShape = self._correctShape(alias, ind) - pointerFields = pointerFields.reshape(pointerShape, order='F') - - timeII = np.arange(self.survey.prob.nT + 1)[timeInd] - srcII = np.array(self.survey.srcList)[srcInd] - if isinstance(srcII, np.ndarray): - srcII = srcII.tolist() - if len(srcII) == 1: - srcII = srcII[0] - - if timeII.size == 1: - pointerShapeDeflated = self._correctShape(alias, ind, deflate=True) - pointerFields = pointerFields.reshape(pointerShapeDeflated, order='F') - out = func(pointerFields, srcII, timeII) - else: #loop over the time steps - nT = pointerShape[2] - out = range(nT) - for i, TIND_i in enumerate(timeII): - fieldI = pointerFields[:,:,i] - if fieldI.shape[0] == fieldI.size: - fieldI = Utils.mkvc(fieldI) - out[i] = func(fieldI, srcII, TIND_i) - if out[i].ndim == 1: - out[i] = out[i][:,np.newaxis,np.newaxis] - elif out[i].ndim == 2: - out[i] = out[i][:,:,np.newaxis] - out = np.concatenate(out, axis=2) - - shape = self._correctShape(name, ind, deflate=True) - return out.reshape(shape, order='F') - - +from Fields import Fields, TimeFields class BaseProblem(object): """ From 116f7620a6beeee411d53a6ac68b75bd686fc528 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 10:32:48 -0700 Subject: [PATCH 3/6] Split up the tests --- .../Tests/{test_Survey.py => test_Fields.py} | 31 +++---------- SimPEG/Tests/test_SurveyAndData.py | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+), 25 deletions(-) rename SimPEG/Tests/{test_Survey.py => test_Fields.py} (95%) create mode 100644 SimPEG/Tests/test_SurveyAndData.py diff --git a/SimPEG/Tests/test_Survey.py b/SimPEG/Tests/test_Fields.py similarity index 95% rename from SimPEG/Tests/test_Survey.py rename to SimPEG/Tests/test_Fields.py index efc1e901..22189a15 100644 --- a/SimPEG/Tests/test_Survey.py +++ b/SimPEG/Tests/test_Fields.py @@ -1,7 +1,8 @@ import unittest from SimPEG import * -class DataAndFieldsTest(unittest.TestCase): + +class FieldsTest(unittest.TestCase): def setUp(self): mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) @@ -26,25 +27,6 @@ class DataAndFieldsTest(unittest.TestCase): self.mesh = mesh self.XYZ = XYZ - def test_overlappingFields(self): - self.assertRaises(AssertionError, Problem.Fields, self.F.mesh, self.F.survey, - knownFields={'b':'F'}, - aliasFields={'b':['b',(lambda F, b, ind: b)]}) - - def test_data(self): - V = [] - for src in self.D.survey.srcList: - for rx in src.rxList: - v = np.random.rand(rx.nD) - V += [v] - self.D[src, rx] = v - self.assertTrue(np.all(v == self.D[src, rx])) - V = np.concatenate(V) - self.assertTrue(np.all(V == Utils.mkvc(self.D))) - - D2 = Survey.Data(self.D.survey, V) - self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D))) - def test_contains(self): F = self.F nSrc = F.survey.nSrc @@ -55,10 +37,10 @@ class DataAndFieldsTest(unittest.TestCase): self.assertTrue('b' not in F) self.assertTrue('e' in F) - def test_uniqueSrcs(self): - srcs = self.D.survey.srcList - srcs += [srcs[0]] - self.assertRaises(AssertionError, Survey.BaseSurvey, srcList=srcs) + def test_overlappingFields(self): + self.assertRaises(AssertionError, Problem.Fields, self.F.mesh, self.F.survey, + knownFields={'b':'F'}, + aliasFields={'b':['b',(lambda F, b, ind: b)]}) def test_SetGet(self): F = self.F @@ -132,7 +114,6 @@ class FieldsTest_Alias(unittest.TestCase): Src4 = Survey.BaseSrc([rxList0, rxList1, rxList2, rxList3],loc=srcLoc) srcList = [Src0,Src1,Src2,Src3,Src4] survey = Survey.BaseSurvey(srcList=srcList) - self.D = Survey.Data(survey) self.F = Problem.Fields(mesh, survey, knownFields={'e':'E'}, aliasFields={'b':['e','F',(lambda e, ind: self.F.mesh.edgeCurl * e)]}) self.Src0 = Src0 self.Src1 = Src1 diff --git a/SimPEG/Tests/test_SurveyAndData.py b/SimPEG/Tests/test_SurveyAndData.py new file mode 100644 index 00000000..bb6b4645 --- /dev/null +++ b/SimPEG/Tests/test_SurveyAndData.py @@ -0,0 +1,45 @@ +import unittest +from SimPEG import * + +class TestData(unittest.TestCase): + + def setUp(self): + mesh = Mesh.TensorMesh([np.ones(n)*5 for n in [10,11,12]],[0,0,-30]) + x = np.linspace(5,10,3) + XYZ = Utils.ndgrid(x,x,np.r_[0.]) + srcLoc = np.r_[0,0,0.] + rxList0 = Survey.BaseRx(XYZ, 'exi') + Src0 = Survey.BaseSrc([rxList0], loc=srcLoc) + rxList1 = Survey.BaseRx(XYZ, 'bxi') + Src1 = Survey.BaseSrc([rxList1], loc=srcLoc) + rxList2 = Survey.BaseRx(XYZ, 'bxi') + Src2 = Survey.BaseSrc([rxList2], loc=srcLoc) + rxList3 = Survey.BaseRx(XYZ, 'bxi') + Src3 = Survey.BaseSrc([rxList3], loc=srcLoc) + Src4 = Survey.BaseSrc([rxList0, rxList1, rxList2, rxList3], loc=srcLoc) + srcList = [Src0,Src1,Src2,Src3,Src4] + survey = Survey.BaseSurvey(srcList=srcList) + self.D = Survey.Data(survey) + + def test_data(self): + V = [] + for src in self.D.survey.srcList: + for rx in src.rxList: + v = np.random.rand(rx.nD) + V += [v] + self.D[src, rx] = v + self.assertTrue(np.all(v == self.D[src, rx])) + V = np.concatenate(V) + self.assertTrue(np.all(V == Utils.mkvc(self.D))) + + D2 = Survey.Data(self.D.survey, V) + self.assertTrue(np.all(Utils.mkvc(D2) == Utils.mkvc(self.D))) + + def test_uniqueSrcs(self): + srcs = self.D.survey.srcList + srcs += [srcs[0]] + self.assertRaises(AssertionError, Survey.BaseSurvey, srcList=srcs) + + +if __name__ == '__main__': + unittest.main() From 59fcd3925f1df6ca74ec9672cde0b57b2d247390 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 11:03:47 -0700 Subject: [PATCH 4/6] getSourceIndex --- SimPEG/Survey.py | 14 ++++++++++++-- SimPEG/Tests/test_SurveyAndData.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index 8abb547a..b6a17f28 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -1,4 +1,4 @@ -import Utils, numpy as np, scipy.sparse as sp +import Utils, numpy as np, scipy.sparse as sp, uuid class BaseRx(object): @@ -13,6 +13,7 @@ class BaseRx(object): storeProjections = True #: Store calls to getP (organized by mesh) def __init__(self, locs, rxType, **kwargs): + self.uid = str(uuid.uuid4()) self.locs = locs self.rxType = rxType self._Ps = {} @@ -124,7 +125,7 @@ class BaseSrc(object): for rx in rxList: assert isinstance(rx, self.rxPair), 'rxList must be a %s'%self.rxPair.__name__ assert len(set(rxList)) == len(rxList), 'The rxList must be unique' - + self.uid = str(uuid.uuid4()) self.rxList = rxList Utils.setKwargs(self, **kwargs) @@ -144,6 +145,7 @@ class Data(object): """Fancy data storage by Src and Rx""" def __init__(self, survey, v=None): + self.uid = str(uuid.uuid4()) self.survey = survey self._dataDict = {} for src in self.survey.srcList: @@ -225,6 +227,14 @@ class BaseSurvey(object): assert np.all([isinstance(src, self.srcPair) for src in value]), 'All sources must be instances of %s' % self.srcPair.__name__ assert len(set(value)) == len(value), 'The srcList must be unique' self._srcList = value + self._sourceOrder = dict() + [self._sourceOrder.setdefault(src.uid, ii) for ii, src in enumerate(self._srcList)] + + def getSourceIndex(self, sources): + inds = map(lambda src: self._sourceOrder.get(src.uid, None), sources) + if None in inds: + raise KeyError('Some of the sources specified are not in this survey. %s'%str(inds)) + return inds @property def prob(self): diff --git a/SimPEG/Tests/test_SurveyAndData.py b/SimPEG/Tests/test_SurveyAndData.py index bb6b4645..6feccc71 100644 --- a/SimPEG/Tests/test_SurveyAndData.py +++ b/SimPEG/Tests/test_SurveyAndData.py @@ -40,6 +40,16 @@ class TestData(unittest.TestCase): srcs += [srcs[0]] self.assertRaises(AssertionError, Survey.BaseSurvey, srcList=srcs) + def test_sourceIndex(self): + survey = self.D.survey + srcs = survey.srcList + assert survey.getSourceIndex([srcs[1],srcs[0]]) == [1,0] + assert survey.getSourceIndex([srcs[1],srcs[2],srcs[2]]) == [1,2,2] + SrcNotThere = Survey.BaseSrc(srcs[0].rxList, loc=np.r_[0,0,0]) + self.assertRaises(KeyError, survey.getSourceIndex, [SrcNotThere]) + self.assertRaises(KeyError, survey.getSourceIndex, [srcs[1],srcs[2],SrcNotThere]) + + if __name__ == '__main__': unittest.main() From de27c4e4ec94bcd18b5cb27bed81bf3d86aa3347 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 11:17:56 -0700 Subject: [PATCH 5/6] fixes #99 --- SimPEG/Fields.py | 8 +------- SimPEG/Survey.py | 5 +++++ SimPEG/Tests/test_SurveyAndData.py | 2 -- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/SimPEG/Fields.py b/SimPEG/Fields.py index f81cfb67..edd4cd92 100644 --- a/SimPEG/Fields.py +++ b/SimPEG/Fields.py @@ -71,13 +71,7 @@ class Fields(object): if type(srcTestList) is slice: ind = srcTestList else: - if type(srcTestList) is not list: - srcTestList = [srcTestList] - for srcTest in srcTestList: - if srcTest not in self.survey.srcList: - raise KeyError('Invalid Source, not in survey list.') - - ind = np.in1d(self.survey.srcList, srcTestList) + ind = self.survey.getSourceIndex(srcTestList) return ind def _nameIndex(self, name, accessType): diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index b6a17f28..0fdb0cd1 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -231,6 +231,11 @@ class BaseSurvey(object): [self._sourceOrder.setdefault(src.uid, ii) for ii, src in enumerate(self._srcList)] def getSourceIndex(self, sources): + if type(sources) is not list: + sources = [sources] + for src in sources: + if getattr(src,'uid',None) is None: + raise KeyError('Source does not have a uid: %s'%str(src)) inds = map(lambda src: self._sourceOrder.get(src.uid, None), sources) if None in inds: raise KeyError('Some of the sources specified are not in this survey. %s'%str(inds)) diff --git a/SimPEG/Tests/test_SurveyAndData.py b/SimPEG/Tests/test_SurveyAndData.py index 6feccc71..f02f14e8 100644 --- a/SimPEG/Tests/test_SurveyAndData.py +++ b/SimPEG/Tests/test_SurveyAndData.py @@ -49,7 +49,5 @@ class TestData(unittest.TestCase): self.assertRaises(KeyError, survey.getSourceIndex, [SrcNotThere]) self.assertRaises(KeyError, survey.getSourceIndex, [srcs[1],srcs[2],SrcNotThere]) - - if __name__ == '__main__': unittest.main() From 2827e85330312d11b01cb4daf0350a997e9a82a5 Mon Sep 17 00:00:00 2001 From: Rowan Cockett Date: Fri, 29 May 2015 11:36:56 -0700 Subject: [PATCH 6/6] fix mkvc to return an (n,1) array for consistency --- SimPEG/Fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SimPEG/Fields.py b/SimPEG/Fields.py index edd4cd92..1801bedc 100644 --- a/SimPEG/Fields.py +++ b/SimPEG/Fields.py @@ -254,7 +254,7 @@ class TimeFields(Fields): for i, TIND_i in enumerate(timeII): fieldI = pointerFields[:,:,i] if fieldI.shape[0] == fieldI.size: - fieldI = Utils.mkvc(fieldI) + fieldI = Utils.mkvc(fieldI, 2) out[i] = func(fieldI, srcII, TIND_i) if out[i].ndim == 1: out[i] = out[i][:,np.newaxis,np.newaxis]