mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 07:50:08 +08:00
ENH: Adds versions to asset databases
This commit is contained in:
+49
-2
@@ -29,25 +29,32 @@ from pandas.tseries.tools import normalize_date
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from numpy import full
|
||||
import sqlalchemy as sa
|
||||
|
||||
from zipline.assets import (
|
||||
Asset,
|
||||
Equity,
|
||||
Future,
|
||||
AssetFinder,
|
||||
AssetFinderCachedEquities
|
||||
AssetFinderCachedEquities,
|
||||
)
|
||||
|
||||
from zipline.assets.futures import (
|
||||
cme_code_to_month,
|
||||
FutureChain,
|
||||
month_to_cme_code
|
||||
)
|
||||
from zipline.assets.asset_writer import (
|
||||
check_version_info,
|
||||
write_version_info,
|
||||
ASSET_DB_VERSION,
|
||||
_version_table_schema,
|
||||
)
|
||||
from zipline.errors import (
|
||||
SymbolNotFound,
|
||||
MultipleSymbolsFound,
|
||||
SidAssignmentError,
|
||||
RootSymbolNotFound,
|
||||
AssetDBVersionError,
|
||||
)
|
||||
from zipline.finance.trading import TradingEnvironment, noop_load
|
||||
from zipline.utils.test_utils import (
|
||||
@@ -1082,3 +1089,43 @@ class TestFutureChain(TestCase):
|
||||
}
|
||||
for key in codes:
|
||||
self.assertEqual(codes[key], month_to_cme_code(key))
|
||||
|
||||
|
||||
class TestAssetDBVersioning(TestCase):
|
||||
|
||||
def test_check_version(self):
|
||||
env = TradingEnvironment(load=noop_load)
|
||||
version_table = env.asset_finder.version
|
||||
|
||||
self.assertTrue(check_version_info(version_table, ASSET_DB_VERSION))
|
||||
|
||||
# This should fail because the version is too low
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
check_version_info(version_table, ASSET_DB_VERSION - 1)
|
||||
|
||||
# This should fail because the version is too high
|
||||
with self.assertRaises(AssetDBVersionError):
|
||||
check_version_info(version_table, ASSET_DB_VERSION + 1)
|
||||
|
||||
def test_write_version(self):
|
||||
env = TradingEnvironment(load=noop_load)
|
||||
metadata = sa.MetaData(bind=env.engine)
|
||||
version_table = _version_table_schema(metadata)
|
||||
version_table.delete().execute()
|
||||
|
||||
# Assert that the version is not present in the table
|
||||
self.assertIsNone(sa.select((version_table.c.version,)).scalar())
|
||||
|
||||
# This should return false because there is no version info in the db
|
||||
self.assertFalse(check_version_info(version_table, -2))
|
||||
|
||||
# This return true because the version has been written
|
||||
write_version_info(version_table, -2)
|
||||
self.assertTrue(check_version_info(version_table, -2))
|
||||
|
||||
# Assert that the version is in the table and correct
|
||||
self.assertEqual(sa.select((version_table.c.version,)).scalar(), -2)
|
||||
|
||||
# Assert that trying to overwrite the version fails
|
||||
with self.assertRaises(sa.exc.IntegrityError):
|
||||
write_version_info(version_table, -3)
|
||||
|
||||
@@ -25,7 +25,7 @@ import numpy as np
|
||||
from six import with_metaclass
|
||||
import sqlalchemy as sa
|
||||
|
||||
from zipline.errors import SidAssignmentError
|
||||
from zipline.errors import SidAssignmentError, AssetDBVersionError
|
||||
from zipline.assets._assets import Asset
|
||||
|
||||
SQLITE_MAX_VARIABLE_NUMBER = 999
|
||||
@@ -165,6 +165,90 @@ def _generate_output_dataframe(data_subset, defaults):
|
||||
return output
|
||||
|
||||
|
||||
# Define a version number for the database generated by these writers
|
||||
# Increment this version number any time a breaking change is made to the
|
||||
# schema and readers of the database
|
||||
ASSET_DB_VERSION = 0
|
||||
|
||||
|
||||
def check_version_info(version_table, expected_version):
|
||||
"""
|
||||
Checks for a version value in the version table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version_table : sa.Table
|
||||
The version table of the asset database
|
||||
expected_version : int
|
||||
The expected version of the asset database
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the version information is present and correct.
|
||||
False if the version information is missing.
|
||||
|
||||
Raises
|
||||
------
|
||||
AssetDBVersionError
|
||||
If the version is in the table and not equal to ASSET_DB_VERSION.
|
||||
"""
|
||||
|
||||
# Read the version out of the table
|
||||
version_from_table = sa.select((version_table.c.version,)).scalar()
|
||||
|
||||
# If a version exists in the table...
|
||||
if version_from_table is not None:
|
||||
|
||||
# Raise an error if the versions do not match
|
||||
if (version_from_table != expected_version):
|
||||
raise AssetDBVersionError(db_version=version_from_table,
|
||||
expected_version=expected_version)
|
||||
|
||||
# A matching version was found
|
||||
return True
|
||||
|
||||
# No version info found
|
||||
return False
|
||||
|
||||
|
||||
def write_version_info(version_table, version_value):
|
||||
"""
|
||||
Inserts the version value in to the version table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version_table : sa.Table
|
||||
The version table of the asset database
|
||||
version_value : int
|
||||
The version to write in to the database
|
||||
|
||||
"""
|
||||
sa.insert(version_table, values={'version': version_value}).execute()
|
||||
|
||||
|
||||
def _version_table_schema(metadata):
|
||||
return sa.Table(
|
||||
'version',
|
||||
metadata,
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
'version',
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
),
|
||||
# This constraint ensures a single entry in this table
|
||||
sa.CheckConstraint('id <= 1'),
|
||||
)
|
||||
|
||||
|
||||
class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
"""
|
||||
Class used to write arbitrary data to SQLite database.
|
||||
@@ -373,8 +457,16 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
primary_key=True),
|
||||
sa.Column('asset_type', sa.Text),
|
||||
)
|
||||
self.version_table = _version_table_schema(metadata)
|
||||
# Create the SQL tables if they do not already exist.
|
||||
metadata.create_all(checkfirst=True)
|
||||
|
||||
# Check if the version is mismatched or, if it is not present, add it
|
||||
version_found = check_version_info(self.version_table,
|
||||
ASSET_DB_VERSION)
|
||||
if not version_found:
|
||||
write_version_info(self.version_table, ASSET_DB_VERSION)
|
||||
|
||||
return metadata
|
||||
|
||||
def load_data(self):
|
||||
|
||||
@@ -37,6 +37,8 @@ from zipline.assets import (
|
||||
)
|
||||
from zipline.assets.asset_writer import (
|
||||
split_delimited_symbol,
|
||||
check_version_info,
|
||||
ASSET_DB_VERSION,
|
||||
)
|
||||
|
||||
log = Logger('assets.py')
|
||||
@@ -80,14 +82,18 @@ class AssetFinder(object):
|
||||
self.engine = engine
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
|
||||
table_names = ['equities', 'futures_exchanges', 'futures_root_symbols',
|
||||
'futures_contracts', 'asset_router']
|
||||
table_names = ['version', 'equities', 'futures_exchanges',
|
||||
'futures_root_symbols', 'futures_contracts',
|
||||
'asset_router']
|
||||
metadata.reflect(only=table_names)
|
||||
for table_name in table_names:
|
||||
setattr(self, table_name, metadata.tables[table_name])
|
||||
|
||||
# Cache for lookup of assets by sid, the objects in the asset lookp may
|
||||
# be shared with the results from equity and future lookup caches.
|
||||
# Check the version info of the db for compatibility
|
||||
check_version_info(self.version, ASSET_DB_VERSION)
|
||||
|
||||
# Cache for lookup of assets by sid, the objects in the asset lookup
|
||||
# may be shared with the results from equity and future lookup caches.
|
||||
#
|
||||
# The top level cache exists to minimize lookups on the asset type
|
||||
# routing.
|
||||
|
||||
@@ -439,3 +439,15 @@ class PositionTrackerMissingAssetFinder(ZiplineError):
|
||||
"not have an AssetFinder. This may be caused by a failure to properly "
|
||||
"de-serialize a TradingAlgorithm."
|
||||
)
|
||||
|
||||
|
||||
class AssetDBVersionError(ZiplineError):
|
||||
"""
|
||||
Raised by an AssetDBWriter or AssetFinder if the version number in the
|
||||
versions table does not match the ASSET_DB_VERSION in asset_writer.py.
|
||||
"""
|
||||
msg = (
|
||||
"The existing Asset database has an incorrect version: {db_version}. "
|
||||
"Expected version: {expected_version}. Try rebuilding your asset "
|
||||
"database or updating your version of Zipline."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user