diff --git a/tests/test_assets.py b/tests/test_assets.py index f081d939..6954db47 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -29,25 +29,32 @@ from pandas.tseries.tools import normalize_date from pandas.util.testing import assert_frame_equal from numpy import full +import sqlalchemy as sa from zipline.assets import ( Asset, Equity, Future, AssetFinder, - AssetFinderCachedEquities + AssetFinderCachedEquities, ) - from zipline.assets.futures import ( cme_code_to_month, FutureChain, month_to_cme_code ) +from zipline.assets.asset_writer import ( + check_version_info, + write_version_info, + ASSET_DB_VERSION, + _version_table_schema, +) from zipline.errors import ( SymbolNotFound, MultipleSymbolsFound, SidAssignmentError, RootSymbolNotFound, + AssetDBVersionError, ) from zipline.finance.trading import TradingEnvironment, noop_load from zipline.utils.test_utils import ( @@ -1082,3 +1089,43 @@ class TestFutureChain(TestCase): } for key in codes: self.assertEqual(codes[key], month_to_cme_code(key)) + + +class TestAssetDBVersioning(TestCase): + + def test_check_version(self): + env = TradingEnvironment(load=noop_load) + version_table = env.asset_finder.version + + self.assertTrue(check_version_info(version_table, ASSET_DB_VERSION)) + + # This should fail because the version is too low + with self.assertRaises(AssetDBVersionError): + check_version_info(version_table, ASSET_DB_VERSION - 1) + + # This should fail because the version is too high + with self.assertRaises(AssetDBVersionError): + check_version_info(version_table, ASSET_DB_VERSION + 1) + + def test_write_version(self): + env = TradingEnvironment(load=noop_load) + metadata = sa.MetaData(bind=env.engine) + version_table = _version_table_schema(metadata) + version_table.delete().execute() + + # Assert that the version is not present in the table + self.assertIsNone(sa.select((version_table.c.version,)).scalar()) + + # This should return false because there is no version info in the db + self.assertFalse(check_version_info(version_table, -2)) + + # This return true because the version has been written + write_version_info(version_table, -2) + self.assertTrue(check_version_info(version_table, -2)) + + # Assert that the version is in the table and correct + self.assertEqual(sa.select((version_table.c.version,)).scalar(), -2) + + # Assert that trying to overwrite the version fails + with self.assertRaises(sa.exc.IntegrityError): + write_version_info(version_table, -3) diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index 5d6fc776..2e0d2ba3 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -25,7 +25,7 @@ import numpy as np from six import with_metaclass import sqlalchemy as sa -from zipline.errors import SidAssignmentError +from zipline.errors import SidAssignmentError, AssetDBVersionError from zipline.assets._assets import Asset SQLITE_MAX_VARIABLE_NUMBER = 999 @@ -165,6 +165,90 @@ def _generate_output_dataframe(data_subset, defaults): return output +# Define a version number for the database generated by these writers +# Increment this version number any time a breaking change is made to the +# schema and readers of the database +ASSET_DB_VERSION = 0 + + +def check_version_info(version_table, expected_version): + """ + Checks for a version value in the version table. + + Parameters + ---------- + version_table : sa.Table + The version table of the asset database + expected_version : int + The expected version of the asset database + + Returns + ------- + bool + True if the version information is present and correct. + False if the version information is missing. + + Raises + ------ + AssetDBVersionError + If the version is in the table and not equal to ASSET_DB_VERSION. + """ + + # Read the version out of the table + version_from_table = sa.select((version_table.c.version,)).scalar() + + # If a version exists in the table... + if version_from_table is not None: + + # Raise an error if the versions do not match + if (version_from_table != expected_version): + raise AssetDBVersionError(db_version=version_from_table, + expected_version=expected_version) + + # A matching version was found + return True + + # No version info found + return False + + +def write_version_info(version_table, version_value): + """ + Inserts the version value in to the version table. + + Parameters + ---------- + version_table : sa.Table + The version table of the asset database + version_value : int + The version to write in to the database + + """ + sa.insert(version_table, values={'version': version_value}).execute() + + +def _version_table_schema(metadata): + return sa.Table( + 'version', + metadata, + sa.Column( + 'id', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column( + 'version', + sa.Integer, + unique=True, + nullable=False, + ), + # This constraint ensures a single entry in this table + sa.CheckConstraint('id <= 1'), + ) + + class AssetDBWriter(with_metaclass(ABCMeta)): """ Class used to write arbitrary data to SQLite database. @@ -373,8 +457,16 @@ class AssetDBWriter(with_metaclass(ABCMeta)): primary_key=True), sa.Column('asset_type', sa.Text), ) + self.version_table = _version_table_schema(metadata) # Create the SQL tables if they do not already exist. metadata.create_all(checkfirst=True) + + # Check if the version is mismatched or, if it is not present, add it + version_found = check_version_info(self.version_table, + ASSET_DB_VERSION) + if not version_found: + write_version_info(self.version_table, ASSET_DB_VERSION) + return metadata def load_data(self): diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index fee1b35c..f6a62cf8 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -37,6 +37,8 @@ from zipline.assets import ( ) from zipline.assets.asset_writer import ( split_delimited_symbol, + check_version_info, + ASSET_DB_VERSION, ) log = Logger('assets.py') @@ -80,14 +82,18 @@ class AssetFinder(object): self.engine = engine metadata = sa.MetaData(bind=engine) - table_names = ['equities', 'futures_exchanges', 'futures_root_symbols', - 'futures_contracts', 'asset_router'] + table_names = ['version', 'equities', 'futures_exchanges', + 'futures_root_symbols', 'futures_contracts', + 'asset_router'] metadata.reflect(only=table_names) for table_name in table_names: setattr(self, table_name, metadata.tables[table_name]) - # Cache for lookup of assets by sid, the objects in the asset lookp may - # be shared with the results from equity and future lookup caches. + # Check the version info of the db for compatibility + check_version_info(self.version, ASSET_DB_VERSION) + + # Cache for lookup of assets by sid, the objects in the asset lookup + # may be shared with the results from equity and future lookup caches. # # The top level cache exists to minimize lookups on the asset type # routing. diff --git a/zipline/errors.py b/zipline/errors.py index eacfcc64..aab1f414 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -439,3 +439,15 @@ class PositionTrackerMissingAssetFinder(ZiplineError): "not have an AssetFinder. This may be caused by a failure to properly " "de-serialize a TradingAlgorithm." ) + + +class AssetDBVersionError(ZiplineError): + """ + Raised by an AssetDBWriter or AssetFinder if the version number in the + versions table does not match the ASSET_DB_VERSION in asset_writer.py. + """ + msg = ( + "The existing Asset database has an incorrect version: {db_version}. " + "Expected version: {expected_version}. Try rebuilding your asset " + "database or updating your version of Zipline." + )