ENH: Replace update_asset_finder with write_data

The write_data methods invokes the relevant AssetDBWriter subclass
to write data to the database. update_asset_finder is no longer
a relevant method since the AssetFinder is strictly a reader class.
This commit is contained in:
Stewart Douglas
2015-08-08 23:10:40 -04:00
committed by jfkirk
parent 9660447ed0
commit 501fd58fdf
9 changed files with 76 additions and 73 deletions
+1 -1
View File
@@ -38,7 +38,7 @@ class BlotterTestCase(TestCase):
@with_environment()
def setUp(self, env=None):
setup_logger(self)
env.update_asset_finder(identifiers=[24])
env.write_data(equities_identifiers=[24])
def tearDown(self):
teardown_logger(self)
+5 -7
View File
@@ -1894,7 +1894,7 @@ class TestPerformanceTracker(unittest.TestCase):
foosid = 1
barsid = 2
env.update_asset_finder(identifiers=[foosid, barsid])
env.write_data(equities_identifiers=[foosid, barsid])
foo_event_1 = factory.create_trade(foosid, 10.0, 20, start_dt)
order_event_1 = Order(sid=foo_event_1.sid,
@@ -1989,7 +1989,7 @@ class TestPerformanceTracker(unittest.TestCase):
@with_environment()
def test_close_position_event(self, env=None):
env.update_asset_finder(identifiers=[1, 2])
env.write_data(equities_identifiers=[1, 2])
pt = perf.PositionTracker()
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(120.0),
@@ -2133,9 +2133,7 @@ class TestPositionTracker(unittest.TestCase):
2: {'asset_type': 'future',
'contract_multiplier': 1000}}
asset_finder = AssetFinder()
env.update_asset_finder(
asset_finder=asset_finder,
asset_metadata=metadata)
env.write_data(equities_data=metadata)
pt = perf.PositionTracker()
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(100.0),
@@ -2163,7 +2161,7 @@ class TestPositionTracker(unittest.TestCase):
'contract_multiplier': 1000},
4: {'asset_type': 'future',
'contract_multiplier': 1000}}
env.update_asset_finder(asset_metadata=metadata)
env.write_data(equities_data=metadata)
pt = perf.PositionTracker()
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(10.0),
@@ -2197,7 +2195,7 @@ class TestPositionTracker(unittest.TestCase):
metadata = {1: {'asset_type': 'equity'},
2: {'asset_type': 'future',
'contract_multiplier': 1000}}
env.update_asset_finder(asset_metadata=metadata)
env.write_data(equities_data=metadata)
pt = perf.PositionTracker()
dt = pd.Timestamp("1984/03/06 3:00PM")
pos1 = perf.Position(1, amount=np.float64(120.0),
+2 -3
View File
@@ -66,9 +66,8 @@ class SecurityListTestCase(TestCase):
self.trading_day_before_first_kd = datetime(
2015, 1, 23, 0, 0, tzinfo=pytz.utc)
env.update_asset_finder(
clear_metadata=True,
identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"]
env.write_data(
equities_identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"]
)
setup_logger(self)
+4 -4
View File
@@ -205,11 +205,11 @@ class TradingAlgorithm(object):
# Update the TradingEnvironment with the provided asset metadata
self.trading_environment = kwargs.pop('env',
TradingEnvironment.instance())
self.trading_environment.update_asset_finder(
asset_finder=kwargs.pop('asset_finder', None),
asset_metadata=kwargs.pop('asset_metadata', None),
identifiers=kwargs.pop('identifiers', None)
self.trading_environment.write_data(
equities_data=kwargs.pop('asset_metadata', None),
equities_identifiers=kwargs.pop('identifiers', None),
)
# Pull in the environment's new AssetFinder for quick reference
self.asset_finder = self.trading_environment.asset_finder
self.init_engine(kwargs.pop('ffc_loader', None))
-10
View File
@@ -243,16 +243,6 @@ AssetMetaData contained an invalid Asset type: '{asset_type}'.
""".strip()
class UpdateAssetFinderTypeError(ZiplineError):
"""
Raised when TradingEnvironment.update_asset_finder() gets an asset_finder
arg that is not of AssetFinder class.
"""
msg = """
TradingEnvironment can not set asset_finder to object of class {cls}.
""".strip()
class ConsumeAssetMetaDataError(ZiplineError):
"""
Raised when AssetFinder.consume() is called on an invalid object.
+58 -42
View File
@@ -22,6 +22,7 @@ import sqlite3
import pandas as pd
import numpy as np
from sqlalchemy import create_engine
# from multipledispatch import dispatch
from zipline.data.loader import load_market_data
from zipline.utils import tradingcalendar
@@ -31,8 +32,7 @@ from zipline.assets.asset_writer import (
AssetDBWriterFromDictionary,
AssetDBWriterFromDataFrame)
from zipline.errors import (
NoFurtherDataError,
UpdateAssetFinderTypeError,
NoFurtherDataError
)
@@ -157,51 +157,67 @@ class TradingEnvironment(object):
# stack.
return False
def update_asset_finder(self,
clear_metadata=False,
asset_finder=None,
asset_metadata=None,
identifiers=None):
def write_data(self,
engine=None,
equities_data={},
futures_data={},
exchanges_data={},
root_symbols_data={},
equities_identifiers=[],
futures_identifiers=[],
exchanges_identifiers=[],
root_symbols_identifiers=[]):
""" Write the supplied data to the database.
Parameters
----------
equities_data: dict
A dictionary of equity metadata
futures_data: dict
A dictionary of futures metadata
exchanges_data: dict
A dictionary of exchanges metadata
root_symbols_data: dict
A dictionary of root symbols metadata
equities_identifiers: list
A list of equities identifiers (sids or symbols)
futures_identifiers: list
A list of futures identifiers (sids or symbols)
exchanges_identifiers: list
A list of exchanges identifiers (ids or names)
root_symbols_identifiers: list
A list of root symbols identifiers (ids or symbols)
"""
Updates the AssetFinder using the provided asset metadata and
identifiers.
If clear_metadata is True, all metadata and assets held in the
asset_finder will be erased before new metadata is provided.
If asset_finder is provided, the existing asset_finder will be replaced
outright with the new asset_finder.
If asset_metadata is provided, the existing metadata will be cleared
and replaced with the provided metadata.
All identifiers will be inserted in the asset metadata if they are not
already present.
:param clear_metadata: A boolean
:param asset_finder: An AssetFinder object to replace the environment's
existing asset_finder
:param asset_metadata: A dict, DataFrame, or readable object
:param identifiers: A list of identifiers to be inserted
:return:
"""
if clear_metadata:
self.engine = create_engine('sqlite:///:memory:')
if engine:
self.engine = engine
if asset_finder is not None:
if not isinstance(asset_finder, AssetFinder):
raise UpdateAssetFinderTypeError(cls=asset_finder.__class__)
self.asset_finder = asset_finder
if (equities_data or futures_data or exchanges_data or
root_symbols_data):
self._write_data_dicts(equities_data, futures_data,
exchanges_data, root_symbols_data)
if asset_metadata is not None:
self.engine = create_engine('sqlite:///:memory:')
if isinstance(asset_metadata, dict):
asset_writer = AssetDBWriterFromDictionary(
equities=asset_metadata)
elif isinstance(asset_metadata, pd.DataFrame):
asset_writer = AssetDBWriterFromDataFrame(
equities=asset_metadata)
asset_writer.write_all(self.engine)
if (equities_identifiers or futures_identifiers or
exchanges_identifiers or root_symbols_identifiers):
self._write_data_lists(equities_identifiers,
futures_identifiers,
exchanges_identifiers,
root_symbols_identifiers)
if identifiers is not None:
asset_writer = AssetDBWriterFromList(equities=identifiers)
asset_writer.write_all(self.engine)
def _write_data_lists(self, equities=[], futures=[],
exchanges=[], root_symbols=[]):
AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\
.write_all(self.engine)
def _write_data_dicts(self, equities={}, futures={},
exchanges={}, root_symbols={}):
AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\
.write_all(self.engine)
def _write_data_dataframes(self, equities, futures,
exchanges, root_symbols):
AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\
.write_all(self.engine)
def normalize_date(self, test_date):
test_date = pd.Timestamp(test_date, tz='UTC')
+2 -2
View File
@@ -93,8 +93,8 @@ class RandomWalkSource(DataSource):
self.sd = sd
self.sids = self.start_prices.keys()
TradingEnvironment.instance().update_asset_finder(
identifiers=self.sids
TradingEnvironment.instance().write_data(
equities_identifiers=self.sids
)
self.open_and_closes = \
+2 -2
View File
@@ -136,7 +136,7 @@ class SpecificEquityTrades(object):
'sids',
set(event.sid for event in self.event_list)
)
env.update_asset_finder(identifiers=self.identifiers)
env.write_data(equities_identifiers=self.identifiers)
assets_by_identifier = {}
for identifier in self.identifiers:
assets_by_identifier[identifier] = env.asset_finder.\
@@ -160,7 +160,7 @@ class SpecificEquityTrades(object):
self.concurrent = kwargs.get('concurrent', False)
self.identifiers = kwargs.get('sids', [1, 2])
env.update_asset_finder(identifiers=self.identifiers)
env.write_data(equities_identifiers=self.identifiers)
assets_by_identifier = {}
for identifier in self.identifiers:
assets_by_identifier[identifier] = env.asset_finder.\
+2 -2
View File
@@ -119,7 +119,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params,
source_id="test_factory"):
trades = []
current = sim_params.first_open
trading.environment.update_asset_finder(identifiers=[sid])
trading.environment.write_data(equities_identifiers=[sid])
oneday = timedelta(days=1)
use_midnight = interval >= oneday
@@ -311,7 +311,7 @@ def create_test_df_source(sim_params=None, bars='daily'):
df = pd.DataFrame(x, index=index, columns=[0])
trading.environment.update_asset_finder(identifiers=[0])
trading.environment.write_data(equities_identifiers=[0])
return DataFrameSource(df), df