diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index a4da5aa1..99ef87fe 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -13,47 +13,11 @@ import sqlalchemy as sa from zipline.errors import SidAssignmentError from zipline.assets._assets import Asset +SQLITE_MAX_VARIABLE_NUMBER = 999 + # Define a namedtuple for use with the load_data and _load_data methods AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols') -# Expected fields for an Asset's metadata -ASSET_TABLE_FIELDS = frozenset({ - 'sid', - 'symbol', - 'asset_name', - 'start_date', - 'end_date', - 'first_traded', - 'exchange', -}) - -# Expected fields for a Future's metadata -FUTURE_TABLE_FIELDS = ASSET_TABLE_FIELDS | { - 'notice_date', - 'expiration_date', - 'auto_close_date', - 'contract_multiplier', -} - -# Expected fields for an Equity's metadata -EQUITY_TABLE_FIELDS = ASSET_TABLE_FIELDS | { - 'company_symbol', - 'share_class_symbol', - 'fuzzy_symbol', -} - -EXCHANGE_TABLE_FIELDS = frozenset({ - 'exchange', - 'timezone', -}) - -ROOT_SYMBOL_TABLE_FIELDS = frozenset({ - 'root_symbol', - 'root_symbol_id', - 'sector', - 'description', - 'exchange', -}) # Default values for the equities DataFrame _equities_defaults = { @@ -162,21 +126,21 @@ def _generate_output_dataframe(data_subset, defaults): """ # The columns provided. cols = set(data_subset.columns) - desired_cols = {col for col in defaults.keys()} + desired_cols = set(defaults) # Drop columns with unrecognised headers. - data_subset.drop(cols - (cols & desired_cols), + data_subset.drop(cols - desired_cols, axis=1, inplace=True) # Get those columns which we need but # for which no data has been supplied. - need = desired_cols - set(data_subset.columns) + need = desired_cols - cols # Combine the users supplied data with our required columns. output = pd.concat( (data_subset, pd.DataFrame( - _dict_subset(defaults, need), + {k: defaults[k] for k in need}, data_subset.index, )), axis=1, @@ -186,13 +150,6 @@ def _generate_output_dataframe(data_subset, defaults): return output -def _dict_subset(dict_, subset): - res = {} - for k in subset: - res[k] = dict_[k] - return res - - class AssetDBWriter(with_metaclass(ABCMeta)): """ Class used to write arbitrary data to SQLite database. @@ -209,10 +166,34 @@ class AssetDBWriter(with_metaclass(ABCMeta)): Returns data in standard format. """ + CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER + + def __init__(self, equities=None, futures=None, exchanges=None, + root_symbols=None): + + if equities is None: + equities = self.defaultval() + self._equities = equities + + if futures is None: + futures = self.defaultval() + self._futures = futures + + if exchanges is None: + exchanges = self.defaultval() + self._exchanges = exchanges + + if root_symbols is None: + root_symbols = self.defaultval() + self._root_symbols = root_symbols + + @abstractmethod + def defaultval(self): + raise NotImplementedError + def write_all(self, engine, - allow_sid_assignment=True, - constraints=True): + allow_sid_assignment=True): """ Write pre-supplied data to SQLite. Parameters @@ -230,7 +211,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): # Begin an SQL transaction. with engine.begin() as txn: # Create SQL tables. - self.init_db(txn, constraints) + self.init_db(txn) # Get the data to add to SQL. data = self.load_data() # Write the data to SQL. @@ -239,45 +220,40 @@ class AssetDBWriter(with_metaclass(ABCMeta)): self._write_futures(data.futures, txn) self._write_equities(data.equities, txn) - def _write_exchanges(self, exchanges, bind=None): - recs = exchanges.reset_index().rename_axis( - {'index': 'exchange'}, - 1, - ).to_dict('records') - # In SQLAlchemy, insert().values([]) will insert NULLs, - # hence we check first to avoid violating NOT NULL constraints. - if recs: - self.futures_exchanges.insert().values(recs).execute(bind=bind) + def _write_df_to_table(self, df, tbl, bind): + df.to_sql( + tbl.name, + bind.connection, + index_label=[col.name for col in tbl.primary_key.columns][0], + if_exists='append', + chunksize=self.CHUNK_SIZE, + ) - def _write_root_symbols(self, root_symbols, bind=None): - recs = root_symbols.reset_index().rename_axis( - {'index': 'root_symbol'}, - 1, - ).to_dict('records') - if recs: - self.futures_root_symbols.insert().values(recs).execute(bind=bind) + def _write_assets(self, assets, asset_tbl, asset_type, bind): + self._write_df_to_table(assets, asset_tbl, bind) - def _write_futures(self, futures, bind=None): - recs = futures.reset_index().rename_axis( - {'index': 'sid'}, - 1, - ).to_dict('records') - for record in recs: - self.futures_contracts.insert().values([record]).execute(bind=bind) - self.asset_router.insert().values([(record['sid'], 'future')])\ - .execute(bind=bind) + pd.DataFrame({self.asset_router.c.sid.name: assets.index.values, + self.asset_router.c.asset_type.name: asset_type}).to_sql( + self.asset_router.name, + bind.connection, + if_exists='append', + index=False, + chunksize=self.CHUNK_SIZE, + ) - def _write_equities(self, equities, bind=None): - recs = equities.reset_index().rename_axis( - {'index': 'sid'}, - 1, - ).to_dict('records') - for record in recs: - self.equities.insert().values([record]).execute(bind=bind) - self.asset_router.insert().values((record['sid'], 'equity'))\ - .execute(bind=bind) + def _write_exchanges(self, exchanges, bind): + self._write_df_to_table(exchanges, self.futures_exchanges, bind) - def init_db(self, engine, constraints=True): + def _write_root_symbols(self, root_symbols, bind): + self._write_df_to_table(root_symbols, self.futures_root_symbols, bind) + + def _write_futures(self, futures, bind): + self._write_assets(futures, self.futures_contracts, 'future', bind) + + def _write_equities(self, equities, bind): + self._write_assets(equities, self.equities, 'equity', bind) + + def init_db(self, engine): """Connect to database and create tables. Parameters @@ -287,7 +263,8 @@ class AssetDBWriter(with_metaclass(ABCMeta)): constraints : bool, optional If True, create SQL ForeignKey and PrimaryKey constraints. """ - self.sql_metadata = metadata = sa.MetaData(bind=engine) + metadata = sa.MetaData(bind=engine) + self.equities = sa.Table( 'equities', metadata, @@ -296,16 +273,16 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Integer, unique=True, nullable=False, - primary_key=constraints, + primary_key=True, ), sa.Column('symbol', sa.Text), sa.Column('company_symbol', sa.Text, index=True), sa.Column('share_class_symbol', sa.Text), sa.Column('fuzzy_symbol', sa.Text, index=True), sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0), - sa.Column('end_date', sa.Integer), - sa.Column('first_traded', sa.Integer), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer, nullable=False), sa.Column('exchange', sa.Text), ) self.futures_exchanges = sa.Table( @@ -316,7 +293,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Text, unique=True, nullable=False, - primary_key=constraints, + primary_key=True, ), sa.Column('timezone', sa.Text), ) @@ -328,7 +305,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Text, unique=True, nullable=False, - primary_key=constraints, + primary_key=True, ), sa.Column('root_symbol_id', sa.Integer), sa.Column('sector', sa.Text), @@ -336,8 +313,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Column( 'exchange', sa.Text, - *((sa.ForeignKey(self.futures_exchanges.c.exchange),) - if constraints else ()) + sa.ForeignKey(self.futures_exchanges.c.exchange), ), ) self.futures_contracts = sa.Table( @@ -348,28 +324,26 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Integer, unique=True, nullable=False, - primary_key=constraints, + primary_key=True, ), sa.Column('symbol', sa.Text), sa.Column( 'root_symbol', sa.Text, - *((sa.ForeignKey(self.futures_root_symbols.c.root_symbol),) - if constraints else ()) + sa.ForeignKey(self.futures_root_symbols.c.root_symbol), ), sa.Column('asset_name', sa.Text), - sa.Column('start_date', sa.Integer, default=0), - sa.Column('end_date', sa.Integer), - sa.Column('first_traded', sa.Integer), + sa.Column('start_date', sa.Integer, default=0, nullable=False), + sa.Column('end_date', sa.Integer, nullable=False), + sa.Column('first_traded', sa.Integer, nullable=False), sa.Column( 'exchange', sa.Text, - *((sa.ForeignKey(self.futures_exchanges.c.exchange),) - if constraints else ()) + sa.ForeignKey(self.futures_exchanges.c.exchange), ), - sa.Column('notice_date', sa.Integer), - sa.Column('expiration_date', sa.Integer), - sa.Column('auto_close_date', sa.Integer), + sa.Column('notice_date', sa.Integer, nullable=False), + sa.Column('expiration_date', sa.Integer, nullable=False), + sa.Column('auto_close_date', sa.Integer, nullable=False), sa.Column('contract_multiplier', sa.Float), ) self.asset_router = sa.Table( @@ -380,7 +354,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): sa.Integer, unique=True, nullable=False, - primary_key=constraints), + primary_key=True), sa.Column('asset_type', sa.Text), ) # Create the SQL tables if they do not already exist. @@ -400,10 +374,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)): ############################### # HACK: If company_name is provided, map it to asset_name - if ('company_name' in data.equities.columns) \ - and ('asset_name' not in data.equities.columns): + if ('company_name' in data.equities.columns + and 'asset_name' not in data.equities.columns): data.equities['asset_name'] = data.equities['company_name'] - if ('file_name' in data.equities.columns): + if 'file_name' in data.equities.columns: data.equities['symbol'] = data.equities['file_name'] equities_output = _generate_output_dataframe( @@ -431,12 +405,9 @@ class AssetDBWriter(with_metaclass(ABCMeta)): equities_output.fuzzy_symbol.str.upper() # Convert date columns to UNIX Epoch integers (nanoseconds) - equities_output['start_date'] = \ - equities_output['start_date'].apply(self.convert_datetime) - equities_output['end_date'] = \ - equities_output['end_date'].apply(self.convert_datetime) - equities_output['first_traded'] = \ - equities_output['first_traded'].apply(self.convert_datetime) + for date_col in ('start_date', 'end_date', 'first_traded'): + equities_output[date_col] = \ + self.dt_to_epoch_ns(equities_output[date_col]) ############################## # Generate futures DataFrame # @@ -448,18 +419,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)): ) # Convert date columns to UNIX Epoch integers (nanoseconds) - futures_output['start_date'] = \ - futures_output['start_date'].apply(self.convert_datetime) - futures_output['end_date'] = \ - futures_output['end_date'].apply(self.convert_datetime) - futures_output['first_traded'] = \ - futures_output['first_traded'].apply(self.convert_datetime) - futures_output['notice_date'] = \ - futures_output['notice_date'].apply(self.convert_datetime) - futures_output['expiration_date'] = \ - futures_output['expiration_date'].apply(self.convert_datetime) - futures_output['auto_close_date'] = \ - futures_output['auto_close_date'].apply(self.convert_datetime) + for date_col in ('start_date', 'end_date', 'first_traded', + 'notice_date', 'expiration_date', 'auto_close_date'): + futures_output[date_col] = \ + self.dt_to_epoch_ns(futures_output[date_col]) # Convert symbols and root_symbols to upper case. futures_output['symbol'] = futures_output.symbol.str.upper() @@ -488,56 +451,15 @@ class AssetDBWriter(with_metaclass(ABCMeta)): exchanges=exchanges_output, root_symbols=root_symbols_output) - def convert_datetime(self, dt): - """Convert a datetime variable to integer of nanoseconds - since UNIX Epoch. - - Parameters - ---------- - dt : datetime-coercible - A string, int or pd.Timestamp instance representing a datetime, or - None/NaN. - - Returns - ------- - int - nanoseconds since UNIX Epoch, or None if parameter 'dt' is null. - """ - - # Check for null parameter - if pd.isnull(dt): - return None - - # If no timezone is specified, assume UTC. - # Otherwise, convert to UTC. + @staticmethod + def dt_to_epoch_ns(dt_series): + index = pd.to_datetime(dt_series.values) try: - dt = pd.Timestamp(dt).tz_localize('UTC') + index = index.tz_localize('UTC') except TypeError: - dt = pd.Timestamp(dt).tz_convert('UTC') + index = index.tz_convert('UTC') - # Get seconds from UNIX Epoch - total_seconds_from_epoch = self._seconds_from_unix_time(dt) - - # Return nanoseconds since UNIX Epoch - return int(total_seconds_from_epoch * 1000000000) - - def _seconds_from_unix_time(self, dt): - """Return seconds between dt and UNIX Epoch. - - Parameters - ---------- - dt: pandas.Timestamp - The time for which to calculate seconds since UNIX Epoch. - - Returns - ------- - float - Seconds between dt and UNIX Epoch. - - """ - epoch = pd.to_datetime(0, utc=True) - delta = dt - epoch - return delta.total_seconds() + return index.view(int) @abstractmethod def _load_data(self): @@ -559,28 +481,7 @@ class AssetDBWriterFromList(AssetDBWriter): Class used to write list data to SQLite database. """ - def __init__(self, equities=None, futures=None, exchanges=None, - root_symbols=None): - - if equities is not None: - self._equities = equities - else: - self._equities = [] - - if futures is not None: - self._futures = futures - else: - self._futures = [] - - if exchanges is not None: - self._exchanges = exchanges - else: - self._exchanges = [] - - if root_symbols is not None: - self._root_symbols = root_symbols - else: - self._root_symbols = [] + defaultval = list def _load_data(self): @@ -654,28 +555,7 @@ class AssetDBWriterFromDictionary(AssetDBWriter): {id_0: {attribute_1 : ...}, id_1: {attribute_2: ...}, ...} """ - def __init__(self, equities=None, futures=None, exchanges=None, - root_symbols=None): - - if equities is not None: - self._equities = equities - else: - self._equities = {} - - if futures is not None: - self._futures = futures - else: - self._futures = {} - - if exchanges is not None: - self._exchanges = exchanges - else: - self._exchanges = {} - - if root_symbols is not None: - self._root_symbols = root_symbols - else: - self._root_symbols = {} + defaultval = dict def _load_data(self): @@ -696,42 +576,21 @@ class AssetDBWriterFromDataFrame(AssetDBWriter): Class used to write pandas.DataFrame data to SQLite database. """ - def __init__(self, equities=None, futures=None, exchanges=None, - root_symbols=None): - - if equities is not None: - self._equities = equities - else: - self._equities = pd.DataFrame() - - if futures is not None: - self._futures = futures - else: - self._futures = pd.DataFrame() - - if exchanges is not None: - self._exchanges = exchanges - else: - self._exchanges = pd.DataFrame() - - if root_symbols is not None: - self._root_symbols = root_symbols - else: - self._root_symbols = pd.DataFrame() + defaultval = pd.DataFrame def _load_data(self): # Check whether identifier columns have been provided. # If they have, set the index to this column. # If not, assume the index already cotains the identifier information. - if 'sid' in self._equities.columns: - self._equities.set_index(['sid'], inplace=True) - if 'sid' in self._futures.columns: - self._futures.set_index(['sid'], inplace=True) - if 'exchange_id' in self._exchanges.columns: - self._exchanges.set_index(['exchange'], inplace=True) - if 'root_symbol_id' in self._root_symbols.columns: - self._root_symbols.set_index(['root_symbol'], inplace=True) + for df, id_col in [ + (self._equities, 'sid'), + (self._futures, 'sid'), + (self._exchanges, 'exchange'), + (self._root_symbols, 'root_symbol'), + ]: + if id_col in df.columns: + df.set_index([id_col], inplace=True) return AssetData(equities=self._equities, futures=self._futures, diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index ae5b92ef..a7ceb7af 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -13,16 +13,15 @@ # limitations under the License. from abc import ABCMeta -from functools import partial from numbers import Integral -from operator import getitem, itemgetter +from operator import itemgetter import warnings from logbook import Logger import numpy as np import pandas as pd from pandas.tseries.tools import normalize_date -from six import with_metaclass, string_types +from six import with_metaclass, string_types, viewkeys import sqlalchemy as sa from toolz import compose @@ -37,8 +36,6 @@ from zipline.assets import ( Asset, Equity, Future, ) from zipline.assets.asset_writer import ( - FUTURE_TABLE_FIELDS, - EQUITY_TABLE_FIELDS, split_delimited_symbol, ) @@ -67,9 +64,9 @@ def _convert_asset_timestamp_fields(dict): """ Takes in a dict of Asset init args and converts dates to pd.Timestamps """ - for key, value in dict.items(): - if (key in _asset_timestamp_fields) and (value is not None): - dict[key] = pd.Timestamp(value, tz='UTC') + for key in (_asset_timestamp_fields & viewkeys(dict)): + value = pd.Timestamp(dict[key], tz='UTC') + dict[key] = None if pd.isnull(value) else value class AssetFinder(object): @@ -84,60 +81,13 @@ class AssetFinder(object): self.engine = engine metadata = sa.MetaData(bind=engine) - self.equities = equities = sa.Table( - 'equities', - metadata, - autoload=True, - autoload_with=engine, - ) - self.futures_exchanges = sa.Table( - 'futures_exchanges', - metadata, - autoload=True, - autoload_with=engine, - ) - self.futures_root_symbols = sa.Table( - 'futures_root_symbols', - metadata, - autoload=True, - autoload_with=engine, - ) - self.futures_contracts = futures_contracts = sa.Table( - 'futures_contracts', - metadata, - autoload=True, - autoload_with=engine, - ) - self.asset_router = sa.Table( - 'asset_router', - metadata, - autoload=True, - autoload_with=engine, - ) - # Create the equity and future queries once. - _equity_sid = equities.c.sid - _equity_by_sid = sa.select( - tuple(map(partial(getitem, equities.c), EQUITY_TABLE_FIELDS)), - ) + table_names = ['equities', 'futures_exchanges', 'futures_root_symbols', + 'futures_contracts', 'asset_router'] + metadata.reflect(only=table_names) + for table_name in table_names: + setattr(self, table_name, metadata.tables[table_name]) - def select_equity_by_sid(sid): - return _equity_by_sid.where(_equity_sid == int(sid)) - - self.select_equity_by_sid = select_equity_by_sid - - _future_sid = futures_contracts.c.sid - _future_by_sid = sa.select( - tuple(map( - partial(getitem, futures_contracts.c), - FUTURE_TABLE_FIELDS, - )), - ) - - def select_future_by_sid(sid): - return _future_by_sid.where(_future_sid == int(sid)) - - self.select_future_by_sid = select_future_by_sid # Cache for lookup of assets by sid, the objects in the asset lookp may # be shared with the results from equity and future lookup caches. # @@ -203,51 +153,46 @@ class AssetFinder(object): raise SidNotFound(sid=sid) def retrieve_all(self, sids, default_none=False): - return [self.retrieve_asset(sid) for sid in sids] + return [self.retrieve_asset(sid, default_none) for sid in sids] def _retrieve_equity(self, sid): """ Retrieve the Equity object of a given sid. """ - try: - return self._equity_cache[sid] - except KeyError: - pass - - data = self.select_equity_by_sid(sid).execute().fetchone() - # Convert 'data' from a RowProxy object to a dict, to allow assignment - data = dict(data.items()) - if data: - _convert_asset_timestamp_fields(data) - - equity = Equity(**data) - else: - equity = None - - self._equity_cache[sid] = equity - return equity + return self._retrieve_asset( + sid, self._equity_cache, self.equities, Equity, + ) def _retrieve_futures_contract(self, sid): """ Retrieve the Future object of a given sid. """ + return self._retrieve_asset( + sid, self._future_cache, self.futures_contracts, Future, + ) + + @staticmethod + def _select_asset_by_sid(asset_tbl, sid): + return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid)) + + def _retrieve_asset(self, sid, cache, asset_tbl, asset_type): try: - return self._future_cache[sid] + return cache[sid] except KeyError: pass - data = self.select_future_by_sid(sid).execute().fetchone() + data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone() # Convert 'data' from a RowProxy object to a dict, to allow assignment data = dict(data.items()) if data: _convert_asset_timestamp_fields(data) - future = Future(**data) + asset = asset_type(**data) else: - future = None + asset = None - self._future_cache[sid] = future - return future + cache[sid] = asset + return asset def lookup_symbol(self, symbol, as_of_date, fuzzy=False): """ @@ -558,7 +503,7 @@ class AssetFinder(object): # Handle missing assets if len(missing) > 0: - warnings.warn("Missing assets for identifiers: " + missing) + warnings.warn("Missing assets for identifiers: %s" % missing) # Return a list of the sids of the found assets return [asset.sid for asset in matches] diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 0d003326..ddd68f5c 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -121,7 +121,10 @@ class TradingEnvironment(object): else: self.engine = engine = asset_db_path - self.asset_finder = AssetFinder(engine) + if engine is not None: + self.asset_finder = AssetFinder(engine) + else: + self.asset_finder = None def write_data(self, engine=None, @@ -196,7 +199,7 @@ class TradingEnvironment(object): .write_all(self.engine, allow_sid_assignment=allow_sid_assignment) def _write_data_dicts(self, equities=None, futures=None, exchanges=None, - root_symbols=None, allow_sid_assignment=True): + root_symbols=None): AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\ .write_all(self.engine) diff --git a/zipline/pipeline/loaders/_adjustments.pyx b/zipline/pipeline/loaders/_adjustments.pyx index 572ceda4..9a6a8b42 100644 --- a/zipline/pipeline/loaders/_adjustments.pyx +++ b/zipline/pipeline/loaders/_adjustments.pyx @@ -28,6 +28,9 @@ ctypedef object DatetimeIndex_t ctypedef object Int64Index_t from zipline.lib.adjustment import Float64Multiply +from zipline.assets.asset_writer import ( + SQLITE_MAX_VARIABLE_NUMBER as SQLITE_MAX_IN_STATEMENT, +) _SID_QUERY_TEMPLATE = """ SELECT DISTINCT sid FROM {0} @@ -44,7 +47,6 @@ FROM {0} WHERE sid IN ({1}) AND effective_date >= {2} AND effective_date <= {3} """ -cdef int SQLITE_MAX_IN_STATEMENT = 999 EPOCH = Timestamp(0, tz='UTC') cdef set _get_sids_from_table(object db,