mirror of
https://github.com/wassname/simpeg.git
synced 2026-06-27 19:32:36 +08:00
fixes #99
This commit is contained in:
+1
-7
@@ -71,13 +71,7 @@ class Fields(object):
|
||||
if type(srcTestList) is slice:
|
||||
ind = srcTestList
|
||||
else:
|
||||
if type(srcTestList) is not list:
|
||||
srcTestList = [srcTestList]
|
||||
for srcTest in srcTestList:
|
||||
if srcTest not in self.survey.srcList:
|
||||
raise KeyError('Invalid Source, not in survey list.')
|
||||
|
||||
ind = np.in1d(self.survey.srcList, srcTestList)
|
||||
ind = self.survey.getSourceIndex(srcTestList)
|
||||
return ind
|
||||
|
||||
def _nameIndex(self, name, accessType):
|
||||
|
||||
@@ -231,6 +231,11 @@ class BaseSurvey(object):
|
||||
[self._sourceOrder.setdefault(src.uid, ii) for ii, src in enumerate(self._srcList)]
|
||||
|
||||
def getSourceIndex(self, sources):
|
||||
if type(sources) is not list:
|
||||
sources = [sources]
|
||||
for src in sources:
|
||||
if getattr(src,'uid',None) is None:
|
||||
raise KeyError('Source does not have a uid: %s'%str(src))
|
||||
inds = map(lambda src: self._sourceOrder.get(src.uid, None), sources)
|
||||
if None in inds:
|
||||
raise KeyError('Some of the sources specified are not in this survey. %s'%str(inds))
|
||||
|
||||
@@ -49,7 +49,5 @@ class TestData(unittest.TestCase):
|
||||
self.assertRaises(KeyError, survey.getSourceIndex, [SrcNotThere])
|
||||
self.assertRaises(KeyError, survey.getSourceIndex, [srcs[1],srcs[2],SrcNotThere])
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user