diff --git a/tests/test_sources.py b/tests/test_sources.py index 8e9bc24b..a5e1e833 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -79,6 +79,39 @@ class TestDataFrameSource(TestCase): self.assertTrue(isinstance(event['volume'], (integer_types))) self.assertEqual(next(stocks_iter), event['sid']) + def test_nan_filter_dataframe(self): + dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC') + df = pd.DataFrame(np.random.randn(2, 2), + index=dates, + columns=['A', 'B']) + df.loc[dates[0], 'A'] = np.nan # should be filtered + df.loc[dates[1], 'B'] = np.nan # should not be filtered + source = DataFrameSource(df) + event = next(source) + self.assertEqual('B', event.sid) + event = next(source) + self.assertEqual('A', event.sid) + event = next(source) + self.assertEqual('B', event.sid) + self.assertTrue(np.isnan(event.price)) + + def test_nan_filter_panel(self): + dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC') + df = pd.Panel(np.random.randn(2, 2, 2), + major_axis=dates, + items=['A', 'B'], + minor_axis=['price', 'volume']) + df.loc['A', dates[0], 'price'] = np.nan # should be filtered + df.loc['B', dates[1], 'price'] = np.nan # should not be filtered + source = DataPanelSource(df) + event = next(source) + self.assertEqual('B', event.sid) + event = next(source) + self.assertEqual('A', event.sid) + event = next(source) + self.assertEqual('B', event.sid) + self.assertTrue(np.isnan(event.price)) + class TestRandomWalkSource(TestCase): def test_minute(self): diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index 4dfa225a..7a0771f4 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -1,5 +1,5 @@ # -# Copyright 2014 Quantopian, Inc. +# Copyright 2015 Quantopian, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ """ Tools to generate data sources. """ +import numpy as np import pandas as pd from zipline.gens.utils import hash_args @@ -31,10 +32,8 @@ class DataFrameSource(DataSource): * columns : sids * index : datetime - sids : list of values representing simulated internal sids - start : start date - delta : timedelta between internal events - filter : filter to remove the sids + :Note: + Bars where the price is nan are filtered out. """ def __init__(self, data, **kwargs): @@ -51,6 +50,8 @@ class DataFrameSource(DataSource): self._raw_data = None + self.started_sids = set() + @property def mapping(self): return { @@ -68,6 +69,12 @@ class DataFrameSource(DataSource): for dt, series in self.data.iterrows(): for sid, price in series.iteritems(): if sid in self.sids: + # Skip SIDs that can not be forward filled + if np.isnan(price) and \ + sid not in self.started_sids: + continue + self.started_sids.add(sid) + event = { 'dt': dt, 'sid': sid, @@ -94,10 +101,8 @@ class DataPanelSource(DataSource): * major_axis : datetime * minor_axis : price, volume, ... - sids : list of values representing simulated internal sids - start : start date - delta : timedelta between internal events - filter : filter to remove the sids + :Note: + Bars where the price is nan are filtered out. """ def __init__(self, data, **kwargs): @@ -114,6 +119,8 @@ class DataPanelSource(DataSource): self._raw_data = None + self.started_sids = set() + @property def mapping(self): mapping = { @@ -140,6 +147,12 @@ class DataPanelSource(DataSource): df = self.data.major_xs(dt) for sid, series in df.iteritems(): if sid in self.sids: + # Skip SIDs that can not be forward filled + if np.isnan(series['price']) and \ + sid not in self.started_sids: + continue + self.started_sids.add(sid) + event = { 'dt': dt, 'sid': sid,