diff --git a/tests/test_panel_daily_bar_reader.py b/tests/test_panel_daily_bar_reader.py new file mode 100644 index 00000000..178a4553 --- /dev/null +++ b/tests/test_panel_daily_bar_reader.py @@ -0,0 +1,48 @@ +# +# Copyright 2016 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import permutations + +import pandas as pd + +from zipline.data.us_equity_pricing import PanelDailyBarReader +from zipline.testing import ExplodingObject +from zipline.testing.fixtures import ZiplineTestCase + + +class TestPanelDailyBarReader(ZiplineTestCase): + def test_duplicate_values(self): + UNIMPORTANT_VALUE = 57 + + panel = pd.Panel( + UNIMPORTANT_VALUE, + items=['a', 'b', 'b', 'a'], + major_axis=['c'], + minor_axis=['d'], + ) + unused = ExplodingObject() + + axis_names = ['items', 'major_axis', 'minor_axis'] + + for axis_order in permutations((0, 1, 2)): + with self.assertRaises(ValueError) as e: + PanelDailyBarReader(unused, panel.transpose(*axis_order)) + + expected = ( + "Duplicate entries in Panel.{name}: ['a', 'b'].".format( + name=axis_names[axis_order.index(0)], + ) + ) + self.assertEqual(str(e.exception), expected) diff --git a/zipline/data/us_equity_pricing.py b/zipline/data/us_equity_pricing.py index fbfb1a06..197b74ca 100644 --- a/zipline/data/us_equity_pricing.py +++ b/zipline/data/us_equity_pricing.py @@ -54,10 +54,12 @@ from six import ( ) from zipline.utils.functional import apply +from zipline.utils.preprocess import call from zipline.utils.input_validation import ( coerce_string, preprocess, expect_element, + verify_indices_all_unique, ) from zipline.utils.sqlite_utils import group_into_chunks from zipline.utils.memoize import lazyval @@ -696,9 +698,12 @@ class PanelDailyBarReader(DailyBarReader): DataPanel Structure ------- - items : Int64Index, asset identifiers - major_axis : DatetimeIndex, days provided by the Panel. + items : Int64Index + Asset identifiers. Must be unique. + major_axis : DatetimeIndex + Dates for data provided provided by the Panel. Must be unique. minor_axis : ['open', 'high', 'low', 'close', 'volume'] + Price attributes. Must be unique. Attributes ---------- @@ -710,7 +715,9 @@ class PanelDailyBarReader(DailyBarReader): first_trading_day : pd.Timestamp The first trading day in the dataset. """ + @preprocess(panel=call(verify_indices_all_unique)) def __init__(self, calendar, panel): + panel = panel.copy() if 'volume' not in panel.items: # Fake volume if it does not exist. @@ -760,7 +767,7 @@ class PanelDailyBarReader(DailyBarReader): Returns -1 if the day is within the date range, but the price is 0. """ - return self.panel[sid, day, colname] + return self.panel.loc[sid, day, colname] def get_last_traded_dt(self, sid, dt): """ diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py index 0424e6f4..840ec5b6 100644 --- a/zipline/utils/input_validation.py +++ b/zipline/utils/input_validation.py @@ -26,6 +26,43 @@ from zipline.utils.functional import getattrs from zipline.utils.preprocess import call, preprocess +def verify_indices_all_unique(obj): + """ + Check that all axes of a pandas object are unique. + + Parameters + ---------- + obj : pd.Series / pd.DataFrame / pd.Panel + The object to validate. + + Returns + ------- + None + + Raises + ------ + ValueError + If any axis has duplicate entries. + """ + axis_names = [ + ('index',), # Series + ('index', 'columns'), # DataFrame + ('items', 'major_axis', 'minor_axis') # Panel + ][obj.ndim - 1] # ndim = 1 should go to entry 0, + + for axis_name, index in zip(axis_names, obj.axes): + if index.is_unique: + continue + + raise ValueError( + "Duplicate entries in {type}.{axis}: {dupes}.".format( + type=type(obj).__name__, + axis=axis_name, + dupes=sorted(index[index.duplicated()]), + ) + ) + + def optionally(preprocessor): """Modify a preprocessor to explicitly allow `None`.