MAINT: PR cleanup

This commit is contained in:
Jean Bredeche
2016-07-27 15:52:34 -04:00
parent 6020752a1d
commit 97ccb54326
10 changed files with 98 additions and 122 deletions
+17 -14
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import namedtuple
import datetime
from datetime import timedelta
@@ -34,7 +35,6 @@ import pytz
from pandas.io.common import PerformanceWarning
from zipline import run_algorithm
from tests.warnings_catcher import WarningsCatcher
from zipline import TradingAlgorithm
from zipline.api import FixedSlippage
from zipline.assets import Equity, Future
@@ -1959,7 +1959,9 @@ def handle_data(context, data):
pass
""")
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
algo = TradingAlgorithm(
script=algocode,
sim_params=sim_params,
@@ -1967,18 +1969,19 @@ def handle_data(context, data):
)
algo.run(self.data_portal)
self.assertEqual(len(w), 2)
for i, warning in enumerate(w):
self.assertIsInstance(warning.message, UserWarning)
self.assertEqual(
warning.message.args[0],
'Got a time rule for the second positional argument '
'date_rule. You should use keyword argument '
'time_rule= when calling schedule_function without '
'specifying a date_rule'
)
# The warnings come from line 13 and 14 in the algocode
self.assertEqual(warning.lineno, 13 + i)
self.assertEqual(len(w), 2)
for i, warning in enumerate(w):
self.assertIsInstance(warning.message, UserWarning)
self.assertEqual(
warning.message.args[0],
'Got a time rule for the second positional argument '
'date_rule. You should use keyword argument '
'time_rule= when calling schedule_function without '
'specifying a date_rule'
)
# The warnings come from line 13 and 14 in the algocode
self.assertEqual(warning.lineno, 13 + i)
self.assertEqual(
algo.done_at_open,
+12 -7
View File
@@ -5,7 +5,6 @@ import numpy as np
import pandas as pd
from pandas.io.common import PerformanceWarning
from tests.warnings_catcher import WarningsCatcher
from zipline import TradingAlgorithm
from zipline.finance.trading import SimulationParameters
from zipline.protocol import BarData
@@ -292,7 +291,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
cease to be supported, we also want to assert that we're seeing a
deprecation warning.
"""
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(sid_accessor_algo)
algo.run(self.data_portal)
@@ -320,7 +320,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
We also want to assert that we warn that iterating over the assets
in `data` is deprecated.
"""
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(data_items_algo)
algo.run(self.data_portal)
@@ -344,7 +345,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
)
def test_iterate_data(self):
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(simple_algo)
@@ -374,7 +376,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
)
def test_history(self):
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
sim_params = self.sim_params.create_new(
@@ -415,7 +418,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
expected_vol_with_split)
def test_simple_transforms(self):
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
sim_params = SimulationParameters(
@@ -485,7 +489,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
self.assertAlmostEqual(346, algo.returns)
def test_manipulation(self):
with WarningsCatcher([PerformanceWarning]) as w:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore", PerformanceWarning)
warnings.simplefilter("default", ZiplineDeprecationWarning)
algo = self.create_algo(simple_algo)
+2 -3
View File
@@ -1,7 +1,6 @@
from datetime import time
from unittest import TestCase
import pandas as pd
import numpy as np
from zipline.gens.sim_engine import (
MinuteSimulationClock,
SESSION_START,
@@ -26,8 +25,8 @@ class TestClock(TestCase):
)
trading_o_and_c = cls.nyse_calendar.schedule.ix[cls.sessions]
cls.opens = trading_o_and_c['market_open'].values.astype(np.int64)
cls.closes = trading_o_and_c['market_close'].values.astype(np.int64)
cls.opens = trading_o_and_c['market_open']
cls.closes = trading_o_and_c['market_close']
def test_bts_before_session(self):
clock = MinuteSimulationClock(
+1
View File
@@ -120,6 +120,7 @@ class MinuteToDailyAggregationTestCase(WithBcolzEquityMinuteBarReader,
self.equity_daily_aggregator = DailyHistoryAggregator(
self.trading_calendar.schedule.market_open,
self.bcolz_equity_minute_bar_reader,
self.trading_calendar
)
@parameterized.expand([
-32
View File
@@ -1,32 +0,0 @@
from warnings import catch_warnings, WarningMessage
class WarningsCatcher(catch_warnings):
"""
Subclass of warnings.catch_warnings that takes a list of warning types to
ignore.
"""
def __init__(self, types_to_ignore=None):
super(WarningsCatcher, self).__init__(record=True)
self._types_to_ignore = set(types_to_ignore or [])
def __enter__(self):
if self._entered:
raise RuntimeError("Cannot enter %r twice" % self)
self._entered = True
self._filters = self._module.filters
self._module.filters = self._filters[:]
self._showwarning = self._module.showwarning
if self._record:
log = []
def showwarning(*args, **kwargs):
if args[1] in self._types_to_ignore:
return
log.append(WarningMessage(*args, **kwargs))
self._module.showwarning = showwarning
return log
else:
return None
+3 -3
View File
@@ -463,11 +463,11 @@ cdef class BarData:
cdef bool _can_trade_for_asset(self, asset, dt, adjusted_dt, data_portal):
session_label = normalize_date(dt) # FIXME
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
# asset isn't alive
return False
if not asset._asset_exchange_open(dt):
if not asset.is_exchange_open(dt):
# exchange isn't open
return False
@@ -525,7 +525,7 @@ cdef class BarData:
cdef bool _is_stale_for_asset(self, asset, dt, adjusted_dt, data_portal):
session_label = normalize_date(dt) # FIXME
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
return False
current_volume = data_portal.get_spot_value(
+2 -4
View File
@@ -496,13 +496,11 @@ class TradingAlgorithm(object):
"""
trading_o_and_c = self.trading_calendar.schedule.ix[
self.sim_params.sessions]
market_closes = trading_o_and_c['market_close'].values.astype(np.int64)
market_closes = trading_o_and_c['market_close']
minutely_emission = False
if self.sim_params.data_frequency == 'minute':
market_opens = trading_o_and_c['market_open'].values.astype(
np.int64
)
market_opens = trading_o_and_c['market_open']
minutely_emission = self.sim_params.emission_rate == "minute"
else:
+4 -6
View File
@@ -101,8 +101,6 @@ cdef class Asset:
self.first_traded = first_traded
self.auto_close_date = auto_close_date
def __int__(self):
return self.sid
@@ -201,7 +199,7 @@ cdef class Asset:
"""
return cls(**dict_)
def _is_alive_for_session(self, session_label):
def is_alive_for_session(self, session_label):
"""
Returns whether the asset is alive at the given dt.
@@ -222,7 +220,7 @@ cdef class Asset:
return ref_start <= session_label.value <= ref_end
def _asset_exchange_open(self, dt_minute):
def is_exchange_open(self, dt_minute):
"""
Parameters
----------
@@ -233,10 +231,10 @@ cdef class Asset:
-------
boolean: whether the asset's exchange is open at the given minute.
"""
calendar = self._exchange_trading_calendar_for_asset()
calendar = self.exchange_trading_calendar()
return calendar.is_open_on_minute(dt_minute)
def _exchange_trading_calendar_for_asset(self):
def exchange_trading_calendar(self):
"""
Get the calendar for this asset's exchange.
+5 -5
View File
@@ -99,7 +99,7 @@ class DailyHistoryAggregator(object):
session_label = self._trading_calendar.minute_to_session_label(dt)
for asset in assets:
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
opens.append(np.NaN)
continue
@@ -168,7 +168,7 @@ class DailyHistoryAggregator(object):
session_label = self._trading_calendar.minute_to_session_label(dt)
for asset in assets:
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
highs.append(np.NaN)
continue
@@ -237,7 +237,7 @@ class DailyHistoryAggregator(object):
session_label = self._trading_calendar.minute_to_session_label(dt)
for asset in assets:
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
lows.append(np.NaN)
continue
@@ -307,7 +307,7 @@ class DailyHistoryAggregator(object):
session_label = self._trading_calendar.minute_to_session_label(dt)
for asset in assets:
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
closes.append(np.NaN)
continue
@@ -367,7 +367,7 @@ class DailyHistoryAggregator(object):
session_label = self._trading_calendar.minute_to_session_label(dt)
for asset in assets:
if not asset._is_alive_for_session(session_label):
if not asset.is_alive_for_session(session_label):
volumes.append(0)
continue
+52 -48
View File
@@ -30,11 +30,10 @@ cpdef enum:
BEFORE_TRADING_START_BAR = 4
cdef class MinuteSimulationClock:
cdef object sessions
cdef bool minute_emission
cdef np.int64_t[:] market_opens, market_closes
cdef object before_trading_start_minutes
cdef dict minutes_by_session, minutes_to_session
cdef np.int64_t[:] market_opens_nanos, market_closes_nanos, bts_nanos, \
sessions_nanos
cdef dict minutes_by_session
def __init__(self,
sessions,
@@ -43,71 +42,76 @@ cdef class MinuteSimulationClock:
before_trading_start_minutes,
minute_emission=False):
self.minute_emission = minute_emission
self.market_opens = market_opens
self.market_closes = market_closes
self.sessions = sessions
self.market_opens_nanos = market_opens.values.astype(np.int64)
self.market_closes_nanos = market_closes.values.astype(np.int64)
self.sessions_nanos = sessions.values.astype(np.int64)
self.bts_nanos = before_trading_start_minutes.values.astype(np.int64)
self.minutes_by_session = self.calc_minutes_by_session()
self.before_trading_start_minutes = before_trading_start_minutes
@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.int64_t, ndim=1] market_minutes(self, np.intp_t i):
cdef np.int64_t[:] market_opens, market_closes
market_opens = self.market_opens
market_closes = self.market_closes
return np.arange(market_opens[i],
market_closes[i] + _nanos_in_minute,
_nanos_in_minute)
@cython.boundscheck(False)
@cython.wraparound(False)
cdef dict calc_minutes_by_session(self):
cdef dict minutes_by_session
cdef int session_idx
cdef object session
cdef np.int64_t session_nano
cdef np.ndarray[np.int64_t, ndim=1] minutes_nanos
minutes_by_session = {}
for session_idx, session in enumerate(self.sessions):
minutes_by_session[session] = pd.to_datetime(
self.market_minutes(session_idx), utc=True, box=True)
for session_idx, session_nano in enumerate(self.sessions_nanos):
minutes_nanos = np.arange(
self.market_opens_nanos[session_idx],
self.market_closes_nanos[session_idx] + _nanos_in_minute,
_nanos_in_minute
)
minutes_by_session[session_nano] = pd.to_datetime(
minutes_nanos, utc=True, box=True
)
return minutes_by_session
def __iter__(self):
minute_emission = self.minute_emission
for idx, session in enumerate(self.sessions):
yield session, SESSION_START
for idx, session_nano in enumerate(self.sessions_nanos):
yield pd.Timestamp(session_nano, tz='UTC'), SESSION_START
bts_minute = self.before_trading_start_minutes[idx]
regular_minutes = self.minutes_by_session[session]
bts_minute = pd.Timestamp(self.bts_nanos[idx], tz='UTC')
regular_minutes = self.minutes_by_session[session_nano]
# we have to search anew every session, because there is no
# guarantee that any two session start on the same minute
bts_idx = regular_minutes.searchsorted(bts_minute)
if bts_idx == len(regular_minutes):
# before_trading_start is after the last close, so don't emit
# it
for minute in regular_minutes:
yield minute, BAR
if minute_emission:
yield minute, MINUTE_END
if bts_minute > regular_minutes[-1]:
# before_trading_start is after the last close,
# so don't emit it
for minute, evt in self._get_minutes_for_list(
regular_minutes,
minute_emission
):
yield minute, evt
else:
# we have to search anew every session, because there is no
# guarantee that any two session start on the same minute
bts_idx = regular_minutes.searchsorted(bts_minute)
# emit all the minutes before bts_minute
for minute in regular_minutes[0:bts_idx]:
yield minute, BAR
if minute_emission:
yield minute, MINUTE_END
for minute, evt in self._get_minutes_for_list(
regular_minutes[0:bts_idx],
minute_emission
):
yield minute, evt
yield bts_minute, BEFORE_TRADING_START_BAR
# emit all the minutes after bts_minute
for minute in regular_minutes[bts_idx:]:
yield minute, BAR
if minute_emission:
yield minute, MINUTE_END
for minute, evt in self._get_minutes_for_list(
regular_minutes[bts_idx:],
minute_emission
):
yield minute, evt
yield regular_minutes[-1], SESSION_END
def _get_minutes_for_list(self, minutes, minute_emission):
for minute in minutes:
yield minute, BAR
if minute_emission:
yield minute, MINUTE_END