From d578d5825ec925037ddea10fff652ffbaf32153b Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 2 Jan 2015 13:42:59 +0100 Subject: [PATCH] BUG: Filter nans in DataFrame and Panel sources. If a SID hasn't started trading yet, pandas' convention is to use nans. Before this change, zipline would raise an exception if there were nans in the input data. We now skip events where the prices contains a nan and has not been traded before (in which case forward fill). Fixes #446. --- tests/test_sources.py | 33 ++++++++++++++++++++++++++++ zipline/sources/data_frame_source.py | 31 ++++++++++++++++++-------- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/tests/test_sources.py b/tests/test_sources.py index 8e9bc24b..a5e1e833 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -79,6 +79,39 @@ class TestDataFrameSource(TestCase): self.assertTrue(isinstance(event['volume'], (integer_types))) self.assertEqual(next(stocks_iter), event['sid']) + def test_nan_filter_dataframe(self): + dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC') + df = pd.DataFrame(np.random.randn(2, 2), + index=dates, + columns=['A', 'B']) + df.loc[dates[0], 'A'] = np.nan # should be filtered + df.loc[dates[1], 'B'] = np.nan # should not be filtered + source = DataFrameSource(df) + event = next(source) + self.assertEqual('B', event.sid) + event = next(source) + self.assertEqual('A', event.sid) + event = next(source) + self.assertEqual('B', event.sid) + self.assertTrue(np.isnan(event.price)) + + def test_nan_filter_panel(self): + dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC') + df = pd.Panel(np.random.randn(2, 2, 2), + major_axis=dates, + items=['A', 'B'], + minor_axis=['price', 'volume']) + df.loc['A', dates[0], 'price'] = np.nan # should be filtered + df.loc['B', dates[1], 'price'] = np.nan # should not be filtered + source = DataPanelSource(df) + event = next(source) + self.assertEqual('B', event.sid) + event = next(source) + self.assertEqual('A', event.sid) + event = next(source) + self.assertEqual('B', event.sid) + self.assertTrue(np.isnan(event.price)) + class TestRandomWalkSource(TestCase): def test_minute(self): diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index 4dfa225a..7a0771f4 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -1,5 +1,5 @@ # -# Copyright 2014 Quantopian, Inc. +# Copyright 2015 Quantopian, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ """ Tools to generate data sources. """ +import numpy as np import pandas as pd from zipline.gens.utils import hash_args @@ -31,10 +32,8 @@ class DataFrameSource(DataSource): * columns : sids * index : datetime - sids : list of values representing simulated internal sids - start : start date - delta : timedelta between internal events - filter : filter to remove the sids + :Note: + Bars where the price is nan are filtered out. """ def __init__(self, data, **kwargs): @@ -51,6 +50,8 @@ class DataFrameSource(DataSource): self._raw_data = None + self.started_sids = set() + @property def mapping(self): return { @@ -68,6 +69,12 @@ class DataFrameSource(DataSource): for dt, series in self.data.iterrows(): for sid, price in series.iteritems(): if sid in self.sids: + # Skip SIDs that can not be forward filled + if np.isnan(price) and \ + sid not in self.started_sids: + continue + self.started_sids.add(sid) + event = { 'dt': dt, 'sid': sid, @@ -94,10 +101,8 @@ class DataPanelSource(DataSource): * major_axis : datetime * minor_axis : price, volume, ... - sids : list of values representing simulated internal sids - start : start date - delta : timedelta between internal events - filter : filter to remove the sids + :Note: + Bars where the price is nan are filtered out. """ def __init__(self, data, **kwargs): @@ -114,6 +119,8 @@ class DataPanelSource(DataSource): self._raw_data = None + self.started_sids = set() + @property def mapping(self): mapping = { @@ -140,6 +147,12 @@ class DataPanelSource(DataSource): df = self.data.major_xs(dt) for sid, series in df.iteritems(): if sid in self.sids: + # Skip SIDs that can not be forward filled + if np.isnan(series['price']) and \ + sid not in self.started_sids: + continue + self.started_sids.add(sid) + event = { 'dt': dt, 'sid': sid,