mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 04:42:38 +08:00
d578d5825e
If a SID hasn't started trading yet, pandas' convention is to use nans. Before this change, zipline would raise an exception if there were nans in the input data. We now skip events where the prices contains a nan and has not been traded before (in which case forward fill). Fixes #446.
162 lines
6.6 KiB
Python
162 lines
6.6 KiB
Python
#
|
|
# Copyright 2013 Quantopian, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import pandas as pd
|
|
import pytz
|
|
from itertools import cycle
|
|
import numpy as np
|
|
|
|
from six import integer_types
|
|
|
|
from unittest import TestCase
|
|
|
|
import zipline.utils.factory as factory
|
|
from zipline.sources import (DataFrameSource,
|
|
DataPanelSource,
|
|
RandomWalkSource)
|
|
from zipline.utils import tradingcalendar as calendar_nyse
|
|
|
|
|
|
class TestDataFrameSource(TestCase):
|
|
def test_df_source(self):
|
|
source, df = factory.create_test_df_source()
|
|
assert isinstance(source.start, pd.lib.Timestamp)
|
|
assert isinstance(source.end, pd.lib.Timestamp)
|
|
|
|
for expected_dt, expected_price in df.iterrows():
|
|
sid0 = next(source)
|
|
|
|
assert expected_dt == sid0.dt
|
|
assert expected_price[0] == sid0.price
|
|
|
|
def test_df_sid_filtering(self):
|
|
_, df = factory.create_test_df_source()
|
|
source = DataFrameSource(df, sids=[0])
|
|
assert 1 not in [event.sid for event in source], \
|
|
"DataFrameSource should only stream selected sid 0, not sid 1."
|
|
|
|
def test_panel_source(self):
|
|
source, panel = factory.create_test_panel_source()
|
|
assert isinstance(source.start, pd.lib.Timestamp)
|
|
assert isinstance(source.end, pd.lib.Timestamp)
|
|
for event in source:
|
|
self.assertTrue('sid' in event)
|
|
self.assertTrue('arbitrary' in event)
|
|
self.assertTrue(hasattr(event, 'volume'))
|
|
self.assertTrue(hasattr(event, 'price'))
|
|
self.assertEquals(event['arbitrary'], 1.)
|
|
self.assertEquals(event['sid'], 0)
|
|
self.assertTrue(isinstance(event['volume'], int))
|
|
self.assertTrue(isinstance(event['arbitrary'], float))
|
|
|
|
def test_yahoo_bars_to_panel_source(self):
|
|
stocks = ['AAPL', 'GE']
|
|
start = pd.datetime(1993, 1, 1, 0, 0, 0, 0, pytz.utc)
|
|
end = pd.datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
|
|
data = factory.load_bars_from_yahoo(stocks=stocks,
|
|
indexes={},
|
|
start=start,
|
|
end=end)
|
|
|
|
check_fields = ['sid', 'open', 'high', 'low', 'close',
|
|
'volume', 'price']
|
|
source = DataPanelSource(data)
|
|
stocks_iter = cycle(stocks)
|
|
for event in source:
|
|
for check_field in check_fields:
|
|
self.assertIn(check_field, event)
|
|
self.assertTrue(isinstance(event['volume'], (integer_types)))
|
|
self.assertEqual(next(stocks_iter), event['sid'])
|
|
|
|
def test_nan_filter_dataframe(self):
|
|
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
|
|
df = pd.DataFrame(np.random.randn(2, 2),
|
|
index=dates,
|
|
columns=['A', 'B'])
|
|
df.loc[dates[0], 'A'] = np.nan # should be filtered
|
|
df.loc[dates[1], 'B'] = np.nan # should not be filtered
|
|
source = DataFrameSource(df)
|
|
event = next(source)
|
|
self.assertEqual('B', event.sid)
|
|
event = next(source)
|
|
self.assertEqual('A', event.sid)
|
|
event = next(source)
|
|
self.assertEqual('B', event.sid)
|
|
self.assertTrue(np.isnan(event.price))
|
|
|
|
def test_nan_filter_panel(self):
|
|
dates = pd.date_range('1/1/2000', periods=2, freq='B', tz='UTC')
|
|
df = pd.Panel(np.random.randn(2, 2, 2),
|
|
major_axis=dates,
|
|
items=['A', 'B'],
|
|
minor_axis=['price', 'volume'])
|
|
df.loc['A', dates[0], 'price'] = np.nan # should be filtered
|
|
df.loc['B', dates[1], 'price'] = np.nan # should not be filtered
|
|
source = DataPanelSource(df)
|
|
event = next(source)
|
|
self.assertEqual('B', event.sid)
|
|
event = next(source)
|
|
self.assertEqual('A', event.sid)
|
|
event = next(source)
|
|
self.assertEqual('B', event.sid)
|
|
self.assertTrue(np.isnan(event.price))
|
|
|
|
|
|
class TestRandomWalkSource(TestCase):
|
|
def test_minute(self):
|
|
np.random.seed(123)
|
|
start_prices = {0: 100,
|
|
1: 500}
|
|
start = pd.Timestamp('1990-01-01', tz='UTC')
|
|
end = pd.Timestamp('1991-01-01', tz='UTC')
|
|
source = RandomWalkSource(start_prices=start_prices,
|
|
calendar=calendar_nyse, start=start,
|
|
end=end)
|
|
self.assertIsInstance(source.start, pd.lib.Timestamp)
|
|
self.assertIsInstance(source.end, pd.lib.Timestamp)
|
|
|
|
for event in source:
|
|
self.assertIn(event.sid, start_prices.keys())
|
|
self.assertIn(event.dt.replace(minute=0, hour=0),
|
|
calendar_nyse.trading_days)
|
|
self.assertGreater(event.dt, start)
|
|
self.assertLess(event.dt, end)
|
|
self.assertGreater(event.price, 0,
|
|
"price should never go negative.")
|
|
self.assertTrue(13 <= event.dt.hour <= 21,
|
|
"event.dt.hour == %i, not during market \
|
|
hours." % event.dt.hour)
|
|
|
|
def test_day(self):
|
|
np.random.seed(123)
|
|
start_prices = {0: 100,
|
|
1: 500}
|
|
start = pd.Timestamp('1990-01-01', tz='UTC')
|
|
end = pd.Timestamp('1992-01-01', tz='UTC')
|
|
source = RandomWalkSource(start_prices=start_prices,
|
|
calendar=calendar_nyse, start=start,
|
|
end=end, freq='daily')
|
|
self.assertIsInstance(source.start, pd.lib.Timestamp)
|
|
self.assertIsInstance(source.end, pd.lib.Timestamp)
|
|
|
|
for event in source:
|
|
self.assertIn(event.sid, start_prices.keys())
|
|
self.assertIn(event.dt.replace(minute=0, hour=0),
|
|
calendar_nyse.trading_days)
|
|
self.assertGreater(event.dt, start)
|
|
self.assertLess(event.dt, end)
|
|
self.assertGreater(event.price, 0,
|
|
"price should never go negative.")
|
|
self.assertEqual(event.dt.hour, 0)
|