From 3f1b0f79f29b040be0d4632e7aa76b01d790e39a Mon Sep 17 00:00:00 2001 From: Jean Bredeche Date: Thu, 5 May 2016 11:54:39 -0400 Subject: [PATCH] DEV: Ensure there are no duplicates in the data passed into TradingAlgorithm.run --- tests/test_panel_daily_bar_reader.py | 46 ++++++++++++++++++++++++++++ zipline/data/us_equity_pricing.py | 13 +++++++- 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/test_panel_daily_bar_reader.py diff --git a/tests/test_panel_daily_bar_reader.py b/tests/test_panel_daily_bar_reader.py new file mode 100644 index 00000000..be7b7c15 --- /dev/null +++ b/tests/test_panel_daily_bar_reader.py @@ -0,0 +1,46 @@ +# +# Copyright 2016 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from zipline.data.us_equity_pricing import PanelDailyBarReader +from zipline.testing.fixtures import WithTradingEnvironment, ZiplineTestCase + + +class TestPanelDailyBarReader(WithTradingEnvironment, ZiplineTestCase): + def test_duplicate_values(self): + df = pd.DataFrame() + panel = pd.concat([pd.Panel({"X": df}), pd.Panel({"X": df})]) + + with self.assertRaises(ValueError) as e: + # panel's items has duplicates + PanelDailyBarReader(None, panel) + + self.assertEqual("Duplicated items found: ['X']", + e.exception.message) + + with self.assertRaises(ValueError) as e: + # panel's major axis has duplicates + PanelDailyBarReader(None, panel.swapaxes(0, 1)) + + self.assertEqual("Duplicated items found: ['X']", + e.exception.message) + + with self.assertRaises(ValueError) as e: + # panel's minor axis has duplicates + PanelDailyBarReader(None, panel.swapaxes(0, 2)) + + self.assertEqual("Duplicated items found: ['X']", + e.exception.message) diff --git a/zipline/data/us_equity_pricing.py b/zipline/data/us_equity_pricing.py index fbfb1a06..7637c001 100644 --- a/zipline/data/us_equity_pricing.py +++ b/zipline/data/us_equity_pricing.py @@ -711,6 +711,17 @@ class PanelDailyBarReader(DailyBarReader): The first trading day in the dataset. """ def __init__(self, calendar, panel): + # check duplicates on all indices of panel + + for attr_name in ["items", "major_axis", "minor_axis"]: + index = getattr(panel, attr_name) + duplicates = index.duplicated() + + if duplicates.any(): + raise ValueError("Duplicated items found: {0}".format( + index[duplicates].values + )) + panel = panel.copy() if 'volume' not in panel.items: # Fake volume if it does not exist. @@ -760,7 +771,7 @@ class PanelDailyBarReader(DailyBarReader): Returns -1 if the day is within the date range, but the price is 0. """ - return self.panel[sid, day, colname] + return self.panel.loc[sid, day, colname] def get_last_traded_dt(self, sid, dt): """