diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index 543b53c2..328b02e1 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -33,10 +33,6 @@ SQLITE_MAX_VARIABLE_NUMBER = 999 # Define a namedtuple for use with the load_data and _load_data methods AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols') -# A list of the names of all tables in the assets db -table_names = ['version_info', 'equities', 'futures_exchanges', - 'futures_root_symbols', 'futures_contracts', 'asset_router'] - # Default values for the equities DataFrame _equities_defaults = { 'symbol': None, @@ -219,6 +215,121 @@ def write_version_info(version_table, version_value): sa.insert(version_table, values={'version': version_value}).execute() +# A list of the names of all tables in the assets db +asset_db_table_names = ['version_info', 'equities', 'futures_exchanges', + 'futures_root_symbols', 'futures_contracts', + 'asset_router'] + + +def _equities_table_schema(metadata): + return sa.Table( + 'equities', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('symbol', sa.Text), + sa.Column('company_symbol', sa.Text, index=True), + sa.Column('share_class_symbol', sa.Text), + sa.Column('fuzzy_symbol', sa.Text, index=True), + sa.Column('asset_name', sa.Text), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer, nullable=False), + sa.Column('exchange', sa.Text), + ) + + +def _futures_exchanges_schema(metadata): + return sa.Table( + 'futures_exchanges', + metadata, + sa.Column( + 'exchange', + sa.Text, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('timezone', sa.Text), + ) + + +def _futures_root_symbols_schema(metadata, futures_exchanges): + return sa.Table( + 'futures_root_symbols', + metadata, + sa.Column( + 'root_symbol', + sa.Text, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('root_symbol_id', sa.Integer), + sa.Column('sector', sa.Text), + sa.Column('description', sa.Text), + sa.Column( + 'exchange', + sa.Text, + sa.ForeignKey(futures_exchanges.c.exchange), + ), + ) + + +def _futures_contracts_schema(metadata, futures_root_symbols, + futures_exchanges): + return sa.Table( + 'futures_contracts', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True, + ), + sa.Column('symbol', sa.Text, unique=True, index=True), + sa.Column( + 'root_symbol', + sa.Text, + sa.ForeignKey(futures_root_symbols.c.root_symbol), + index=True + ), + sa.Column('asset_name', sa.Text), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer, nullable=False), + sa.Column( + 'exchange', + sa.Text, + sa.ForeignKey(futures_exchanges.c.exchange), + ), + sa.Column('notice_date', sa.Integer, nullable=False), + sa.Column('expiration_date', sa.Integer, nullable=False), + sa.Column('auto_close_date', sa.Integer, nullable=False), + sa.Column('contract_multiplier', sa.Float), + ) + + +def _asset_router_schema(metadata): + return sa.Table( + 'asset_router', + metadata, + sa.Column( + 'sid', + sa.Integer, + unique=True, + nullable=False, + primary_key=True), + sa.Column('asset_type', sa.Text), + ) + + def _version_table_schema(metadata): return sa.Table( 'version_info', @@ -354,7 +465,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): True if any tables are present, otherwise False. """ conn = engine.connect() - for table_name in table_names: + for table_name in asset_db_table_names: if engine.dialect.has_table(conn, table_name): return True return False @@ -373,107 +484,26 @@ class AssetDBWriter(with_metaclass(ABCMeta)): tables_already_exist = self.check_for_tables(engine=engine) - self.equities = sa.Table( - 'equities', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('symbol', sa.Text), - sa.Column('company_symbol', sa.Text, index=True), - sa.Column('share_class_symbol', sa.Text), - sa.Column('fuzzy_symbol', sa.Text, index=True), - sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0, nullable=False), - sa.Column('end_date', sa.Integer, nullable=False), - sa.Column('first_traded', sa.Integer, nullable=False), - sa.Column('exchange', sa.Text), + self.equities = _equities_table_schema(metadata) + self.futures_exchanges = _futures_exchanges_schema(metadata) + self.futures_root_symbols = _futures_root_symbols_schema( + metadata=metadata, + futures_exchanges=self.futures_exchanges, ) - self.futures_exchanges = sa.Table( - 'futures_exchanges', - metadata, - sa.Column( - 'exchange', - sa.Text, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('timezone', sa.Text), + self.futures_contracts = _futures_contracts_schema( + metadata=metadata, + futures_root_symbols=self.futures_root_symbols, + futures_exchanges=self.futures_exchanges, ) - self.futures_root_symbols = sa.Table( - 'futures_root_symbols', - metadata, - sa.Column( - 'root_symbol', - sa.Text, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('root_symbol_id', sa.Integer), - sa.Column('sector', sa.Text), - sa.Column('description', sa.Text), - sa.Column( - 'exchange', - sa.Text, - sa.ForeignKey(self.futures_exchanges.c.exchange), - ), - ) - self.futures_contracts = sa.Table( - 'futures_contracts', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True, - ), - sa.Column('symbol', sa.Text, unique=True, index=True), - sa.Column( - 'root_symbol', - sa.Text, - sa.ForeignKey(self.futures_root_symbols.c.root_symbol), - index=True - ), - sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0, nullable=False), - sa.Column('end_date', sa.Integer, nullable=False), - sa.Column('first_traded', sa.Integer, nullable=False), - sa.Column( - 'exchange', - sa.Text, - sa.ForeignKey(self.futures_exchanges.c.exchange), - ), - sa.Column('notice_date', sa.Integer, nullable=False), - sa.Column('expiration_date', sa.Integer, nullable=False), - sa.Column('auto_close_date', sa.Integer, nullable=False), - sa.Column('contract_multiplier', sa.Float), - ) - self.asset_router = sa.Table( - 'asset_router', - metadata, - sa.Column( - 'sid', - sa.Integer, - unique=True, - nullable=False, - primary_key=True), - sa.Column('asset_type', sa.Text), - ) - self.version_table = _version_table_schema(metadata) + self.asset_router = _asset_router_schema(metadata) + self.version_info = _version_table_schema(metadata) # Create the SQL tables if they do not already exist. metadata.create_all(checkfirst=True) if tables_already_exist: - check_version_info(self.version_table, ASSET_DB_VERSION) + check_version_info(self.version_info, ASSET_DB_VERSION) else: - write_version_info(self.version_table, ASSET_DB_VERSION) + write_version_info(self.version_info, ASSET_DB_VERSION) return metadata diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index c1d593bd..8d059072 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -39,7 +39,7 @@ from zipline.assets.asset_writer import ( split_delimited_symbol, check_version_info, ASSET_DB_VERSION, - table_names, + asset_db_table_names, ) log = Logger('assets.py') @@ -82,8 +82,8 @@ class AssetFinder(object): self.engine = engine metadata = sa.MetaData(bind=engine) - metadata.reflect(only=table_names) - for table_name in table_names: + metadata.reflect(only=asset_db_table_names) + for table_name in asset_db_table_names: setattr(self, table_name, metadata.tables[table_name]) # Check the version info of the db for compatibility