Merge pull request #746 from quantopian/write_faster

Write faster
This commit is contained in:
Richard Frank
2015-10-02 12:21:27 -04:00
4 changed files with 147 additions and 338 deletions
+109 -250
View File
@@ -13,47 +13,11 @@ import sqlalchemy as sa
from zipline.errors import SidAssignmentError
from zipline.assets._assets import Asset
SQLITE_MAX_VARIABLE_NUMBER = 999
# Define a namedtuple for use with the load_data and _load_data methods
AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols')
# Expected fields for an Asset's metadata
ASSET_TABLE_FIELDS = frozenset({
'sid',
'symbol',
'asset_name',
'start_date',
'end_date',
'first_traded',
'exchange',
})
# Expected fields for a Future's metadata
FUTURE_TABLE_FIELDS = ASSET_TABLE_FIELDS | {
'notice_date',
'expiration_date',
'auto_close_date',
'contract_multiplier',
}
# Expected fields for an Equity's metadata
EQUITY_TABLE_FIELDS = ASSET_TABLE_FIELDS | {
'company_symbol',
'share_class_symbol',
'fuzzy_symbol',
}
EXCHANGE_TABLE_FIELDS = frozenset({
'exchange',
'timezone',
})
ROOT_SYMBOL_TABLE_FIELDS = frozenset({
'root_symbol',
'root_symbol_id',
'sector',
'description',
'exchange',
})
# Default values for the equities DataFrame
_equities_defaults = {
@@ -162,21 +126,21 @@ def _generate_output_dataframe(data_subset, defaults):
"""
# The columns provided.
cols = set(data_subset.columns)
desired_cols = {col for col in defaults.keys()}
desired_cols = set(defaults)
# Drop columns with unrecognised headers.
data_subset.drop(cols - (cols & desired_cols),
data_subset.drop(cols - desired_cols,
axis=1,
inplace=True)
# Get those columns which we need but
# for which no data has been supplied.
need = desired_cols - set(data_subset.columns)
need = desired_cols - cols
# Combine the users supplied data with our required columns.
output = pd.concat(
(data_subset, pd.DataFrame(
_dict_subset(defaults, need),
{k: defaults[k] for k in need},
data_subset.index,
)),
axis=1,
@@ -186,13 +150,6 @@ def _generate_output_dataframe(data_subset, defaults):
return output
def _dict_subset(dict_, subset):
res = {}
for k in subset:
res[k] = dict_[k]
return res
class AssetDBWriter(with_metaclass(ABCMeta)):
"""
Class used to write arbitrary data to SQLite database.
@@ -209,10 +166,34 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
Returns data in standard format.
"""
CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER
def __init__(self, equities=None, futures=None, exchanges=None,
root_symbols=None):
if equities is None:
equities = self.defaultval()
self._equities = equities
if futures is None:
futures = self.defaultval()
self._futures = futures
if exchanges is None:
exchanges = self.defaultval()
self._exchanges = exchanges
if root_symbols is None:
root_symbols = self.defaultval()
self._root_symbols = root_symbols
@abstractmethod
def defaultval(self):
raise NotImplementedError
def write_all(self,
engine,
allow_sid_assignment=True,
constraints=True):
allow_sid_assignment=True):
""" Write pre-supplied data to SQLite.
Parameters
@@ -230,7 +211,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
# Begin an SQL transaction.
with engine.begin() as txn:
# Create SQL tables.
self.init_db(txn, constraints)
self.init_db(txn)
# Get the data to add to SQL.
data = self.load_data()
# Write the data to SQL.
@@ -239,45 +220,40 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
self._write_futures(data.futures, txn)
self._write_equities(data.equities, txn)
def _write_exchanges(self, exchanges, bind=None):
recs = exchanges.reset_index().rename_axis(
{'index': 'exchange'},
1,
).to_dict('records')
# In SQLAlchemy, insert().values([]) will insert NULLs,
# hence we check first to avoid violating NOT NULL constraints.
if recs:
self.futures_exchanges.insert().values(recs).execute(bind=bind)
def _write_df_to_table(self, df, tbl, bind):
df.to_sql(
tbl.name,
bind.connection,
index_label=[col.name for col in tbl.primary_key.columns][0],
if_exists='append',
chunksize=self.CHUNK_SIZE,
)
def _write_root_symbols(self, root_symbols, bind=None):
recs = root_symbols.reset_index().rename_axis(
{'index': 'root_symbol'},
1,
).to_dict('records')
if recs:
self.futures_root_symbols.insert().values(recs).execute(bind=bind)
def _write_assets(self, assets, asset_tbl, asset_type, bind):
self._write_df_to_table(assets, asset_tbl, bind)
def _write_futures(self, futures, bind=None):
recs = futures.reset_index().rename_axis(
{'index': 'sid'},
1,
).to_dict('records')
for record in recs:
self.futures_contracts.insert().values([record]).execute(bind=bind)
self.asset_router.insert().values([(record['sid'], 'future')])\
.execute(bind=bind)
pd.DataFrame({self.asset_router.c.sid.name: assets.index.values,
self.asset_router.c.asset_type.name: asset_type}).to_sql(
self.asset_router.name,
bind.connection,
if_exists='append',
index=False,
chunksize=self.CHUNK_SIZE,
)
def _write_equities(self, equities, bind=None):
recs = equities.reset_index().rename_axis(
{'index': 'sid'},
1,
).to_dict('records')
for record in recs:
self.equities.insert().values([record]).execute(bind=bind)
self.asset_router.insert().values((record['sid'], 'equity'))\
.execute(bind=bind)
def _write_exchanges(self, exchanges, bind):
self._write_df_to_table(exchanges, self.futures_exchanges, bind)
def init_db(self, engine, constraints=True):
def _write_root_symbols(self, root_symbols, bind):
self._write_df_to_table(root_symbols, self.futures_root_symbols, bind)
def _write_futures(self, futures, bind):
self._write_assets(futures, self.futures_contracts, 'future', bind)
def _write_equities(self, equities, bind):
self._write_assets(equities, self.equities, 'equity', bind)
def init_db(self, engine):
"""Connect to database and create tables.
Parameters
@@ -287,7 +263,8 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
constraints : bool, optional
If True, create SQL ForeignKey and PrimaryKey constraints.
"""
self.sql_metadata = metadata = sa.MetaData(bind=engine)
metadata = sa.MetaData(bind=engine)
self.equities = sa.Table(
'equities',
metadata,
@@ -296,16 +273,16 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Integer,
unique=True,
nullable=False,
primary_key=constraints,
primary_key=True,
),
sa.Column('symbol', sa.Text),
sa.Column('company_symbol', sa.Text, index=True),
sa.Column('share_class_symbol', sa.Text),
sa.Column('fuzzy_symbol', sa.Text, index=True),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0),
sa.Column('end_date', sa.Integer),
sa.Column('first_traded', sa.Integer),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer, nullable=False),
sa.Column('exchange', sa.Text),
)
self.futures_exchanges = sa.Table(
@@ -316,7 +293,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Text,
unique=True,
nullable=False,
primary_key=constraints,
primary_key=True,
),
sa.Column('timezone', sa.Text),
)
@@ -328,7 +305,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Text,
unique=True,
nullable=False,
primary_key=constraints,
primary_key=True,
),
sa.Column('root_symbol_id', sa.Integer),
sa.Column('sector', sa.Text),
@@ -336,8 +313,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Column(
'exchange',
sa.Text,
*((sa.ForeignKey(self.futures_exchanges.c.exchange),)
if constraints else ())
sa.ForeignKey(self.futures_exchanges.c.exchange),
),
)
self.futures_contracts = sa.Table(
@@ -348,28 +324,26 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Integer,
unique=True,
nullable=False,
primary_key=constraints,
primary_key=True,
),
sa.Column('symbol', sa.Text),
sa.Column(
'root_symbol',
sa.Text,
*((sa.ForeignKey(self.futures_root_symbols.c.root_symbol),)
if constraints else ())
sa.ForeignKey(self.futures_root_symbols.c.root_symbol),
),
sa.Column('asset_name', sa.Text),
sa.Column('start_date', sa.Integer, default=0),
sa.Column('end_date', sa.Integer),
sa.Column('first_traded', sa.Integer),
sa.Column('start_date', sa.Integer, default=0, nullable=False),
sa.Column('end_date', sa.Integer, nullable=False),
sa.Column('first_traded', sa.Integer, nullable=False),
sa.Column(
'exchange',
sa.Text,
*((sa.ForeignKey(self.futures_exchanges.c.exchange),)
if constraints else ())
sa.ForeignKey(self.futures_exchanges.c.exchange),
),
sa.Column('notice_date', sa.Integer),
sa.Column('expiration_date', sa.Integer),
sa.Column('auto_close_date', sa.Integer),
sa.Column('notice_date', sa.Integer, nullable=False),
sa.Column('expiration_date', sa.Integer, nullable=False),
sa.Column('auto_close_date', sa.Integer, nullable=False),
sa.Column('contract_multiplier', sa.Float),
)
self.asset_router = sa.Table(
@@ -380,7 +354,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
sa.Integer,
unique=True,
nullable=False,
primary_key=constraints),
primary_key=True),
sa.Column('asset_type', sa.Text),
)
# Create the SQL tables if they do not already exist.
@@ -400,10 +374,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
###############################
# HACK: If company_name is provided, map it to asset_name
if ('company_name' in data.equities.columns) \
and ('asset_name' not in data.equities.columns):
if ('company_name' in data.equities.columns
and 'asset_name' not in data.equities.columns):
data.equities['asset_name'] = data.equities['company_name']
if ('file_name' in data.equities.columns):
if 'file_name' in data.equities.columns:
data.equities['symbol'] = data.equities['file_name']
equities_output = _generate_output_dataframe(
@@ -431,12 +405,9 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
equities_output.fuzzy_symbol.str.upper()
# Convert date columns to UNIX Epoch integers (nanoseconds)
equities_output['start_date'] = \
equities_output['start_date'].apply(self.convert_datetime)
equities_output['end_date'] = \
equities_output['end_date'].apply(self.convert_datetime)
equities_output['first_traded'] = \
equities_output['first_traded'].apply(self.convert_datetime)
for date_col in ('start_date', 'end_date', 'first_traded'):
equities_output[date_col] = \
self.dt_to_epoch_ns(equities_output[date_col])
##############################
# Generate futures DataFrame #
@@ -448,18 +419,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
)
# Convert date columns to UNIX Epoch integers (nanoseconds)
futures_output['start_date'] = \
futures_output['start_date'].apply(self.convert_datetime)
futures_output['end_date'] = \
futures_output['end_date'].apply(self.convert_datetime)
futures_output['first_traded'] = \
futures_output['first_traded'].apply(self.convert_datetime)
futures_output['notice_date'] = \
futures_output['notice_date'].apply(self.convert_datetime)
futures_output['expiration_date'] = \
futures_output['expiration_date'].apply(self.convert_datetime)
futures_output['auto_close_date'] = \
futures_output['auto_close_date'].apply(self.convert_datetime)
for date_col in ('start_date', 'end_date', 'first_traded',
'notice_date', 'expiration_date', 'auto_close_date'):
futures_output[date_col] = \
self.dt_to_epoch_ns(futures_output[date_col])
# Convert symbols and root_symbols to upper case.
futures_output['symbol'] = futures_output.symbol.str.upper()
@@ -488,56 +451,15 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
exchanges=exchanges_output,
root_symbols=root_symbols_output)
def convert_datetime(self, dt):
"""Convert a datetime variable to integer of nanoseconds
since UNIX Epoch.
Parameters
----------
dt : datetime-coercible
A string, int or pd.Timestamp instance representing a datetime, or
None/NaN.
Returns
-------
int
nanoseconds since UNIX Epoch, or None if parameter 'dt' is null.
"""
# Check for null parameter
if pd.isnull(dt):
return None
# If no timezone is specified, assume UTC.
# Otherwise, convert to UTC.
@staticmethod
def dt_to_epoch_ns(dt_series):
index = pd.to_datetime(dt_series.values)
try:
dt = pd.Timestamp(dt).tz_localize('UTC')
index = index.tz_localize('UTC')
except TypeError:
dt = pd.Timestamp(dt).tz_convert('UTC')
index = index.tz_convert('UTC')
# Get seconds from UNIX Epoch
total_seconds_from_epoch = self._seconds_from_unix_time(dt)
# Return nanoseconds since UNIX Epoch
return int(total_seconds_from_epoch * 1000000000)
def _seconds_from_unix_time(self, dt):
"""Return seconds between dt and UNIX Epoch.
Parameters
----------
dt: pandas.Timestamp
The time for which to calculate seconds since UNIX Epoch.
Returns
-------
float
Seconds between dt and UNIX Epoch.
"""
epoch = pd.to_datetime(0, utc=True)
delta = dt - epoch
return delta.total_seconds()
return index.view(int)
@abstractmethod
def _load_data(self):
@@ -559,28 +481,7 @@ class AssetDBWriterFromList(AssetDBWriter):
Class used to write list data to SQLite database.
"""
def __init__(self, equities=None, futures=None, exchanges=None,
root_symbols=None):
if equities is not None:
self._equities = equities
else:
self._equities = []
if futures is not None:
self._futures = futures
else:
self._futures = []
if exchanges is not None:
self._exchanges = exchanges
else:
self._exchanges = []
if root_symbols is not None:
self._root_symbols = root_symbols
else:
self._root_symbols = []
defaultval = list
def _load_data(self):
@@ -654,28 +555,7 @@ class AssetDBWriterFromDictionary(AssetDBWriter):
{id_0: {attribute_1 : ...}, id_1: {attribute_2: ...}, ...}
"""
def __init__(self, equities=None, futures=None, exchanges=None,
root_symbols=None):
if equities is not None:
self._equities = equities
else:
self._equities = {}
if futures is not None:
self._futures = futures
else:
self._futures = {}
if exchanges is not None:
self._exchanges = exchanges
else:
self._exchanges = {}
if root_symbols is not None:
self._root_symbols = root_symbols
else:
self._root_symbols = {}
defaultval = dict
def _load_data(self):
@@ -696,42 +576,21 @@ class AssetDBWriterFromDataFrame(AssetDBWriter):
Class used to write pandas.DataFrame data to SQLite database.
"""
def __init__(self, equities=None, futures=None, exchanges=None,
root_symbols=None):
if equities is not None:
self._equities = equities
else:
self._equities = pd.DataFrame()
if futures is not None:
self._futures = futures
else:
self._futures = pd.DataFrame()
if exchanges is not None:
self._exchanges = exchanges
else:
self._exchanges = pd.DataFrame()
if root_symbols is not None:
self._root_symbols = root_symbols
else:
self._root_symbols = pd.DataFrame()
defaultval = pd.DataFrame
def _load_data(self):
# Check whether identifier columns have been provided.
# If they have, set the index to this column.
# If not, assume the index already cotains the identifier information.
if 'sid' in self._equities.columns:
self._equities.set_index(['sid'], inplace=True)
if 'sid' in self._futures.columns:
self._futures.set_index(['sid'], inplace=True)
if 'exchange_id' in self._exchanges.columns:
self._exchanges.set_index(['exchange'], inplace=True)
if 'root_symbol_id' in self._root_symbols.columns:
self._root_symbols.set_index(['root_symbol'], inplace=True)
for df, id_col in [
(self._equities, 'sid'),
(self._futures, 'sid'),
(self._exchanges, 'exchange'),
(self._root_symbols, 'root_symbol'),
]:
if id_col in df.columns:
df.set_index([id_col], inplace=True)
return AssetData(equities=self._equities,
futures=self._futures,
+30 -85
View File
@@ -13,16 +13,15 @@
# limitations under the License.
from abc import ABCMeta
from functools import partial
from numbers import Integral
from operator import getitem, itemgetter
from operator import itemgetter
import warnings
from logbook import Logger
import numpy as np
import pandas as pd
from pandas.tseries.tools import normalize_date
from six import with_metaclass, string_types
from six import with_metaclass, string_types, viewkeys
import sqlalchemy as sa
from toolz import compose
@@ -37,8 +36,6 @@ from zipline.assets import (
Asset, Equity, Future,
)
from zipline.assets.asset_writer import (
FUTURE_TABLE_FIELDS,
EQUITY_TABLE_FIELDS,
split_delimited_symbol,
)
@@ -67,9 +64,9 @@ def _convert_asset_timestamp_fields(dict):
"""
Takes in a dict of Asset init args and converts dates to pd.Timestamps
"""
for key, value in dict.items():
if (key in _asset_timestamp_fields) and (value is not None):
dict[key] = pd.Timestamp(value, tz='UTC')
for key in (_asset_timestamp_fields & viewkeys(dict)):
value = pd.Timestamp(dict[key], tz='UTC')
dict[key] = None if pd.isnull(value) else value
class AssetFinder(object):
@@ -84,60 +81,13 @@ class AssetFinder(object):
self.engine = engine
metadata = sa.MetaData(bind=engine)
self.equities = equities = sa.Table(
'equities',
metadata,
autoload=True,
autoload_with=engine,
)
self.futures_exchanges = sa.Table(
'futures_exchanges',
metadata,
autoload=True,
autoload_with=engine,
)
self.futures_root_symbols = sa.Table(
'futures_root_symbols',
metadata,
autoload=True,
autoload_with=engine,
)
self.futures_contracts = futures_contracts = sa.Table(
'futures_contracts',
metadata,
autoload=True,
autoload_with=engine,
)
self.asset_router = sa.Table(
'asset_router',
metadata,
autoload=True,
autoload_with=engine,
)
# Create the equity and future queries once.
_equity_sid = equities.c.sid
_equity_by_sid = sa.select(
tuple(map(partial(getitem, equities.c), EQUITY_TABLE_FIELDS)),
)
table_names = ['equities', 'futures_exchanges', 'futures_root_symbols',
'futures_contracts', 'asset_router']
metadata.reflect(only=table_names)
for table_name in table_names:
setattr(self, table_name, metadata.tables[table_name])
def select_equity_by_sid(sid):
return _equity_by_sid.where(_equity_sid == int(sid))
self.select_equity_by_sid = select_equity_by_sid
_future_sid = futures_contracts.c.sid
_future_by_sid = sa.select(
tuple(map(
partial(getitem, futures_contracts.c),
FUTURE_TABLE_FIELDS,
)),
)
def select_future_by_sid(sid):
return _future_by_sid.where(_future_sid == int(sid))
self.select_future_by_sid = select_future_by_sid
# Cache for lookup of assets by sid, the objects in the asset lookp may
# be shared with the results from equity and future lookup caches.
#
@@ -203,51 +153,46 @@ class AssetFinder(object):
raise SidNotFound(sid=sid)
def retrieve_all(self, sids, default_none=False):
return [self.retrieve_asset(sid) for sid in sids]
return [self.retrieve_asset(sid, default_none) for sid in sids]
def _retrieve_equity(self, sid):
"""
Retrieve the Equity object of a given sid.
"""
try:
return self._equity_cache[sid]
except KeyError:
pass
data = self.select_equity_by_sid(sid).execute().fetchone()
# Convert 'data' from a RowProxy object to a dict, to allow assignment
data = dict(data.items())
if data:
_convert_asset_timestamp_fields(data)
equity = Equity(**data)
else:
equity = None
self._equity_cache[sid] = equity
return equity
return self._retrieve_asset(
sid, self._equity_cache, self.equities, Equity,
)
def _retrieve_futures_contract(self, sid):
"""
Retrieve the Future object of a given sid.
"""
return self._retrieve_asset(
sid, self._future_cache, self.futures_contracts, Future,
)
@staticmethod
def _select_asset_by_sid(asset_tbl, sid):
return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid))
def _retrieve_asset(self, sid, cache, asset_tbl, asset_type):
try:
return self._future_cache[sid]
return cache[sid]
except KeyError:
pass
data = self.select_future_by_sid(sid).execute().fetchone()
data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone()
# Convert 'data' from a RowProxy object to a dict, to allow assignment
data = dict(data.items())
if data:
_convert_asset_timestamp_fields(data)
future = Future(**data)
asset = asset_type(**data)
else:
future = None
asset = None
self._future_cache[sid] = future
return future
cache[sid] = asset
return asset
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
"""
@@ -558,7 +503,7 @@ class AssetFinder(object):
# Handle missing assets
if len(missing) > 0:
warnings.warn("Missing assets for identifiers: " + missing)
warnings.warn("Missing assets for identifiers: %s" % missing)
# Return a list of the sids of the found assets
return [asset.sid for asset in matches]
+5 -2
View File
@@ -121,7 +121,10 @@ class TradingEnvironment(object):
else:
self.engine = engine = asset_db_path
self.asset_finder = AssetFinder(engine)
if engine is not None:
self.asset_finder = AssetFinder(engine)
else:
self.asset_finder = None
def write_data(self,
engine=None,
@@ -196,7 +199,7 @@ class TradingEnvironment(object):
.write_all(self.engine, allow_sid_assignment=allow_sid_assignment)
def _write_data_dicts(self, equities=None, futures=None, exchanges=None,
root_symbols=None, allow_sid_assignment=True):
root_symbols=None):
AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\
.write_all(self.engine)
+3 -1
View File
@@ -28,6 +28,9 @@ ctypedef object DatetimeIndex_t
ctypedef object Int64Index_t
from zipline.lib.adjustment import Float64Multiply
from zipline.assets.asset_writer import (
SQLITE_MAX_VARIABLE_NUMBER as SQLITE_MAX_IN_STATEMENT,
)
_SID_QUERY_TEMPLATE = """
SELECT DISTINCT sid FROM {0}
@@ -44,7 +47,6 @@ FROM {0}
WHERE sid IN ({1}) AND effective_date >= {2} AND effective_date <= {3}
"""
cdef int SQLITE_MAX_IN_STATEMENT = 999
EPOCH = Timestamp(0, tz='UTC')
cdef set _get_sids_from_table(object db,