Merge pull request #1054 from quantopian/with-trading-env

WithTradingEnvironmnet and WithSimParams
This commit is contained in:
Joe Jevnik
2016-03-15 19:35:33 -04:00
7 changed files with 157 additions and 34 deletions
+1 -2
View File
@@ -1,4 +1,3 @@
-e git://github.com/quantopian/datashape.git@9bd8fb970a0fc55e866a0b46b5101c9aa47e24ed#egg=datashape-dev
-e git://github.com/quantopian/odo.git@4f7f45fb039d89ea101803b95da21fc055901d66#egg=odo-dev
-e git://github.com/quantopian/blaze.git@0b6e76122a57c7115f18c6fdbd5fbab5501fd486#egg=blaze-dev
-e git://github.com/quantopian/blaze.git@32f39b1dadefc1e686c45dc23db0ecb56e753938#egg=blaze-dev
+23 -23
View File
@@ -108,7 +108,7 @@ class BlazeToPipelineTestCase(TestCase):
def test_tabular(self):
name = 'expr'
expr = bz.Data(self.df, name=name, dshape=self.dshape)
expr = bz.data(self.df, name=name, dshape=self.dshape)
ds = from_blaze(
expr,
loader=self.garbage_loader,
@@ -145,7 +145,7 @@ class BlazeToPipelineTestCase(TestCase):
def test_column(self):
exprname = 'expr'
expr = bz.Data(self.df, name=exprname, dshape=self.dshape)
expr = bz.data(self.df, name=exprname, dshape=self.dshape)
value = from_blaze(
expr.value,
loader=self.garbage_loader,
@@ -189,7 +189,7 @@ class BlazeToPipelineTestCase(TestCase):
self.assertEqual(value.dataset.__name__, exprname)
def test_missing_asof(self):
expr = bz.Data(
expr = bz.data(
self.df.loc[:, ['sid', 'value', 'timestamp']],
name='expr',
dshape="""
@@ -210,7 +210,7 @@ class BlazeToPipelineTestCase(TestCase):
self.assertIn(repr(str(expr.dshape.measure)), str(e.exception))
def test_auto_deltas(self):
expr = bz.Data(
expr = bz.data(
{'ds': self.df,
'ds_deltas': pd.DataFrame(columns=self.df.columns)},
dshape=var * Record((
@@ -233,7 +233,7 @@ class BlazeToPipelineTestCase(TestCase):
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter('always')
loader = BlazeLoader()
expr = bz.Data(self.df, dshape=self.dshape)
expr = bz.data(self.df, dshape=self.dshape)
from_blaze(
expr,
loader=loader,
@@ -247,7 +247,7 @@ class BlazeToPipelineTestCase(TestCase):
def test_auto_deltas_fail_raise(self):
loader = BlazeLoader()
expr = bz.Data(self.df, dshape=self.dshape)
expr = bz.data(self.df, dshape=self.dshape)
with self.assertRaises(ValueError) as e:
from_blaze(
expr,
@@ -257,7 +257,7 @@ class BlazeToPipelineTestCase(TestCase):
self.assertIn(str(expr), str(e.exception))
def test_non_numpy_field(self):
expr = bz.Data(
expr = bz.data(
[],
dshape="""
var * {
@@ -279,7 +279,7 @@ class BlazeToPipelineTestCase(TestCase):
# NOTE: This test will fail if we ever allow string types in
# the Pipeline API. If this happens, change the dtype of the `a` field
# of expr to another type we don't allow.
expr = bz.Data(
expr = bz.data(
[],
dshape="""
var * {
@@ -301,7 +301,7 @@ class BlazeToPipelineTestCase(TestCase):
)
def test_complex_expr(self):
expr = bz.Data(self.df, dshape=self.dshape)
expr = bz.data(self.df, dshape=self.dshape)
# put an Add in the table
expr_with_add = bz.transform(expr, value=expr.value + 1)
@@ -321,7 +321,7 @@ class BlazeToPipelineTestCase(TestCase):
missing_values=self.missing_values,
)
deltas = bz.Data(
deltas = bz.data(
pd.DataFrame(columns=self.df.columns),
dshape=self.dshape,
)
@@ -342,7 +342,7 @@ class BlazeToPipelineTestCase(TestCase):
)
def _test_id(self, df, dshape, expected, finder, add):
expr = bz.Data(df, name='expr', dshape=dshape)
expr = bz.data(df, name='expr', dshape=dshape)
loader = BlazeLoader()
ds = from_blaze(
expr,
@@ -375,7 +375,7 @@ class BlazeToPipelineTestCase(TestCase):
timedelta(hours=8, minutes=44)
).tz_convert('utc').tz_localize(None)
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45')
expr = bz.Data(df, name='expr', dshape=self.dshape)
expr = bz.data(df, name='expr', dshape=self.dshape)
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
ds = from_blaze(
expr,
@@ -861,9 +861,9 @@ class BlazeToPipelineTestCase(TestCase):
'int_value': (3, 4, 5),
})
df = df.append(extra_sid_df, ignore_index=True)
expr = bz.Data(df, name='expr', dshape=self.dshape)
deltas = bz.Data(df, dshape=self.dshape)
deltas = bz.Data(
expr = bz.data(df, name='expr', dshape=self.dshape)
deltas = bz.data(df, dshape=self.dshape)
deltas = bz.data(
odo(
bz.transform(
deltas,
@@ -916,14 +916,14 @@ class BlazeToPipelineTestCase(TestCase):
@with_extra_sid
def test_deltas_only_one_delta_in_universe(self, asset_info):
expr = bz.Data(self.df, name='expr', dshape=self.dshape)
expr = bz.data(self.df, name='expr', dshape=self.dshape)
deltas = pd.DataFrame({
'sid': [65, 66],
'asof_date': [self.dates[1], self.dates[0]],
'timestamp': [self.dates[2], self.dates[1]],
'value': [10, 11],
})
deltas = bz.Data(deltas, name='deltas', dshape=self.dshape)
deltas = bz.data(deltas, name='deltas', dshape=self.dshape)
expected_views = keymap(pd.Timestamp, {
'2014-01-02': np.array([[0.0, 11.0, 2.0],
[1.0, 2.0, 3.0]]),
@@ -968,8 +968,8 @@ class BlazeToPipelineTestCase(TestCase):
def test_deltas_macro(self):
asset_info = asset_infos[0][0]
expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape)
deltas = bz.Data(
expr = bz.data(self.macro_df, name='expr', dshape=self.macro_dshape)
deltas = bz.data(
self.macro_df.iloc[:-1],
name='deltas',
dshape=self.macro_dshape,
@@ -1023,8 +1023,8 @@ class BlazeToPipelineTestCase(TestCase):
'asof_date': repeated_dates,
'timestamp': repeated_dates,
})
expr = bz.Data(baseline, name='expr', dshape=self.dshape)
deltas = bz.Data(
expr = bz.data(baseline, name='expr', dshape=self.dshape)
deltas = bz.data(
odo(
bz.transform(
expr,
@@ -1094,8 +1094,8 @@ class BlazeToPipelineTestCase(TestCase):
'asof_date': base_dates,
'timestamp': base_dates,
})
expr = bz.Data(baseline, name='expr', dshape=self.macro_dshape)
deltas = bz.Data(baseline, name='deltas', dshape=self.macro_dshape)
expr = bz.data(baseline, name='expr', dshape=self.macro_dshape)
deltas = bz.data(baseline, name='deltas', dshape=self.macro_dshape)
deltas = bz.transform(
deltas,
value=deltas.value + 10,
+2 -2
View File
@@ -225,7 +225,7 @@ class BlazeCashBuybackAuthLoaderTestCase(CashBuybackAuthLoaderTestCase):
BlazeCashBuybackAuthLoaderTestCase,
self,
).loader_args(dates)
return (bz.Data(pd.concat(
return (bz.data(pd.concat(
pd.DataFrame({
BUYBACK_ANNOUNCEMENT_FIELD_NAME:
frame[BUYBACK_ANNOUNCEMENT_FIELD_NAME],
@@ -252,7 +252,7 @@ class BlazeShareBuybackAuthLoaderTestCase(ShareBuybackAuthLoaderTestCase):
BlazeShareBuybackAuthLoaderTestCase,
self,
).loader_args(dates)
return (bz.Data(pd.concat(
return (bz.data(pd.concat(
pd.DataFrame({
BUYBACK_ANNOUNCEMENT_FIELD_NAME:
frame[BUYBACK_ANNOUNCEMENT_FIELD_NAME],
+1 -1
View File
@@ -208,7 +208,7 @@ class BlazeEarningsCalendarLoaderTestCase(EarningsCalendarLoaderTestCase):
BlazeEarningsCalendarLoaderTestCase,
self,
).loader_args(dates)
return (bz.Data(pd.concat(
return (bz.data(pd.concat(
pd.DataFrame({
ANNOUNCEMENT_FIELD_NAME: df[ANNOUNCEMENT_FIELD_NAME],
TS_FIELD_NAME: df[TS_FIELD_NAME],
+1 -1
View File
@@ -241,7 +241,7 @@ class BlazeEventLoaderTestCase(TestCase):
TypeError, re.escape(ABSTRACT_CONCRETE_LOADER_ERROR)
):
BlazeEventDataSetLoaderNoConcreteLoader(
bz.Data(
bz.data(
pd.DataFrame({ANNOUNCEMENT_FIELD_NAME: dtx,
SID_FIELD_NAME: 0})
)
+1 -1
View File
@@ -1081,7 +1081,7 @@ def bind_expression_to_resources(expr, resources):
# prefixes symbol-manipulation methods with underscores to prevent
# collisions with data column names.
return expr._subs({
k: bz.Data(v, dshape=k.dshape) for k, v in iteritems(resources)
k: bz.data(v, dshape=k.dshape) for k, v in iteritems(resources)
})
+128 -4
View File
@@ -1,11 +1,13 @@
from unittest import TestCase
from contextlib2 import ExitStack
from logbook import NullHandler
import pandas as pd
from six import with_metaclass
from .core import tmp_asset_finder
from ..utils import tradingcalendar
from ..finance.trading import TradingEnvironment
from ..utils import tradingcalendar, factory
from ..utils.final import FinalMeta, final
@@ -71,6 +73,11 @@ class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
def add_class_callback(cls, callback):
"""
Register a callback to be executed during tearDownClass.
Parameters
----------
callback : callable
The callback to invoke at the end of the test suite.
"""
return cls._class_teardown_stack.callback(callback)
@@ -108,30 +115,147 @@ class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
def add_instance_callback(self, callback):
"""
Register a callback to be executed during tearDown.
Parameters
----------
callback : callable
The callback to invoke at the end of each test.
"""
return self._instance_teardown_stack.callback(callback)
class WithLogger(object):
"""ZiplineTestCase mixing providing cls.log_handler as an instance-level
fixture.
after init_instance_fixtures has been called `self.log_handler` will be a
new ``NullHandler``.
This behaviour may be overridden by defining a ``make_log_handler`` class
method which returns a new handler object.
"""
make_log_handler = NullHandler
@classmethod
def init_class_fixtures(cls):
super(WithLogger, cls).init_instance_fixtures()
cls.log_handler = cls.enter_class_context(
cls.make_log_handler().applicationbound(),
)
class WithAssetFinder(object):
"""
ZiplineTestCase mixin providing cls.asset_finder as a class-level fixture.
After init_class_fixtures has been called, `cls.asset_finder` is populated
with an AssetFinder. The default finder just a `tmp_asset_finder`
with an AssetFinder. The default finder just a `tmp_asset_finder` whose
arguments come from:
equities=cls.make_equities_info(),
futures=cls.make_futures_info(),
exchanges=cls.make_exchanges_info(),
root_symbols=cls.make_root_symbols_info(),
Each of these fields may be overridden with a function named
``make_*_info`` which returnsa dataframe to write.
This behavior can be altered by overriding `make_asset_finder` as a class
method.
"""
@classmethod
def _make_info(cls):
return None
make_equities_info = _make_info
make_futures_info = _make_info
make_exchanges_info = _make_info
make_root_symbols_info = _make_info
del _make_info
@classmethod
def make_asset_finder(cls):
return cls.enter_class_context(tmp_asset_finder())
return cls.enter_class_context(tmp_asset_finder(
equities=cls.equities_info,
futures=cls.futures_info,
exchanges=cls.exchanges_info,
root_symbols=cls.root_symbols_info,
))
@classmethod
def init_class_fixtures(cls):
super(WithAssetFinder, cls).init_class_fixtures()
cls.equities_info = cls.make_equities_info()
cls.futures_info = cls.make_futures_info()
cls.exchanges_info = cls.make_exchanges_info()
cls.root_symbols_info = cls.make_root_symbols_info()
cls.asset_finder = cls.make_asset_finder()
class WithTradingEnvironment(WithAssetFinder):
"""
ZiplineTestCase mixin providing cls.env as a class-level fixture.
After ``init_class_fixtures`` has been called, `cls.env` is populated
with a trading environment whose `asset_finder` is the result of
`cls.make_asset_finder`.
The ``load`` function may be provided by overriding the
``make_load_function`` class method.
This behavior can be altered by overriding `make_trading_environment` as a
class method.
"""
@classmethod
def make_load_function(cls):
return None
@classmethod
def make_trading_environment(cls):
return TradingEnvironment(
load=cls.make_load_function(),
asset_db_path=cls.asset_finder.engine,
)
@classmethod
def init_class_fixtures(cls):
super(WithTradingEnvironment, cls).init_class_fixtures()
cls.env = cls.make_trading_environment()
class WithSimParams(WithTradingEnvironment):
"""
ZiplineTestCase mixin providing cls.sim_params as a class level fixture.
The arguments used to construct the trading environment may be overridded
by putting ``SIM_PARAMS_{argname}`` in the class dict except for the
trading environment which is overridden with the mechanisms provided by
``WithTradingEnvironment``.
"""
SIM_PARAMS_YEAR = None
SIM_PARAMS_START = pd.Timestamp('2006-01-01')
SIM_PARAMS_END = pd.Timestamp('2006-12-31')
SIM_PARAMS_CAPITAL_BASE = float("1.0e5")
SIM_PARAMS_NUM_DAYS = None
SIM_PARAMS_DATA_FREQUENCY = 'daily'
SIM_PARAMS_EMISSION_RATE = 'daily'
@classmethod
def init_class_fixtures(cls):
super(WithSimParams, cls).init_class_fixtures()
cls.sim_params = factory.create_simulation_parameters(
year=cls.SIM_PARAMS_YEAR,
start=cls.SIM_PARAMS_START,
end=cls.SIM_PARAMS_END,
capital_base=cls.SIM_PARAMS_CAPITAL_BASE,
data_frequency=cls.SIM_PARAMS_DATA_FREQUENCY,
emission_rate=cls.SIM_PARAMS_EMISSION_RATE,
env=cls.env,
)
class WithNYSETradingDays(object):
"""
ZiplineTestCase mixin providing cls.trading_days as a class-level fixture.
@@ -144,7 +268,7 @@ class WithNYSETradingDays(object):
The default value of TRADING_DAY_COUNT is 126 (half a trading-year).
Inheritors can override TRADING_DAY_COUNT to request more or less data.
"""
DATA_MAX_DAY = pd.Timestamp('2015-01-02')
DATA_MAX_DAY = pd.Timestamp('2016-01-04')
TRADING_DAY_COUNT = 126
@classmethod