BUG: Filter nans in DataFrame and Panel sources.

If a SID hasn't started trading yet, pandas' convention is to use nans.
Before this change, zipline would raise an exception if there were nans in the
input data.

We now skip events where the prices contains a nan and has not been traded
before (in which case forward fill).

Fixes #446.
This commit is contained in:
Thomas Wiecki
2015-01-02 13:42:59 +01:00
parent a257a43e99
commit d578d5825e
2 changed files with 55 additions and 9 deletions
+33
View File
@@ -79,6 +79,39 @@ class TestDataFrameSource(TestCase):
self.assertTrue(isinstance(event['volume'], (integer_types)))
self.assertEqual(next(stocks_iter), event['sid'])
def test_nan_filter_dataframe(self):
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
df = pd.DataFrame(np.random.randn(2, 2),
index=dates,
columns=['A', 'B'])
df.loc[dates[0], 'A'] = np.nan # should be filtered
df.loc[dates[1], 'B'] = np.nan # should not be filtered
source = DataFrameSource(df)
event = next(source)
self.assertEqual('B', event.sid)
event = next(source)
self.assertEqual('A', event.sid)
event = next(source)
self.assertEqual('B', event.sid)
self.assertTrue(np.isnan(event.price))
def test_nan_filter_panel(self):
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
df = pd.Panel(np.random.randn(2, 2, 2),
major_axis=dates,
items=['A', 'B'],
minor_axis=['price', 'volume'])
df.loc['A', dates[0], 'price'] = np.nan # should be filtered
df.loc['B', dates[1], 'price'] = np.nan # should not be filtered
source = DataPanelSource(df)
event = next(source)
self.assertEqual('B', event.sid)
event = next(source)
self.assertEqual('A', event.sid)
event = next(source)
self.assertEqual('B', event.sid)
self.assertTrue(np.isnan(event.price))
class TestRandomWalkSource(TestCase):
def test_minute(self):
+22 -9
View File
@@ -1,5 +1,5 @@
#
# Copyright 2014 Quantopian, Inc.
# Copyright 2015 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
"""
Tools to generate data sources.
"""
import numpy as np
import pandas as pd
from zipline.gens.utils import hash_args
@@ -31,10 +32,8 @@ class DataFrameSource(DataSource):
* columns : sids
* index : datetime
sids : list of values representing simulated internal sids
start : start date
delta : timedelta between internal events
filter : filter to remove the sids
:Note:
Bars where the price is nan are filtered out.
"""
def __init__(self, data, **kwargs):
@@ -51,6 +50,8 @@ class DataFrameSource(DataSource):
self._raw_data = None
self.started_sids = set()
@property
def mapping(self):
return {
@@ -68,6 +69,12 @@ class DataFrameSource(DataSource):
for dt, series in self.data.iterrows():
for sid, price in series.iteritems():
if sid in self.sids:
# Skip SIDs that can not be forward filled
if np.isnan(price) and \
sid not in self.started_sids:
continue
self.started_sids.add(sid)
event = {
'dt': dt,
'sid': sid,
@@ -94,10 +101,8 @@ class DataPanelSource(DataSource):
* major_axis : datetime
* minor_axis : price, volume, ...
sids : list of values representing simulated internal sids
start : start date
delta : timedelta between internal events
filter : filter to remove the sids
:Note:
Bars where the price is nan are filtered out.
"""
def __init__(self, data, **kwargs):
@@ -114,6 +119,8 @@ class DataPanelSource(DataSource):
self._raw_data = None
self.started_sids = set()
@property
def mapping(self):
mapping = {
@@ -140,6 +147,12 @@ class DataPanelSource(DataSource):
df = self.data.major_xs(dt)
for sid, series in df.iteritems():
if sid in self.sids:
# Skip SIDs that can not be forward filled
if np.isnan(series['price']) and \
sid not in self.started_sids:
continue
self.started_sids.add(sid)
event = {
'dt': dt,
'sid': sid,