mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 17:00:51 +08:00
ENH: Adds asset db downgrade management and tests
This commit is contained in:
@@ -57,3 +57,6 @@ Markdown==2.6.2
|
||||
futures==3.0.3
|
||||
requests-futures==0.9.5
|
||||
piprot==0.9.1
|
||||
|
||||
# For asset db management
|
||||
alembic==0.7.7
|
||||
|
||||
@@ -55,6 +55,9 @@ from zipline.assets.asset_db_schema import (
|
||||
ASSET_DB_VERSION,
|
||||
_version_table_schema,
|
||||
)
|
||||
from zipline.assets.asset_db_migrations import (
|
||||
downgrade
|
||||
)
|
||||
from zipline.errors import (
|
||||
EquitiesNotFound,
|
||||
FutureContractsNotFound,
|
||||
@@ -64,6 +67,7 @@ from zipline.errors import (
|
||||
SidAssignmentError,
|
||||
SidsNotFound,
|
||||
SymbolNotFound,
|
||||
AssetDBImpossibleDowngrade,
|
||||
)
|
||||
from zipline.finance.trading import TradingEnvironment, noop_load
|
||||
from zipline.utils.test_utils import (
|
||||
@@ -1365,3 +1369,30 @@ class TestAssetDBVersioning(TestCase):
|
||||
|
||||
# Now that the versions match, this Finder should succeed
|
||||
AssetFinder(engine=env.engine)
|
||||
|
||||
def test_downgrade(self):
|
||||
# Attempt to downgrade a current assets db all the way down to v0
|
||||
env = TradingEnvironment(load=noop_load)
|
||||
conn = env.engine.connect()
|
||||
downgrade(env.engine, 0)
|
||||
|
||||
# Verify that the db version is now 0
|
||||
metadata = sa.MetaData(conn)
|
||||
metadata.reflect(bind=env.engine)
|
||||
version_table = metadata.tables['version_info']
|
||||
check_version_info(version_table, 0)
|
||||
|
||||
# Check some of the v1-to-v0 downgrades
|
||||
self.assertTrue('futures_contracts' in metadata.tables)
|
||||
self.assertTrue('version_info' in metadata.tables)
|
||||
self.assertFalse('tick_size' in
|
||||
metadata.tables['futures_contracts'].columns)
|
||||
self.assertTrue('contract_multiplier' in
|
||||
metadata.tables['futures_contracts'].columns)
|
||||
|
||||
def test_impossible_downgrade(self):
|
||||
# Attempt to downgrade a current assets db to a
|
||||
# higher-than-current version
|
||||
env = TradingEnvironment(load=noop_load)
|
||||
with self.assertRaises(AssetDBImpossibleDowngrade):
|
||||
downgrade(env.engine, ASSET_DB_VERSION + 5)
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import sqlalchemy as sa
|
||||
from alembic.migration import MigrationContext
|
||||
from alembic.operations import Operations
|
||||
|
||||
from zipline.assets.asset_writer import write_version_info
|
||||
from zipline.errors import AssetDBImpossibleDowngrade
|
||||
|
||||
|
||||
def downgrade(engine, desired_version):
|
||||
"""Downgrades the assets db at the given engine to the desired version.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine : Engine
|
||||
An SQLAlchemy engine to the assets database.
|
||||
desired_version : int
|
||||
The desired resulting version for the assets database.
|
||||
"""
|
||||
|
||||
# Check the version of the db at the engine
|
||||
conn = engine.connect()
|
||||
metadata = sa.MetaData(conn)
|
||||
metadata.reflect(bind=engine)
|
||||
version_info_table = metadata.tables['version_info']
|
||||
starting_version = sa.select((version_info_table.c.version,)).scalar()
|
||||
|
||||
# Check for accidental upgrade
|
||||
if starting_version < desired_version:
|
||||
raise AssetDBImpossibleDowngrade(db_version=starting_version,
|
||||
desired_version=desired_version)
|
||||
|
||||
# Check if the desired version is already the db version
|
||||
if starting_version == desired_version:
|
||||
# No downgrade needed
|
||||
return
|
||||
|
||||
# Create alembic context
|
||||
ctx = MigrationContext.configure(conn)
|
||||
op = Operations(ctx)
|
||||
|
||||
# Integer keys of downgrades to run
|
||||
# E.g.: [5, 4, 3, 2] would downgrade v6 to v2
|
||||
downgrade_keys = range(desired_version, starting_version)[::-1]
|
||||
|
||||
# Disable foreign keys until all downgrades are complete
|
||||
_pragma_foreign_keys(conn, False)
|
||||
|
||||
# Execute the downgrades in order
|
||||
for downgrade_key in downgrade_keys:
|
||||
_downgrade_methods[downgrade_key](op, version_info_table)
|
||||
|
||||
# Re-enable foreign keys
|
||||
_pragma_foreign_keys(conn, True)
|
||||
|
||||
|
||||
def _pragma_foreign_keys(connection, on):
|
||||
"""Sets the PRAGMA foreign_keys state of the SQLLite database. Disabling
|
||||
the pragma allows for batch modification of tables with foreign keys.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
connection : Connection
|
||||
A SQLAlchemy connection to the db
|
||||
on : bool
|
||||
If true, PRAGMA foreign_keys will be set to ON. Otherwise, the PRAGMA
|
||||
foreign_keys will be set to OFF.
|
||||
"""
|
||||
connection.execute("PRAGMA foreign_keys=%s" % ("ON" if on else "OFF"))
|
||||
|
||||
|
||||
def _downgrade_v1_to_v0(op, version_info_table):
|
||||
"""
|
||||
Downgrade assets db by removing the 'tick_size' column and renaming the
|
||||
'multiplier' column.
|
||||
"""
|
||||
version_info_table.delete().execute()
|
||||
|
||||
# Drop indices before batch
|
||||
# This is to prevent index collision when creating the temp table
|
||||
op.drop_index('ix_futures_contracts_root_symbol')
|
||||
op.drop_index('ix_futures_contracts_symbol')
|
||||
|
||||
# Execute batch op to allow column modification in SQLLite
|
||||
with op.batch_alter_table('futures_contracts') as batch_op:
|
||||
|
||||
# Rename 'multiplier'
|
||||
batch_op.alter_column(column_name='multiplier',
|
||||
new_column_name='contract_multiplier')
|
||||
|
||||
# Delete 'tick_size'
|
||||
batch_op.drop_column('tick_size')
|
||||
|
||||
# Recreate indices after batch
|
||||
op.create_index('ix_futures_contracts_root_symbol',
|
||||
table_name='futures_contracts',
|
||||
columns=['root_symbol'])
|
||||
op.create_index('ix_futures_contracts_symbol',
|
||||
table_name='futures_contracts',
|
||||
columns=['symbol'],
|
||||
unique=True)
|
||||
|
||||
write_version_info(version_info_table, 0)
|
||||
|
||||
# This dict contains references to downgrade methods that can be applied to an
|
||||
# assets db. The resulting db's version is the key.
|
||||
# e.g. The method at key '0' is the downgrade method from v1 to v0
|
||||
_downgrade_methods = {
|
||||
0: _downgrade_v1_to_v0,
|
||||
}
|
||||
@@ -507,3 +507,10 @@ class AssetDBVersionError(ZiplineError):
|
||||
"Expected version: {expected_version}. Try rebuilding your asset "
|
||||
"database or updating your version of Zipline."
|
||||
)
|
||||
|
||||
|
||||
class AssetDBImpossibleDowngrade(ZiplineError):
|
||||
msg = (
|
||||
"The existing Asset database is version: {db_version} which is lower "
|
||||
"than the desired downgrade version: {desired_version}."
|
||||
)
|
||||
|
||||
@@ -126,8 +126,8 @@ def calc_period_stats(pos_stats, ending_cash):
|
||||
net_leverage=net_leverage)
|
||||
|
||||
|
||||
def calc_payout(contract_multiplier, amount, old_price, price):
|
||||
return (price - old_price) * contract_multiplier * amount
|
||||
def calc_payout(multiplier, amount, old_price, price):
|
||||
return (price - old_price) * multiplier * amount
|
||||
|
||||
|
||||
class PerformancePeriod(object):
|
||||
@@ -235,7 +235,7 @@ class PerformancePeriod(object):
|
||||
pos = positions[asset]
|
||||
amount = pos.amount
|
||||
payout = calc_payout(
|
||||
asset.contract_multiplier,
|
||||
asset.multiplier,
|
||||
amount,
|
||||
old_price,
|
||||
pos.last_sale_price)
|
||||
@@ -288,7 +288,7 @@ class PerformancePeriod(object):
|
||||
amount = pos.amount
|
||||
price = txn.price
|
||||
cash_adj = calc_payout(
|
||||
asset.contract_multiplier, amount, old_price, price)
|
||||
asset.multiplier, amount, old_price, price)
|
||||
self.adjust_cash(cash_adj)
|
||||
if amount + txn.amount == 0:
|
||||
del self._payout_last_sale_prices[asset]
|
||||
|
||||
Reference in New Issue
Block a user