mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 06:44:50 +08:00
+109
-250
@@ -13,47 +13,11 @@ import sqlalchemy as sa
|
||||
from zipline.errors import SidAssignmentError
|
||||
from zipline.assets._assets import Asset
|
||||
|
||||
SQLITE_MAX_VARIABLE_NUMBER = 999
|
||||
|
||||
# Define a namedtuple for use with the load_data and _load_data methods
|
||||
AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols')
|
||||
|
||||
# Expected fields for an Asset's metadata
|
||||
ASSET_TABLE_FIELDS = frozenset({
|
||||
'sid',
|
||||
'symbol',
|
||||
'asset_name',
|
||||
'start_date',
|
||||
'end_date',
|
||||
'first_traded',
|
||||
'exchange',
|
||||
})
|
||||
|
||||
# Expected fields for a Future's metadata
|
||||
FUTURE_TABLE_FIELDS = ASSET_TABLE_FIELDS | {
|
||||
'notice_date',
|
||||
'expiration_date',
|
||||
'auto_close_date',
|
||||
'contract_multiplier',
|
||||
}
|
||||
|
||||
# Expected fields for an Equity's metadata
|
||||
EQUITY_TABLE_FIELDS = ASSET_TABLE_FIELDS | {
|
||||
'company_symbol',
|
||||
'share_class_symbol',
|
||||
'fuzzy_symbol',
|
||||
}
|
||||
|
||||
EXCHANGE_TABLE_FIELDS = frozenset({
|
||||
'exchange',
|
||||
'timezone',
|
||||
})
|
||||
|
||||
ROOT_SYMBOL_TABLE_FIELDS = frozenset({
|
||||
'root_symbol',
|
||||
'root_symbol_id',
|
||||
'sector',
|
||||
'description',
|
||||
'exchange',
|
||||
})
|
||||
|
||||
# Default values for the equities DataFrame
|
||||
_equities_defaults = {
|
||||
@@ -162,21 +126,21 @@ def _generate_output_dataframe(data_subset, defaults):
|
||||
"""
|
||||
# The columns provided.
|
||||
cols = set(data_subset.columns)
|
||||
desired_cols = {col for col in defaults.keys()}
|
||||
desired_cols = set(defaults)
|
||||
|
||||
# Drop columns with unrecognised headers.
|
||||
data_subset.drop(cols - (cols & desired_cols),
|
||||
data_subset.drop(cols - desired_cols,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
|
||||
# Get those columns which we need but
|
||||
# for which no data has been supplied.
|
||||
need = desired_cols - set(data_subset.columns)
|
||||
need = desired_cols - cols
|
||||
|
||||
# Combine the users supplied data with our required columns.
|
||||
output = pd.concat(
|
||||
(data_subset, pd.DataFrame(
|
||||
_dict_subset(defaults, need),
|
||||
{k: defaults[k] for k in need},
|
||||
data_subset.index,
|
||||
)),
|
||||
axis=1,
|
||||
@@ -186,13 +150,6 @@ def _generate_output_dataframe(data_subset, defaults):
|
||||
return output
|
||||
|
||||
|
||||
def _dict_subset(dict_, subset):
|
||||
res = {}
|
||||
for k in subset:
|
||||
res[k] = dict_[k]
|
||||
return res
|
||||
|
||||
|
||||
class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
"""
|
||||
Class used to write arbitrary data to SQLite database.
|
||||
@@ -209,10 +166,34 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
Returns data in standard format.
|
||||
|
||||
"""
|
||||
CHUNK_SIZE = SQLITE_MAX_VARIABLE_NUMBER
|
||||
|
||||
def __init__(self, equities=None, futures=None, exchanges=None,
|
||||
root_symbols=None):
|
||||
|
||||
if equities is None:
|
||||
equities = self.defaultval()
|
||||
self._equities = equities
|
||||
|
||||
if futures is None:
|
||||
futures = self.defaultval()
|
||||
self._futures = futures
|
||||
|
||||
if exchanges is None:
|
||||
exchanges = self.defaultval()
|
||||
self._exchanges = exchanges
|
||||
|
||||
if root_symbols is None:
|
||||
root_symbols = self.defaultval()
|
||||
self._root_symbols = root_symbols
|
||||
|
||||
@abstractmethod
|
||||
def defaultval(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def write_all(self,
|
||||
engine,
|
||||
allow_sid_assignment=True,
|
||||
constraints=True):
|
||||
allow_sid_assignment=True):
|
||||
""" Write pre-supplied data to SQLite.
|
||||
|
||||
Parameters
|
||||
@@ -230,7 +211,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
# Begin an SQL transaction.
|
||||
with engine.begin() as txn:
|
||||
# Create SQL tables.
|
||||
self.init_db(txn, constraints)
|
||||
self.init_db(txn)
|
||||
# Get the data to add to SQL.
|
||||
data = self.load_data()
|
||||
# Write the data to SQL.
|
||||
@@ -239,45 +220,40 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
self._write_futures(data.futures, txn)
|
||||
self._write_equities(data.equities, txn)
|
||||
|
||||
def _write_exchanges(self, exchanges, bind=None):
|
||||
recs = exchanges.reset_index().rename_axis(
|
||||
{'index': 'exchange'},
|
||||
1,
|
||||
).to_dict('records')
|
||||
# In SQLAlchemy, insert().values([]) will insert NULLs,
|
||||
# hence we check first to avoid violating NOT NULL constraints.
|
||||
if recs:
|
||||
self.futures_exchanges.insert().values(recs).execute(bind=bind)
|
||||
def _write_df_to_table(self, df, tbl, bind):
|
||||
df.to_sql(
|
||||
tbl.name,
|
||||
bind.connection,
|
||||
index_label=[col.name for col in tbl.primary_key.columns][0],
|
||||
if_exists='append',
|
||||
chunksize=self.CHUNK_SIZE,
|
||||
)
|
||||
|
||||
def _write_root_symbols(self, root_symbols, bind=None):
|
||||
recs = root_symbols.reset_index().rename_axis(
|
||||
{'index': 'root_symbol'},
|
||||
1,
|
||||
).to_dict('records')
|
||||
if recs:
|
||||
self.futures_root_symbols.insert().values(recs).execute(bind=bind)
|
||||
def _write_assets(self, assets, asset_tbl, asset_type, bind):
|
||||
self._write_df_to_table(assets, asset_tbl, bind)
|
||||
|
||||
def _write_futures(self, futures, bind=None):
|
||||
recs = futures.reset_index().rename_axis(
|
||||
{'index': 'sid'},
|
||||
1,
|
||||
).to_dict('records')
|
||||
for record in recs:
|
||||
self.futures_contracts.insert().values([record]).execute(bind=bind)
|
||||
self.asset_router.insert().values([(record['sid'], 'future')])\
|
||||
.execute(bind=bind)
|
||||
pd.DataFrame({self.asset_router.c.sid.name: assets.index.values,
|
||||
self.asset_router.c.asset_type.name: asset_type}).to_sql(
|
||||
self.asset_router.name,
|
||||
bind.connection,
|
||||
if_exists='append',
|
||||
index=False,
|
||||
chunksize=self.CHUNK_SIZE,
|
||||
)
|
||||
|
||||
def _write_equities(self, equities, bind=None):
|
||||
recs = equities.reset_index().rename_axis(
|
||||
{'index': 'sid'},
|
||||
1,
|
||||
).to_dict('records')
|
||||
for record in recs:
|
||||
self.equities.insert().values([record]).execute(bind=bind)
|
||||
self.asset_router.insert().values((record['sid'], 'equity'))\
|
||||
.execute(bind=bind)
|
||||
def _write_exchanges(self, exchanges, bind):
|
||||
self._write_df_to_table(exchanges, self.futures_exchanges, bind)
|
||||
|
||||
def init_db(self, engine, constraints=True):
|
||||
def _write_root_symbols(self, root_symbols, bind):
|
||||
self._write_df_to_table(root_symbols, self.futures_root_symbols, bind)
|
||||
|
||||
def _write_futures(self, futures, bind):
|
||||
self._write_assets(futures, self.futures_contracts, 'future', bind)
|
||||
|
||||
def _write_equities(self, equities, bind):
|
||||
self._write_assets(equities, self.equities, 'equity', bind)
|
||||
|
||||
def init_db(self, engine):
|
||||
"""Connect to database and create tables.
|
||||
|
||||
Parameters
|
||||
@@ -287,7 +263,8 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
constraints : bool, optional
|
||||
If True, create SQL ForeignKey and PrimaryKey constraints.
|
||||
"""
|
||||
self.sql_metadata = metadata = sa.MetaData(bind=engine)
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
|
||||
self.equities = sa.Table(
|
||||
'equities',
|
||||
metadata,
|
||||
@@ -296,16 +273,16 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=constraints,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text),
|
||||
sa.Column('company_symbol', sa.Text, index=True),
|
||||
sa.Column('share_class_symbol', sa.Text),
|
||||
sa.Column('fuzzy_symbol', sa.Text, index=True),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0),
|
||||
sa.Column('end_date', sa.Integer),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer, nullable=False),
|
||||
sa.Column('exchange', sa.Text),
|
||||
)
|
||||
self.futures_exchanges = sa.Table(
|
||||
@@ -316,7 +293,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=constraints,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('timezone', sa.Text),
|
||||
)
|
||||
@@ -328,7 +305,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Text,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=constraints,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('root_symbol_id', sa.Integer),
|
||||
sa.Column('sector', sa.Text),
|
||||
@@ -336,8 +313,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
*((sa.ForeignKey(self.futures_exchanges.c.exchange),)
|
||||
if constraints else ())
|
||||
sa.ForeignKey(self.futures_exchanges.c.exchange),
|
||||
),
|
||||
)
|
||||
self.futures_contracts = sa.Table(
|
||||
@@ -348,28 +324,26 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=constraints,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('symbol', sa.Text),
|
||||
sa.Column(
|
||||
'root_symbol',
|
||||
sa.Text,
|
||||
*((sa.ForeignKey(self.futures_root_symbols.c.root_symbol),)
|
||||
if constraints else ())
|
||||
sa.ForeignKey(self.futures_root_symbols.c.root_symbol),
|
||||
),
|
||||
sa.Column('asset_name', sa.Text),
|
||||
sa.Column('start_date', sa.Integer, default=0),
|
||||
sa.Column('end_date', sa.Integer),
|
||||
sa.Column('first_traded', sa.Integer),
|
||||
sa.Column('start_date', sa.Integer, default=0, nullable=False),
|
||||
sa.Column('end_date', sa.Integer, nullable=False),
|
||||
sa.Column('first_traded', sa.Integer, nullable=False),
|
||||
sa.Column(
|
||||
'exchange',
|
||||
sa.Text,
|
||||
*((sa.ForeignKey(self.futures_exchanges.c.exchange),)
|
||||
if constraints else ())
|
||||
sa.ForeignKey(self.futures_exchanges.c.exchange),
|
||||
),
|
||||
sa.Column('notice_date', sa.Integer),
|
||||
sa.Column('expiration_date', sa.Integer),
|
||||
sa.Column('auto_close_date', sa.Integer),
|
||||
sa.Column('notice_date', sa.Integer, nullable=False),
|
||||
sa.Column('expiration_date', sa.Integer, nullable=False),
|
||||
sa.Column('auto_close_date', sa.Integer, nullable=False),
|
||||
sa.Column('contract_multiplier', sa.Float),
|
||||
)
|
||||
self.asset_router = sa.Table(
|
||||
@@ -380,7 +354,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
sa.Integer,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
primary_key=constraints),
|
||||
primary_key=True),
|
||||
sa.Column('asset_type', sa.Text),
|
||||
)
|
||||
# Create the SQL tables if they do not already exist.
|
||||
@@ -400,10 +374,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
###############################
|
||||
|
||||
# HACK: If company_name is provided, map it to asset_name
|
||||
if ('company_name' in data.equities.columns) \
|
||||
and ('asset_name' not in data.equities.columns):
|
||||
if ('company_name' in data.equities.columns
|
||||
and 'asset_name' not in data.equities.columns):
|
||||
data.equities['asset_name'] = data.equities['company_name']
|
||||
if ('file_name' in data.equities.columns):
|
||||
if 'file_name' in data.equities.columns:
|
||||
data.equities['symbol'] = data.equities['file_name']
|
||||
|
||||
equities_output = _generate_output_dataframe(
|
||||
@@ -431,12 +405,9 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
equities_output.fuzzy_symbol.str.upper()
|
||||
|
||||
# Convert date columns to UNIX Epoch integers (nanoseconds)
|
||||
equities_output['start_date'] = \
|
||||
equities_output['start_date'].apply(self.convert_datetime)
|
||||
equities_output['end_date'] = \
|
||||
equities_output['end_date'].apply(self.convert_datetime)
|
||||
equities_output['first_traded'] = \
|
||||
equities_output['first_traded'].apply(self.convert_datetime)
|
||||
for date_col in ('start_date', 'end_date', 'first_traded'):
|
||||
equities_output[date_col] = \
|
||||
self.dt_to_epoch_ns(equities_output[date_col])
|
||||
|
||||
##############################
|
||||
# Generate futures DataFrame #
|
||||
@@ -448,18 +419,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
)
|
||||
|
||||
# Convert date columns to UNIX Epoch integers (nanoseconds)
|
||||
futures_output['start_date'] = \
|
||||
futures_output['start_date'].apply(self.convert_datetime)
|
||||
futures_output['end_date'] = \
|
||||
futures_output['end_date'].apply(self.convert_datetime)
|
||||
futures_output['first_traded'] = \
|
||||
futures_output['first_traded'].apply(self.convert_datetime)
|
||||
futures_output['notice_date'] = \
|
||||
futures_output['notice_date'].apply(self.convert_datetime)
|
||||
futures_output['expiration_date'] = \
|
||||
futures_output['expiration_date'].apply(self.convert_datetime)
|
||||
futures_output['auto_close_date'] = \
|
||||
futures_output['auto_close_date'].apply(self.convert_datetime)
|
||||
for date_col in ('start_date', 'end_date', 'first_traded',
|
||||
'notice_date', 'expiration_date', 'auto_close_date'):
|
||||
futures_output[date_col] = \
|
||||
self.dt_to_epoch_ns(futures_output[date_col])
|
||||
|
||||
# Convert symbols and root_symbols to upper case.
|
||||
futures_output['symbol'] = futures_output.symbol.str.upper()
|
||||
@@ -488,56 +451,15 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
|
||||
exchanges=exchanges_output,
|
||||
root_symbols=root_symbols_output)
|
||||
|
||||
def convert_datetime(self, dt):
|
||||
"""Convert a datetime variable to integer of nanoseconds
|
||||
since UNIX Epoch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt : datetime-coercible
|
||||
A string, int or pd.Timestamp instance representing a datetime, or
|
||||
None/NaN.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
nanoseconds since UNIX Epoch, or None if parameter 'dt' is null.
|
||||
"""
|
||||
|
||||
# Check for null parameter
|
||||
if pd.isnull(dt):
|
||||
return None
|
||||
|
||||
# If no timezone is specified, assume UTC.
|
||||
# Otherwise, convert to UTC.
|
||||
@staticmethod
|
||||
def dt_to_epoch_ns(dt_series):
|
||||
index = pd.to_datetime(dt_series.values)
|
||||
try:
|
||||
dt = pd.Timestamp(dt).tz_localize('UTC')
|
||||
index = index.tz_localize('UTC')
|
||||
except TypeError:
|
||||
dt = pd.Timestamp(dt).tz_convert('UTC')
|
||||
index = index.tz_convert('UTC')
|
||||
|
||||
# Get seconds from UNIX Epoch
|
||||
total_seconds_from_epoch = self._seconds_from_unix_time(dt)
|
||||
|
||||
# Return nanoseconds since UNIX Epoch
|
||||
return int(total_seconds_from_epoch * 1000000000)
|
||||
|
||||
def _seconds_from_unix_time(self, dt):
|
||||
"""Return seconds between dt and UNIX Epoch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt: pandas.Timestamp
|
||||
The time for which to calculate seconds since UNIX Epoch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Seconds between dt and UNIX Epoch.
|
||||
|
||||
"""
|
||||
epoch = pd.to_datetime(0, utc=True)
|
||||
delta = dt - epoch
|
||||
return delta.total_seconds()
|
||||
return index.view(int)
|
||||
|
||||
@abstractmethod
|
||||
def _load_data(self):
|
||||
@@ -559,28 +481,7 @@ class AssetDBWriterFromList(AssetDBWriter):
|
||||
Class used to write list data to SQLite database.
|
||||
"""
|
||||
|
||||
def __init__(self, equities=None, futures=None, exchanges=None,
|
||||
root_symbols=None):
|
||||
|
||||
if equities is not None:
|
||||
self._equities = equities
|
||||
else:
|
||||
self._equities = []
|
||||
|
||||
if futures is not None:
|
||||
self._futures = futures
|
||||
else:
|
||||
self._futures = []
|
||||
|
||||
if exchanges is not None:
|
||||
self._exchanges = exchanges
|
||||
else:
|
||||
self._exchanges = []
|
||||
|
||||
if root_symbols is not None:
|
||||
self._root_symbols = root_symbols
|
||||
else:
|
||||
self._root_symbols = []
|
||||
defaultval = list
|
||||
|
||||
def _load_data(self):
|
||||
|
||||
@@ -654,28 +555,7 @@ class AssetDBWriterFromDictionary(AssetDBWriter):
|
||||
{id_0: {attribute_1 : ...}, id_1: {attribute_2: ...}, ...}
|
||||
"""
|
||||
|
||||
def __init__(self, equities=None, futures=None, exchanges=None,
|
||||
root_symbols=None):
|
||||
|
||||
if equities is not None:
|
||||
self._equities = equities
|
||||
else:
|
||||
self._equities = {}
|
||||
|
||||
if futures is not None:
|
||||
self._futures = futures
|
||||
else:
|
||||
self._futures = {}
|
||||
|
||||
if exchanges is not None:
|
||||
self._exchanges = exchanges
|
||||
else:
|
||||
self._exchanges = {}
|
||||
|
||||
if root_symbols is not None:
|
||||
self._root_symbols = root_symbols
|
||||
else:
|
||||
self._root_symbols = {}
|
||||
defaultval = dict
|
||||
|
||||
def _load_data(self):
|
||||
|
||||
@@ -696,42 +576,21 @@ class AssetDBWriterFromDataFrame(AssetDBWriter):
|
||||
Class used to write pandas.DataFrame data to SQLite database.
|
||||
"""
|
||||
|
||||
def __init__(self, equities=None, futures=None, exchanges=None,
|
||||
root_symbols=None):
|
||||
|
||||
if equities is not None:
|
||||
self._equities = equities
|
||||
else:
|
||||
self._equities = pd.DataFrame()
|
||||
|
||||
if futures is not None:
|
||||
self._futures = futures
|
||||
else:
|
||||
self._futures = pd.DataFrame()
|
||||
|
||||
if exchanges is not None:
|
||||
self._exchanges = exchanges
|
||||
else:
|
||||
self._exchanges = pd.DataFrame()
|
||||
|
||||
if root_symbols is not None:
|
||||
self._root_symbols = root_symbols
|
||||
else:
|
||||
self._root_symbols = pd.DataFrame()
|
||||
defaultval = pd.DataFrame
|
||||
|
||||
def _load_data(self):
|
||||
|
||||
# Check whether identifier columns have been provided.
|
||||
# If they have, set the index to this column.
|
||||
# If not, assume the index already cotains the identifier information.
|
||||
if 'sid' in self._equities.columns:
|
||||
self._equities.set_index(['sid'], inplace=True)
|
||||
if 'sid' in self._futures.columns:
|
||||
self._futures.set_index(['sid'], inplace=True)
|
||||
if 'exchange_id' in self._exchanges.columns:
|
||||
self._exchanges.set_index(['exchange'], inplace=True)
|
||||
if 'root_symbol_id' in self._root_symbols.columns:
|
||||
self._root_symbols.set_index(['root_symbol'], inplace=True)
|
||||
for df, id_col in [
|
||||
(self._equities, 'sid'),
|
||||
(self._futures, 'sid'),
|
||||
(self._exchanges, 'exchange'),
|
||||
(self._root_symbols, 'root_symbol'),
|
||||
]:
|
||||
if id_col in df.columns:
|
||||
df.set_index([id_col], inplace=True)
|
||||
|
||||
return AssetData(equities=self._equities,
|
||||
futures=self._futures,
|
||||
|
||||
+30
-85
@@ -13,16 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABCMeta
|
||||
from functools import partial
|
||||
from numbers import Integral
|
||||
from operator import getitem, itemgetter
|
||||
from operator import itemgetter
|
||||
import warnings
|
||||
|
||||
from logbook import Logger
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.tseries.tools import normalize_date
|
||||
from six import with_metaclass, string_types
|
||||
from six import with_metaclass, string_types, viewkeys
|
||||
import sqlalchemy as sa
|
||||
from toolz import compose
|
||||
|
||||
@@ -37,8 +36,6 @@ from zipline.assets import (
|
||||
Asset, Equity, Future,
|
||||
)
|
||||
from zipline.assets.asset_writer import (
|
||||
FUTURE_TABLE_FIELDS,
|
||||
EQUITY_TABLE_FIELDS,
|
||||
split_delimited_symbol,
|
||||
)
|
||||
|
||||
@@ -67,9 +64,9 @@ def _convert_asset_timestamp_fields(dict):
|
||||
"""
|
||||
Takes in a dict of Asset init args and converts dates to pd.Timestamps
|
||||
"""
|
||||
for key, value in dict.items():
|
||||
if (key in _asset_timestamp_fields) and (value is not None):
|
||||
dict[key] = pd.Timestamp(value, tz='UTC')
|
||||
for key in (_asset_timestamp_fields & viewkeys(dict)):
|
||||
value = pd.Timestamp(dict[key], tz='UTC')
|
||||
dict[key] = None if pd.isnull(value) else value
|
||||
|
||||
|
||||
class AssetFinder(object):
|
||||
@@ -84,60 +81,13 @@ class AssetFinder(object):
|
||||
|
||||
self.engine = engine
|
||||
metadata = sa.MetaData(bind=engine)
|
||||
self.equities = equities = sa.Table(
|
||||
'equities',
|
||||
metadata,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
)
|
||||
self.futures_exchanges = sa.Table(
|
||||
'futures_exchanges',
|
||||
metadata,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
)
|
||||
self.futures_root_symbols = sa.Table(
|
||||
'futures_root_symbols',
|
||||
metadata,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
)
|
||||
self.futures_contracts = futures_contracts = sa.Table(
|
||||
'futures_contracts',
|
||||
metadata,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
)
|
||||
self.asset_router = sa.Table(
|
||||
'asset_router',
|
||||
metadata,
|
||||
autoload=True,
|
||||
autoload_with=engine,
|
||||
)
|
||||
|
||||
# Create the equity and future queries once.
|
||||
_equity_sid = equities.c.sid
|
||||
_equity_by_sid = sa.select(
|
||||
tuple(map(partial(getitem, equities.c), EQUITY_TABLE_FIELDS)),
|
||||
)
|
||||
table_names = ['equities', 'futures_exchanges', 'futures_root_symbols',
|
||||
'futures_contracts', 'asset_router']
|
||||
metadata.reflect(only=table_names)
|
||||
for table_name in table_names:
|
||||
setattr(self, table_name, metadata.tables[table_name])
|
||||
|
||||
def select_equity_by_sid(sid):
|
||||
return _equity_by_sid.where(_equity_sid == int(sid))
|
||||
|
||||
self.select_equity_by_sid = select_equity_by_sid
|
||||
|
||||
_future_sid = futures_contracts.c.sid
|
||||
_future_by_sid = sa.select(
|
||||
tuple(map(
|
||||
partial(getitem, futures_contracts.c),
|
||||
FUTURE_TABLE_FIELDS,
|
||||
)),
|
||||
)
|
||||
|
||||
def select_future_by_sid(sid):
|
||||
return _future_by_sid.where(_future_sid == int(sid))
|
||||
|
||||
self.select_future_by_sid = select_future_by_sid
|
||||
# Cache for lookup of assets by sid, the objects in the asset lookp may
|
||||
# be shared with the results from equity and future lookup caches.
|
||||
#
|
||||
@@ -203,51 +153,46 @@ class AssetFinder(object):
|
||||
raise SidNotFound(sid=sid)
|
||||
|
||||
def retrieve_all(self, sids, default_none=False):
|
||||
return [self.retrieve_asset(sid) for sid in sids]
|
||||
return [self.retrieve_asset(sid, default_none) for sid in sids]
|
||||
|
||||
def _retrieve_equity(self, sid):
|
||||
"""
|
||||
Retrieve the Equity object of a given sid.
|
||||
"""
|
||||
try:
|
||||
return self._equity_cache[sid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
data = self.select_equity_by_sid(sid).execute().fetchone()
|
||||
# Convert 'data' from a RowProxy object to a dict, to allow assignment
|
||||
data = dict(data.items())
|
||||
if data:
|
||||
_convert_asset_timestamp_fields(data)
|
||||
|
||||
equity = Equity(**data)
|
||||
else:
|
||||
equity = None
|
||||
|
||||
self._equity_cache[sid] = equity
|
||||
return equity
|
||||
return self._retrieve_asset(
|
||||
sid, self._equity_cache, self.equities, Equity,
|
||||
)
|
||||
|
||||
def _retrieve_futures_contract(self, sid):
|
||||
"""
|
||||
Retrieve the Future object of a given sid.
|
||||
"""
|
||||
return self._retrieve_asset(
|
||||
sid, self._future_cache, self.futures_contracts, Future,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_asset_by_sid(asset_tbl, sid):
|
||||
return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid))
|
||||
|
||||
def _retrieve_asset(self, sid, cache, asset_tbl, asset_type):
|
||||
try:
|
||||
return self._future_cache[sid]
|
||||
return cache[sid]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
data = self.select_future_by_sid(sid).execute().fetchone()
|
||||
data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone()
|
||||
# Convert 'data' from a RowProxy object to a dict, to allow assignment
|
||||
data = dict(data.items())
|
||||
if data:
|
||||
_convert_asset_timestamp_fields(data)
|
||||
|
||||
future = Future(**data)
|
||||
asset = asset_type(**data)
|
||||
else:
|
||||
future = None
|
||||
asset = None
|
||||
|
||||
self._future_cache[sid] = future
|
||||
return future
|
||||
cache[sid] = asset
|
||||
return asset
|
||||
|
||||
def lookup_symbol(self, symbol, as_of_date, fuzzy=False):
|
||||
"""
|
||||
@@ -558,7 +503,7 @@ class AssetFinder(object):
|
||||
|
||||
# Handle missing assets
|
||||
if len(missing) > 0:
|
||||
warnings.warn("Missing assets for identifiers: " + missing)
|
||||
warnings.warn("Missing assets for identifiers: %s" % missing)
|
||||
|
||||
# Return a list of the sids of the found assets
|
||||
return [asset.sid for asset in matches]
|
||||
|
||||
@@ -121,7 +121,10 @@ class TradingEnvironment(object):
|
||||
else:
|
||||
self.engine = engine = asset_db_path
|
||||
|
||||
self.asset_finder = AssetFinder(engine)
|
||||
if engine is not None:
|
||||
self.asset_finder = AssetFinder(engine)
|
||||
else:
|
||||
self.asset_finder = None
|
||||
|
||||
def write_data(self,
|
||||
engine=None,
|
||||
@@ -196,7 +199,7 @@ class TradingEnvironment(object):
|
||||
.write_all(self.engine, allow_sid_assignment=allow_sid_assignment)
|
||||
|
||||
def _write_data_dicts(self, equities=None, futures=None, exchanges=None,
|
||||
root_symbols=None, allow_sid_assignment=True):
|
||||
root_symbols=None):
|
||||
AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\
|
||||
.write_all(self.engine)
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ ctypedef object DatetimeIndex_t
|
||||
ctypedef object Int64Index_t
|
||||
|
||||
from zipline.lib.adjustment import Float64Multiply
|
||||
from zipline.assets.asset_writer import (
|
||||
SQLITE_MAX_VARIABLE_NUMBER as SQLITE_MAX_IN_STATEMENT,
|
||||
)
|
||||
|
||||
_SID_QUERY_TEMPLATE = """
|
||||
SELECT DISTINCT sid FROM {0}
|
||||
@@ -44,7 +47,6 @@ FROM {0}
|
||||
WHERE sid IN ({1}) AND effective_date >= {2} AND effective_date <= {3}
|
||||
"""
|
||||
|
||||
cdef int SQLITE_MAX_IN_STATEMENT = 999
|
||||
EPOCH = Timestamp(0, tz='UTC')
|
||||
|
||||
cdef set _get_sids_from_table(object db,
|
||||
|
||||
Reference in New Issue
Block a user