mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:30:28 +08:00
MAINT: PR cleanup
This commit is contained in:
+17
-14
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
@@ -34,7 +35,6 @@ import pytz
|
||||
from pandas.io.common import PerformanceWarning
|
||||
|
||||
from zipline import run_algorithm
|
||||
from tests.warnings_catcher import WarningsCatcher
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.api import FixedSlippage
|
||||
from zipline.assets import Equity, Future
|
||||
@@ -1959,7 +1959,9 @@ def handle_data(context, data):
|
||||
pass
|
||||
""")
|
||||
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
|
||||
algo = TradingAlgorithm(
|
||||
script=algocode,
|
||||
sim_params=sim_params,
|
||||
@@ -1967,18 +1969,19 @@ def handle_data(context, data):
|
||||
)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
self.assertEqual(len(w), 2)
|
||||
for i, warning in enumerate(w):
|
||||
self.assertIsInstance(warning.message, UserWarning)
|
||||
self.assertEqual(
|
||||
warning.message.args[0],
|
||||
'Got a time rule for the second positional argument '
|
||||
'date_rule. You should use keyword argument '
|
||||
'time_rule= when calling schedule_function without '
|
||||
'specifying a date_rule'
|
||||
)
|
||||
# The warnings come from line 13 and 14 in the algocode
|
||||
self.assertEqual(warning.lineno, 13 + i)
|
||||
self.assertEqual(len(w), 2)
|
||||
|
||||
for i, warning in enumerate(w):
|
||||
self.assertIsInstance(warning.message, UserWarning)
|
||||
self.assertEqual(
|
||||
warning.message.args[0],
|
||||
'Got a time rule for the second positional argument '
|
||||
'date_rule. You should use keyword argument '
|
||||
'time_rule= when calling schedule_function without '
|
||||
'specifying a date_rule'
|
||||
)
|
||||
# The warnings come from line 13 and 14 in the algocode
|
||||
self.assertEqual(warning.lineno, 13 + i)
|
||||
|
||||
self.assertEqual(
|
||||
algo.done_at_open,
|
||||
|
||||
+12
-7
@@ -5,7 +5,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.io.common import PerformanceWarning
|
||||
|
||||
from tests.warnings_catcher import WarningsCatcher
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.finance.trading import SimulationParameters
|
||||
from zipline.protocol import BarData
|
||||
@@ -292,7 +291,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
cease to be supported, we also want to assert that we're seeing a
|
||||
deprecation warning.
|
||||
"""
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
algo = self.create_algo(sid_accessor_algo)
|
||||
algo.run(self.data_portal)
|
||||
@@ -320,7 +320,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
We also want to assert that we warn that iterating over the assets
|
||||
in `data` is deprecated.
|
||||
"""
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
algo = self.create_algo(data_items_algo)
|
||||
algo.run(self.data_portal)
|
||||
@@ -344,7 +345,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
)
|
||||
|
||||
def test_iterate_data(self):
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
|
||||
algo = self.create_algo(simple_algo)
|
||||
@@ -374,7 +376,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
)
|
||||
|
||||
def test_history(self):
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
|
||||
sim_params = self.sim_params.create_new(
|
||||
@@ -415,7 +418,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
expected_vol_with_split)
|
||||
|
||||
def test_simple_transforms(self):
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
|
||||
sim_params = SimulationParameters(
|
||||
@@ -485,7 +489,8 @@ class TestAPIShim(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
self.assertAlmostEqual(346, algo.returns)
|
||||
|
||||
def test_manipulation(self):
|
||||
with WarningsCatcher([PerformanceWarning]) as w:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("ignore", PerformanceWarning)
|
||||
warnings.simplefilter("default", ZiplineDeprecationWarning)
|
||||
|
||||
algo = self.create_algo(simple_algo)
|
||||
|
||||
+2
-3
@@ -1,7 +1,6 @@
|
||||
from datetime import time
|
||||
from unittest import TestCase
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from zipline.gens.sim_engine import (
|
||||
MinuteSimulationClock,
|
||||
SESSION_START,
|
||||
@@ -26,8 +25,8 @@ class TestClock(TestCase):
|
||||
)
|
||||
|
||||
trading_o_and_c = cls.nyse_calendar.schedule.ix[cls.sessions]
|
||||
cls.opens = trading_o_and_c['market_open'].values.astype(np.int64)
|
||||
cls.closes = trading_o_and_c['market_close'].values.astype(np.int64)
|
||||
cls.opens = trading_o_and_c['market_open']
|
||||
cls.closes = trading_o_and_c['market_close']
|
||||
|
||||
def test_bts_before_session(self):
|
||||
clock = MinuteSimulationClock(
|
||||
|
||||
@@ -120,6 +120,7 @@ class MinuteToDailyAggregationTestCase(WithBcolzEquityMinuteBarReader,
|
||||
self.equity_daily_aggregator = DailyHistoryAggregator(
|
||||
self.trading_calendar.schedule.market_open,
|
||||
self.bcolz_equity_minute_bar_reader,
|
||||
self.trading_calendar
|
||||
)
|
||||
|
||||
@parameterized.expand([
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from warnings import catch_warnings, WarningMessage
|
||||
|
||||
|
||||
class WarningsCatcher(catch_warnings):
|
||||
"""
|
||||
Subclass of warnings.catch_warnings that takes a list of warning types to
|
||||
ignore.
|
||||
"""
|
||||
def __init__(self, types_to_ignore=None):
|
||||
super(WarningsCatcher, self).__init__(record=True)
|
||||
|
||||
self._types_to_ignore = set(types_to_ignore or [])
|
||||
|
||||
def __enter__(self):
|
||||
if self._entered:
|
||||
raise RuntimeError("Cannot enter %r twice" % self)
|
||||
self._entered = True
|
||||
self._filters = self._module.filters
|
||||
self._module.filters = self._filters[:]
|
||||
self._showwarning = self._module.showwarning
|
||||
if self._record:
|
||||
log = []
|
||||
|
||||
def showwarning(*args, **kwargs):
|
||||
if args[1] in self._types_to_ignore:
|
||||
return
|
||||
log.append(WarningMessage(*args, **kwargs))
|
||||
|
||||
self._module.showwarning = showwarning
|
||||
return log
|
||||
else:
|
||||
return None
|
||||
@@ -463,11 +463,11 @@ cdef class BarData:
|
||||
|
||||
cdef bool _can_trade_for_asset(self, asset, dt, adjusted_dt, data_portal):
|
||||
session_label = normalize_date(dt) # FIXME
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
# asset isn't alive
|
||||
return False
|
||||
|
||||
if not asset._asset_exchange_open(dt):
|
||||
if not asset.is_exchange_open(dt):
|
||||
# exchange isn't open
|
||||
return False
|
||||
|
||||
@@ -525,7 +525,7 @@ cdef class BarData:
|
||||
cdef bool _is_stale_for_asset(self, asset, dt, adjusted_dt, data_portal):
|
||||
session_label = normalize_date(dt) # FIXME
|
||||
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
return False
|
||||
|
||||
current_volume = data_portal.get_spot_value(
|
||||
|
||||
@@ -496,13 +496,11 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
trading_o_and_c = self.trading_calendar.schedule.ix[
|
||||
self.sim_params.sessions]
|
||||
market_closes = trading_o_and_c['market_close'].values.astype(np.int64)
|
||||
market_closes = trading_o_and_c['market_close']
|
||||
minutely_emission = False
|
||||
|
||||
if self.sim_params.data_frequency == 'minute':
|
||||
market_opens = trading_o_and_c['market_open'].values.astype(
|
||||
np.int64
|
||||
)
|
||||
market_opens = trading_o_and_c['market_open']
|
||||
|
||||
minutely_emission = self.sim_params.emission_rate == "minute"
|
||||
else:
|
||||
|
||||
@@ -101,8 +101,6 @@ cdef class Asset:
|
||||
self.first_traded = first_traded
|
||||
self.auto_close_date = auto_close_date
|
||||
|
||||
|
||||
|
||||
def __int__(self):
|
||||
return self.sid
|
||||
|
||||
@@ -201,7 +199,7 @@ cdef class Asset:
|
||||
"""
|
||||
return cls(**dict_)
|
||||
|
||||
def _is_alive_for_session(self, session_label):
|
||||
def is_alive_for_session(self, session_label):
|
||||
"""
|
||||
Returns whether the asset is alive at the given dt.
|
||||
|
||||
@@ -222,7 +220,7 @@ cdef class Asset:
|
||||
|
||||
return ref_start <= session_label.value <= ref_end
|
||||
|
||||
def _asset_exchange_open(self, dt_minute):
|
||||
def is_exchange_open(self, dt_minute):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -233,10 +231,10 @@ cdef class Asset:
|
||||
-------
|
||||
boolean: whether the asset's exchange is open at the given minute.
|
||||
"""
|
||||
calendar = self._exchange_trading_calendar_for_asset()
|
||||
calendar = self.exchange_trading_calendar()
|
||||
return calendar.is_open_on_minute(dt_minute)
|
||||
|
||||
def _exchange_trading_calendar_for_asset(self):
|
||||
def exchange_trading_calendar(self):
|
||||
"""
|
||||
Get the calendar for this asset's exchange.
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ class DailyHistoryAggregator(object):
|
||||
session_label = self._trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
for asset in assets:
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
opens.append(np.NaN)
|
||||
continue
|
||||
|
||||
@@ -168,7 +168,7 @@ class DailyHistoryAggregator(object):
|
||||
session_label = self._trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
for asset in assets:
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
highs.append(np.NaN)
|
||||
continue
|
||||
|
||||
@@ -237,7 +237,7 @@ class DailyHistoryAggregator(object):
|
||||
session_label = self._trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
for asset in assets:
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
lows.append(np.NaN)
|
||||
continue
|
||||
|
||||
@@ -307,7 +307,7 @@ class DailyHistoryAggregator(object):
|
||||
session_label = self._trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
for asset in assets:
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
closes.append(np.NaN)
|
||||
continue
|
||||
|
||||
@@ -367,7 +367,7 @@ class DailyHistoryAggregator(object):
|
||||
session_label = self._trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
for asset in assets:
|
||||
if not asset._is_alive_for_session(session_label):
|
||||
if not asset.is_alive_for_session(session_label):
|
||||
volumes.append(0)
|
||||
continue
|
||||
|
||||
|
||||
+52
-48
@@ -30,11 +30,10 @@ cpdef enum:
|
||||
BEFORE_TRADING_START_BAR = 4
|
||||
|
||||
cdef class MinuteSimulationClock:
|
||||
cdef object sessions
|
||||
cdef bool minute_emission
|
||||
cdef np.int64_t[:] market_opens, market_closes
|
||||
cdef object before_trading_start_minutes
|
||||
cdef dict minutes_by_session, minutes_to_session
|
||||
cdef np.int64_t[:] market_opens_nanos, market_closes_nanos, bts_nanos, \
|
||||
sessions_nanos
|
||||
cdef dict minutes_by_session
|
||||
|
||||
def __init__(self,
|
||||
sessions,
|
||||
@@ -43,71 +42,76 @@ cdef class MinuteSimulationClock:
|
||||
before_trading_start_minutes,
|
||||
minute_emission=False):
|
||||
self.minute_emission = minute_emission
|
||||
self.market_opens = market_opens
|
||||
self.market_closes = market_closes
|
||||
self.sessions = sessions
|
||||
|
||||
self.market_opens_nanos = market_opens.values.astype(np.int64)
|
||||
self.market_closes_nanos = market_closes.values.astype(np.int64)
|
||||
self.sessions_nanos = sessions.values.astype(np.int64)
|
||||
self.bts_nanos = before_trading_start_minutes.values.astype(np.int64)
|
||||
|
||||
self.minutes_by_session = self.calc_minutes_by_session()
|
||||
|
||||
self.before_trading_start_minutes = before_trading_start_minutes
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef np.ndarray[np.int64_t, ndim=1] market_minutes(self, np.intp_t i):
|
||||
cdef np.int64_t[:] market_opens, market_closes
|
||||
|
||||
market_opens = self.market_opens
|
||||
market_closes = self.market_closes
|
||||
|
||||
return np.arange(market_opens[i],
|
||||
market_closes[i] + _nanos_in_minute,
|
||||
_nanos_in_minute)
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef dict calc_minutes_by_session(self):
|
||||
cdef dict minutes_by_session
|
||||
cdef int session_idx
|
||||
cdef object session
|
||||
cdef np.int64_t session_nano
|
||||
cdef np.ndarray[np.int64_t, ndim=1] minutes_nanos
|
||||
|
||||
minutes_by_session = {}
|
||||
for session_idx, session in enumerate(self.sessions):
|
||||
minutes_by_session[session] = pd.to_datetime(
|
||||
self.market_minutes(session_idx), utc=True, box=True)
|
||||
for session_idx, session_nano in enumerate(self.sessions_nanos):
|
||||
minutes_nanos = np.arange(
|
||||
self.market_opens_nanos[session_idx],
|
||||
self.market_closes_nanos[session_idx] + _nanos_in_minute,
|
||||
_nanos_in_minute
|
||||
)
|
||||
minutes_by_session[session_nano] = pd.to_datetime(
|
||||
minutes_nanos, utc=True, box=True
|
||||
)
|
||||
return minutes_by_session
|
||||
|
||||
def __iter__(self):
|
||||
minute_emission = self.minute_emission
|
||||
|
||||
for idx, session in enumerate(self.sessions):
|
||||
yield session, SESSION_START
|
||||
for idx, session_nano in enumerate(self.sessions_nanos):
|
||||
yield pd.Timestamp(session_nano, tz='UTC'), SESSION_START
|
||||
|
||||
bts_minute = self.before_trading_start_minutes[idx]
|
||||
regular_minutes = self.minutes_by_session[session]
|
||||
bts_minute = pd.Timestamp(self.bts_nanos[idx], tz='UTC')
|
||||
regular_minutes = self.minutes_by_session[session_nano]
|
||||
|
||||
# we have to search anew every session, because there is no
|
||||
# guarantee that any two session start on the same minute
|
||||
bts_idx = regular_minutes.searchsorted(bts_minute)
|
||||
|
||||
if bts_idx == len(regular_minutes):
|
||||
# before_trading_start is after the last close, so don't emit
|
||||
# it
|
||||
for minute in regular_minutes:
|
||||
yield minute, BAR
|
||||
if minute_emission:
|
||||
yield minute, MINUTE_END
|
||||
if bts_minute > regular_minutes[-1]:
|
||||
# before_trading_start is after the last close,
|
||||
# so don't emit it
|
||||
for minute, evt in self._get_minutes_for_list(
|
||||
regular_minutes,
|
||||
minute_emission
|
||||
):
|
||||
yield minute, evt
|
||||
else:
|
||||
# we have to search anew every session, because there is no
|
||||
# guarantee that any two session start on the same minute
|
||||
bts_idx = regular_minutes.searchsorted(bts_minute)
|
||||
|
||||
# emit all the minutes before bts_minute
|
||||
for minute in regular_minutes[0:bts_idx]:
|
||||
yield minute, BAR
|
||||
if minute_emission:
|
||||
yield minute, MINUTE_END
|
||||
for minute, evt in self._get_minutes_for_list(
|
||||
regular_minutes[0:bts_idx],
|
||||
minute_emission
|
||||
):
|
||||
yield minute, evt
|
||||
|
||||
yield bts_minute, BEFORE_TRADING_START_BAR
|
||||
|
||||
# emit all the minutes after bts_minute
|
||||
for minute in regular_minutes[bts_idx:]:
|
||||
yield minute, BAR
|
||||
if minute_emission:
|
||||
yield minute, MINUTE_END
|
||||
for minute, evt in self._get_minutes_for_list(
|
||||
regular_minutes[bts_idx:],
|
||||
minute_emission
|
||||
):
|
||||
yield minute, evt
|
||||
|
||||
yield regular_minutes[-1], SESSION_END
|
||||
|
||||
def _get_minutes_for_list(self, minutes, minute_emission):
|
||||
for minute in minutes:
|
||||
yield minute, BAR
|
||||
if minute_emission:
|
||||
yield minute, MINUTE_END
|
||||
|
||||
Reference in New Issue
Block a user