diff --git a/SimPEG/Fields.py b/SimPEG/Fields.py index f81cfb67..edd4cd92 100644 --- a/SimPEG/Fields.py +++ b/SimPEG/Fields.py @@ -71,13 +71,7 @@ class Fields(object): if type(srcTestList) is slice: ind = srcTestList else: - if type(srcTestList) is not list: - srcTestList = [srcTestList] - for srcTest in srcTestList: - if srcTest not in self.survey.srcList: - raise KeyError('Invalid Source, not in survey list.') - - ind = np.in1d(self.survey.srcList, srcTestList) + ind = self.survey.getSourceIndex(srcTestList) return ind def _nameIndex(self, name, accessType): diff --git a/SimPEG/Survey.py b/SimPEG/Survey.py index b6a17f28..0fdb0cd1 100644 --- a/SimPEG/Survey.py +++ b/SimPEG/Survey.py @@ -231,6 +231,11 @@ class BaseSurvey(object): [self._sourceOrder.setdefault(src.uid, ii) for ii, src in enumerate(self._srcList)] def getSourceIndex(self, sources): + if type(sources) is not list: + sources = [sources] + for src in sources: + if getattr(src,'uid',None) is None: + raise KeyError('Source does not have a uid: %s'%str(src)) inds = map(lambda src: self._sourceOrder.get(src.uid, None), sources) if None in inds: raise KeyError('Some of the sources specified are not in this survey. %s'%str(inds)) diff --git a/SimPEG/Tests/test_SurveyAndData.py b/SimPEG/Tests/test_SurveyAndData.py index 6feccc71..f02f14e8 100644 --- a/SimPEG/Tests/test_SurveyAndData.py +++ b/SimPEG/Tests/test_SurveyAndData.py @@ -49,7 +49,5 @@ class TestData(unittest.TestCase): self.assertRaises(KeyError, survey.getSourceIndex, [SrcNotThere]) self.assertRaises(KeyError, survey.getSourceIndex, [srcs[1],srcs[2],SrcNotThere]) - - if __name__ == '__main__': unittest.main()