mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 17:00:51 +08:00
ENH: Adds the option to force calendar registration
This commit is contained in:
@@ -19,6 +19,7 @@ from os.path import (
|
||||
join,
|
||||
)
|
||||
from unittest import TestCase
|
||||
from collections import namedtuple
|
||||
|
||||
import pandas as pd
|
||||
import pytz
|
||||
@@ -31,7 +32,57 @@ from pandas import (
|
||||
)
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from zipline.errors import (
|
||||
CalendarNameCollision,
|
||||
InvalidCalendarName,
|
||||
)
|
||||
from zipline.utils.calendars.exchange_calendar_nyse import NYSEExchangeCalendar
|
||||
from zipline.utils.calendars.exchange_calendar import(
|
||||
register_calendar,
|
||||
deregister_calendar,
|
||||
get_calendar,
|
||||
clear_calendars,
|
||||
)
|
||||
|
||||
|
||||
class CalendarRegistrationTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dummy_cal_type = namedtuple('DummyCal', ('name'))
|
||||
|
||||
def tearDown(self):
|
||||
clear_calendars()
|
||||
|
||||
def test_register_calendar(self):
|
||||
# Build a fake calendar
|
||||
dummy_cal = self.dummy_cal_type('DMY')
|
||||
|
||||
# Try to register and retrieve the calendar
|
||||
register_calendar(dummy_cal)
|
||||
retr_cal = get_calendar('DMY')
|
||||
self.assertEqual(dummy_cal, retr_cal)
|
||||
|
||||
# Try to register again, expecting a name collision
|
||||
with self.assertRaises(CalendarNameCollision):
|
||||
register_calendar(dummy_cal)
|
||||
|
||||
# Deregister the calendar and ensure that it is removed
|
||||
deregister_calendar('DMY')
|
||||
with self.assertRaises(InvalidCalendarName):
|
||||
get_calendar('DMY')
|
||||
|
||||
def test_force_registration(self):
|
||||
dummy_nyse = self.dummy_cal_type('NYSE')
|
||||
|
||||
# Get the actual NYSE calendar
|
||||
real_nyse = get_calendar('NYSE')
|
||||
|
||||
# Force a registration of the dummy NYSE
|
||||
register_calendar(dummy_nyse, force=True)
|
||||
|
||||
# Ensure that the dummy overwrote the real calendar
|
||||
retr_cal = get_calendar('NYSE')
|
||||
self.assertNotEqual(real_nyse, retr_cal)
|
||||
|
||||
|
||||
class ExchangeCalendarTestBase(object):
|
||||
|
||||
@@ -496,8 +496,52 @@ def get_calendar(name):
|
||||
return _static_calendars[name]
|
||||
|
||||
|
||||
def register_calendar(calendar):
|
||||
def deregister_calendar(cal_name):
|
||||
"""
|
||||
If a calendar is registered with the given name, it is de-registered.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cal_name : str
|
||||
The name of the calendar to be deregistered.
|
||||
"""
|
||||
try:
|
||||
_static_calendars.pop(cal_name)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def clear_calendars():
|
||||
"""
|
||||
Deregisters all current registered calendars
|
||||
"""
|
||||
_static_calendars.clear()
|
||||
|
||||
|
||||
def register_calendar(calendar, force=False):
|
||||
"""
|
||||
Registers a calendar for retrieval by the get_calendar method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendar : ExchangeCalendar
|
||||
The calendar to be registered for retrieval.
|
||||
force : bool, optional
|
||||
If True, old calendars will be overwritten on a name collision.
|
||||
If False, name collisions will raise an exception. Default: False.
|
||||
|
||||
Raises
|
||||
------
|
||||
CalendarNameCollision
|
||||
If a calendar is already registered with the given calendar's name.
|
||||
"""
|
||||
# If we are forcing the registration, remove an existing calendar with the
|
||||
# same name.
|
||||
if force:
|
||||
deregister_calendar(calendar.name)
|
||||
|
||||
# Check if we are already holding a calendar with the same name
|
||||
if calendar.name in _static_calendars:
|
||||
raise CalendarNameCollision(calendar_name=calendar.name)
|
||||
|
||||
_static_calendars[calendar.name] = calendar
|
||||
|
||||
Reference in New Issue
Block a user