ENH: Adds asset db downgrade management and tests

This commit is contained in:
jfkirk
2016-01-06 13:31:25 -05:00
parent db1e62971a
commit ece9e59ef9
5 changed files with 154 additions and 4 deletions
+3
View File
@@ -57,3 +57,6 @@ Markdown==2.6.2
futures==3.0.3
requests-futures==0.9.5
piprot==0.9.1
# For asset db management
alembic==0.7.7
+31
View File
@@ -55,6 +55,9 @@ from zipline.assets.asset_db_schema import (
ASSET_DB_VERSION,
_version_table_schema,
)
from zipline.assets.asset_db_migrations import (
downgrade
)
from zipline.errors import (
EquitiesNotFound,
FutureContractsNotFound,
@@ -64,6 +67,7 @@ from zipline.errors import (
SidAssignmentError,
SidsNotFound,
SymbolNotFound,
AssetDBImpossibleDowngrade,
)
from zipline.finance.trading import TradingEnvironment, noop_load
from zipline.utils.test_utils import (
@@ -1365,3 +1369,30 @@ class TestAssetDBVersioning(TestCase):
# Now that the versions match, this Finder should succeed
AssetFinder(engine=env.engine)
def test_downgrade(self):
# Attempt to downgrade a current assets db all the way down to v0
env = TradingEnvironment(load=noop_load)
conn = env.engine.connect()
downgrade(env.engine, 0)
# Verify that the db version is now 0
metadata = sa.MetaData(conn)
metadata.reflect(bind=env.engine)
version_table = metadata.tables['version_info']
check_version_info(version_table, 0)
# Check some of the v1-to-v0 downgrades
self.assertTrue('futures_contracts' in metadata.tables)
self.assertTrue('version_info' in metadata.tables)
self.assertFalse('tick_size' in
metadata.tables['futures_contracts'].columns)
self.assertTrue('contract_multiplier' in
metadata.tables['futures_contracts'].columns)
def test_impossible_downgrade(self):
# Attempt to downgrade a current assets db to a
# higher-than-current version
env = TradingEnvironment(load=noop_load)
with self.assertRaises(AssetDBImpossibleDowngrade):
downgrade(env.engine, ASSET_DB_VERSION + 5)
+109
View File
@@ -0,0 +1,109 @@
import sqlalchemy as sa
from alembic.migration import MigrationContext
from alembic.operations import Operations
from zipline.assets.asset_writer import write_version_info
from zipline.errors import AssetDBImpossibleDowngrade
def downgrade(engine, desired_version):
"""Downgrades the assets db at the given engine to the desired version.
Parameters
----------
engine : Engine
An SQLAlchemy engine to the assets database.
desired_version : int
The desired resulting version for the assets database.
"""
# Check the version of the db at the engine
conn = engine.connect()
metadata = sa.MetaData(conn)
metadata.reflect(bind=engine)
version_info_table = metadata.tables['version_info']
starting_version = sa.select((version_info_table.c.version,)).scalar()
# Check for accidental upgrade
if starting_version < desired_version:
raise AssetDBImpossibleDowngrade(db_version=starting_version,
desired_version=desired_version)
# Check if the desired version is already the db version
if starting_version == desired_version:
# No downgrade needed
return
# Create alembic context
ctx = MigrationContext.configure(conn)
op = Operations(ctx)
# Integer keys of downgrades to run
# E.g.: [5, 4, 3, 2] would downgrade v6 to v2
downgrade_keys = range(desired_version, starting_version)[::-1]
# Disable foreign keys until all downgrades are complete
_pragma_foreign_keys(conn, False)
# Execute the downgrades in order
for downgrade_key in downgrade_keys:
_downgrade_methods[downgrade_key](op, version_info_table)
# Re-enable foreign keys
_pragma_foreign_keys(conn, True)
def _pragma_foreign_keys(connection, on):
"""Sets the PRAGMA foreign_keys state of the SQLLite database. Disabling
the pragma allows for batch modification of tables with foreign keys.
Parameters
----------
connection : Connection
A SQLAlchemy connection to the db
on : bool
If true, PRAGMA foreign_keys will be set to ON. Otherwise, the PRAGMA
foreign_keys will be set to OFF.
"""
connection.execute("PRAGMA foreign_keys=%s" % ("ON" if on else "OFF"))
def _downgrade_v1_to_v0(op, version_info_table):
"""
Downgrade assets db by removing the 'tick_size' column and renaming the
'multiplier' column.
"""
version_info_table.delete().execute()
# Drop indices before batch
# This is to prevent index collision when creating the temp table
op.drop_index('ix_futures_contracts_root_symbol')
op.drop_index('ix_futures_contracts_symbol')
# Execute batch op to allow column modification in SQLLite
with op.batch_alter_table('futures_contracts') as batch_op:
# Rename 'multiplier'
batch_op.alter_column(column_name='multiplier',
new_column_name='contract_multiplier')
# Delete 'tick_size'
batch_op.drop_column('tick_size')
# Recreate indices after batch
op.create_index('ix_futures_contracts_root_symbol',
table_name='futures_contracts',
columns=['root_symbol'])
op.create_index('ix_futures_contracts_symbol',
table_name='futures_contracts',
columns=['symbol'],
unique=True)
write_version_info(version_info_table, 0)
# This dict contains references to downgrade methods that can be applied to an
# assets db. The resulting db's version is the key.
# e.g. The method at key '0' is the downgrade method from v1 to v0
_downgrade_methods = {
0: _downgrade_v1_to_v0,
}
+7
View File
@@ -507,3 +507,10 @@ class AssetDBVersionError(ZiplineError):
"Expected version: {expected_version}. Try rebuilding your asset "
"database or updating your version of Zipline."
)
class AssetDBImpossibleDowngrade(ZiplineError):
msg = (
"The existing Asset database is version: {db_version} which is lower "
"than the desired downgrade version: {desired_version}."
)
+4 -4
View File
@@ -126,8 +126,8 @@ def calc_period_stats(pos_stats, ending_cash):
net_leverage=net_leverage)
def calc_payout(contract_multiplier, amount, old_price, price):
return (price - old_price) * contract_multiplier * amount
def calc_payout(multiplier, amount, old_price, price):
return (price - old_price) * multiplier * amount
class PerformancePeriod(object):
@@ -235,7 +235,7 @@ class PerformancePeriod(object):
pos = positions[asset]
amount = pos.amount
payout = calc_payout(
asset.contract_multiplier,
asset.multiplier,
amount,
old_price,
pos.last_sale_price)
@@ -288,7 +288,7 @@ class PerformancePeriod(object):
amount = pos.amount
price = txn.price
cash_adj = calc_payout(
asset.contract_multiplier, amount, old_price, price)
asset.multiplier, amount, old_price, price)
self.adjust_cash(cash_adj)
if amount + txn.amount == 0:
del self._payout_last_sale_prices[asset]