This commit is contained in:
Rowan Cockett
2015-05-29 11:17:56 -07:00
parent 59fcd3925f
commit de27c4e4ec
3 changed files with 6 additions and 9 deletions
+1 -7
View File
@@ -71,13 +71,7 @@ class Fields(object):
if type(srcTestList) is slice:
ind = srcTestList
else:
if type(srcTestList) is not list:
srcTestList = [srcTestList]
for srcTest in srcTestList:
if srcTest not in self.survey.srcList:
raise KeyError('Invalid Source, not in survey list.')
ind = np.in1d(self.survey.srcList, srcTestList)
ind = self.survey.getSourceIndex(srcTestList)
return ind
def _nameIndex(self, name, accessType):
+5
View File
@@ -231,6 +231,11 @@ class BaseSurvey(object):
[self._sourceOrder.setdefault(src.uid, ii) for ii, src in enumerate(self._srcList)]
def getSourceIndex(self, sources):
if type(sources) is not list:
sources = [sources]
for src in sources:
if getattr(src,'uid',None) is None:
raise KeyError('Source does not have a uid: %s'%str(src))
inds = map(lambda src: self._sourceOrder.get(src.uid, None), sources)
if None in inds:
raise KeyError('Some of the sources specified are not in this survey. %s'%str(inds))
-2
View File
@@ -49,7 +49,5 @@ class TestData(unittest.TestCase):
self.assertRaises(KeyError, survey.getSourceIndex, [SrcNotThere])
self.assertRaises(KeyError, survey.getSourceIndex, [srcs[1],srcs[2],SrcNotThere])
if __name__ == '__main__':
unittest.main()