diff --git a/tests/test_sources.py b/tests/test_sources.py index 15d79e46..75de66cd 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -14,6 +14,7 @@ # limitations under the License. import pandas as pd import pytz +from itertools import cycle from unittest import TestCase @@ -59,16 +60,16 @@ class TestDataFrameSource(TestCase): start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc) end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) data = factory.load_bars_from_yahoo(stocks=stocks, + indexes={}, start=start, end=end) + check_fields = ['sid', 'open', 'high', 'low', 'close', + 'volume', 'price'] source = DataPanelSource(data) + stocks_iter = cycle(stocks) for event in source: - self.assertTrue('sid' in event) - self.assertTrue('open' in event) - self.assertTrue('high' in event) - self.assertTrue('low' in event) - self.assertTrue('close' in event) - self.assertTrue('volume' in event) - self.assertTrue('price' in event) + for check_field in check_fields: + self.assertIn(check_field, event) self.assertTrue(isinstance(event['volume'], (int, long))) + self.assertEqual(stocks_iter.next(), event['sid']) diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index 7b0b5bd2..10bfe505 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -134,8 +134,9 @@ class DataPanelSource(DataSource): return self.arg_string def raw_data_gen(self): - for sid, dataframe in self.data.iteritems(): - for dt, series in dataframe.iterrows(): + for dt in self.data.major_axis: + df = self.data.major_xs(dt) + for sid, series in df.iterkv(): if sid in self.sids: event = { 'dt': dt,