mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 02:20:07 +08:00
BUG: Filter nans in DataFrame and Panel sources.
If a SID hasn't started trading yet, pandas' convention is to use nans. Before this change, zipline would raise an exception if there were nans in the input data. We now skip events where the prices contains a nan and has not been traded before (in which case forward fill). Fixes #446.
This commit is contained in:
@@ -79,6 +79,39 @@ class TestDataFrameSource(TestCase):
|
||||
self.assertTrue(isinstance(event['volume'], (integer_types)))
|
||||
self.assertEqual(next(stocks_iter), event['sid'])
|
||||
|
||||
def test_nan_filter_dataframe(self):
|
||||
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
|
||||
df = pd.DataFrame(np.random.randn(2, 2),
|
||||
index=dates,
|
||||
columns=['A', 'B'])
|
||||
df.loc[dates[0], 'A'] = np.nan # should be filtered
|
||||
df.loc[dates[1], 'B'] = np.nan # should not be filtered
|
||||
source = DataFrameSource(df)
|
||||
event = next(source)
|
||||
self.assertEqual('B', event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual('A', event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual('B', event.sid)
|
||||
self.assertTrue(np.isnan(event.price))
|
||||
|
||||
def test_nan_filter_panel(self):
|
||||
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
|
||||
df = pd.Panel(np.random.randn(2, 2, 2),
|
||||
major_axis=dates,
|
||||
items=['A', 'B'],
|
||||
minor_axis=['price', 'volume'])
|
||||
df.loc['A', dates[0], 'price'] = np.nan # should be filtered
|
||||
df.loc['B', dates[1], 'price'] = np.nan # should not be filtered
|
||||
source = DataPanelSource(df)
|
||||
event = next(source)
|
||||
self.assertEqual('B', event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual('A', event.sid)
|
||||
event = next(source)
|
||||
self.assertEqual('B', event.sid)
|
||||
self.assertTrue(np.isnan(event.price))
|
||||
|
||||
|
||||
class TestRandomWalkSource(TestCase):
|
||||
def test_minute(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# Copyright 2014 Quantopian, Inc.
|
||||
# Copyright 2015 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,7 @@
|
||||
"""
|
||||
Tools to generate data sources.
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from zipline.gens.utils import hash_args
|
||||
@@ -31,10 +32,8 @@ class DataFrameSource(DataSource):
|
||||
* columns : sids
|
||||
* index : datetime
|
||||
|
||||
sids : list of values representing simulated internal sids
|
||||
start : start date
|
||||
delta : timedelta between internal events
|
||||
filter : filter to remove the sids
|
||||
:Note:
|
||||
Bars where the price is nan are filtered out.
|
||||
"""
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
@@ -51,6 +50,8 @@ class DataFrameSource(DataSource):
|
||||
|
||||
self._raw_data = None
|
||||
|
||||
self.started_sids = set()
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
return {
|
||||
@@ -68,6 +69,12 @@ class DataFrameSource(DataSource):
|
||||
for dt, series in self.data.iterrows():
|
||||
for sid, price in series.iteritems():
|
||||
if sid in self.sids:
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(price) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
@@ -94,10 +101,8 @@ class DataPanelSource(DataSource):
|
||||
* major_axis : datetime
|
||||
* minor_axis : price, volume, ...
|
||||
|
||||
sids : list of values representing simulated internal sids
|
||||
start : start date
|
||||
delta : timedelta between internal events
|
||||
filter : filter to remove the sids
|
||||
:Note:
|
||||
Bars where the price is nan are filtered out.
|
||||
"""
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
@@ -114,6 +119,8 @@ class DataPanelSource(DataSource):
|
||||
|
||||
self._raw_data = None
|
||||
|
||||
self.started_sids = set()
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
mapping = {
|
||||
@@ -140,6 +147,12 @@ class DataPanelSource(DataSource):
|
||||
df = self.data.major_xs(dt)
|
||||
for sid, series in df.iteritems():
|
||||
if sid in self.sids:
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(series['price']) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
|
||||
Reference in New Issue
Block a user