From 501fd58fdfa6f05d512a735830791fbcff6cef36 Mon Sep 17 00:00:00 2001 From: Stewart Douglas Date: Sat, 8 Aug 2015 23:10:40 -0400 Subject: [PATCH] ENH: Replace update_asset_finder with write_data The write_data methods invokes the relevant AssetDBWriter subclass to write data to the database. update_asset_finder is no longer a relevant method since the AssetFinder is strictly a reader class. --- tests/test_blotter.py | 2 +- tests/test_perf_tracking.py | 12 ++-- tests/test_security_list.py | 5 +- zipline/algorithm.py | 8 +-- zipline/errors.py | 10 ---- zipline/finance/trading.py | 100 +++++++++++++++++++-------------- zipline/sources/simulated.py | 4 +- zipline/sources/test_source.py | 4 +- zipline/utils/factory.py | 4 +- 9 files changed, 76 insertions(+), 73 deletions(-) diff --git a/tests/test_blotter.py b/tests/test_blotter.py index 71c3d144..89226339 100644 --- a/tests/test_blotter.py +++ b/tests/test_blotter.py @@ -38,7 +38,7 @@ class BlotterTestCase(TestCase): @with_environment() def setUp(self, env=None): setup_logger(self) - env.update_asset_finder(identifiers=[24]) + env.write_data(equities_identifiers=[24]) def tearDown(self): teardown_logger(self) diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index a1f3b747..c04dc57e 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -1894,7 +1894,7 @@ class TestPerformanceTracker(unittest.TestCase): foosid = 1 barsid = 2 - env.update_asset_finder(identifiers=[foosid, barsid]) + env.write_data(equities_identifiers=[foosid, barsid]) foo_event_1 = factory.create_trade(foosid, 10.0, 20, start_dt) order_event_1 = Order(sid=foo_event_1.sid, @@ -1989,7 +1989,7 @@ class TestPerformanceTracker(unittest.TestCase): @with_environment() def test_close_position_event(self, env=None): - env.update_asset_finder(identifiers=[1, 2]) + env.write_data(equities_identifiers=[1, 2]) pt = perf.PositionTracker() dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(120.0), @@ -2133,9 +2133,7 @@ class TestPositionTracker(unittest.TestCase): 2: {'asset_type': 'future', 'contract_multiplier': 1000}} asset_finder = AssetFinder() - env.update_asset_finder( - asset_finder=asset_finder, - asset_metadata=metadata) + env.write_data(equities_data=metadata) pt = perf.PositionTracker() dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(100.0), @@ -2163,7 +2161,7 @@ class TestPositionTracker(unittest.TestCase): 'contract_multiplier': 1000}, 4: {'asset_type': 'future', 'contract_multiplier': 1000}} - env.update_asset_finder(asset_metadata=metadata) + env.write_data(equities_data=metadata) pt = perf.PositionTracker() dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(10.0), @@ -2197,7 +2195,7 @@ class TestPositionTracker(unittest.TestCase): metadata = {1: {'asset_type': 'equity'}, 2: {'asset_type': 'future', 'contract_multiplier': 1000}} - env.update_asset_finder(asset_metadata=metadata) + env.write_data(equities_data=metadata) pt = perf.PositionTracker() dt = pd.Timestamp("1984/03/06 3:00PM") pos1 = perf.Position(1, amount=np.float64(120.0), diff --git a/tests/test_security_list.py b/tests/test_security_list.py index a72ebff3..00b6a08a 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -66,9 +66,8 @@ class SecurityListTestCase(TestCase): self.trading_day_before_first_kd = datetime( 2015, 1, 23, 0, 0, tzinfo=pytz.utc) - env.update_asset_finder( - clear_metadata=True, - identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"] + env.write_data( + equities_identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"] ) setup_logger(self) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 39389b30..cf4bb5c9 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -205,11 +205,11 @@ class TradingAlgorithm(object): # Update the TradingEnvironment with the provided asset metadata self.trading_environment = kwargs.pop('env', TradingEnvironment.instance()) - self.trading_environment.update_asset_finder( - asset_finder=kwargs.pop('asset_finder', None), - asset_metadata=kwargs.pop('asset_metadata', None), - identifiers=kwargs.pop('identifiers', None) + self.trading_environment.write_data( + equities_data=kwargs.pop('asset_metadata', None), + equities_identifiers=kwargs.pop('identifiers', None), ) + # Pull in the environment's new AssetFinder for quick reference self.asset_finder = self.trading_environment.asset_finder self.init_engine(kwargs.pop('ffc_loader', None)) diff --git a/zipline/errors.py b/zipline/errors.py index 3e9a9d60..b8941934 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -243,16 +243,6 @@ AssetMetaData contained an invalid Asset type: '{asset_type}'. """.strip() -class UpdateAssetFinderTypeError(ZiplineError): - """ - Raised when TradingEnvironment.update_asset_finder() gets an asset_finder - arg that is not of AssetFinder class. - """ - msg = """ -TradingEnvironment can not set asset_finder to object of class {cls}. -""".strip() - - class ConsumeAssetMetaDataError(ZiplineError): """ Raised when AssetFinder.consume() is called on an invalid object. diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 3b0d4af7..e6e886d3 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -22,6 +22,7 @@ import sqlite3 import pandas as pd import numpy as np from sqlalchemy import create_engine +# from multipledispatch import dispatch from zipline.data.loader import load_market_data from zipline.utils import tradingcalendar @@ -31,8 +32,7 @@ from zipline.assets.asset_writer import ( AssetDBWriterFromDictionary, AssetDBWriterFromDataFrame) from zipline.errors import ( - NoFurtherDataError, - UpdateAssetFinderTypeError, + NoFurtherDataError ) @@ -157,51 +157,67 @@ class TradingEnvironment(object): # stack. return False - def update_asset_finder(self, - clear_metadata=False, - asset_finder=None, - asset_metadata=None, - identifiers=None): + def write_data(self, + engine=None, + equities_data={}, + futures_data={}, + exchanges_data={}, + root_symbols_data={}, + equities_identifiers=[], + futures_identifiers=[], + exchanges_identifiers=[], + root_symbols_identifiers=[]): + """ Write the supplied data to the database. + + Parameters + ---------- + equities_data: dict + A dictionary of equity metadata + futures_data: dict + A dictionary of futures metadata + exchanges_data: dict + A dictionary of exchanges metadata + root_symbols_data: dict + A dictionary of root symbols metadata + equities_identifiers: list + A list of equities identifiers (sids or symbols) + futures_identifiers: list + A list of futures identifiers (sids or symbols) + exchanges_identifiers: list + A list of exchanges identifiers (ids or names) + root_symbols_identifiers: list + A list of root symbols identifiers (ids or symbols) """ - Updates the AssetFinder using the provided asset metadata and - identifiers. - If clear_metadata is True, all metadata and assets held in the - asset_finder will be erased before new metadata is provided. - If asset_finder is provided, the existing asset_finder will be replaced - outright with the new asset_finder. - If asset_metadata is provided, the existing metadata will be cleared - and replaced with the provided metadata. - All identifiers will be inserted in the asset metadata if they are not - already present. - :param clear_metadata: A boolean - :param asset_finder: An AssetFinder object to replace the environment's - existing asset_finder - :param asset_metadata: A dict, DataFrame, or readable object - :param identifiers: A list of identifiers to be inserted - :return: - """ - if clear_metadata: - self.engine = create_engine('sqlite:///:memory:') + if engine: + self.engine = engine - if asset_finder is not None: - if not isinstance(asset_finder, AssetFinder): - raise UpdateAssetFinderTypeError(cls=asset_finder.__class__) - self.asset_finder = asset_finder + if (equities_data or futures_data or exchanges_data or + root_symbols_data): + self._write_data_dicts(equities_data, futures_data, + exchanges_data, root_symbols_data) - if asset_metadata is not None: - self.engine = create_engine('sqlite:///:memory:') - if isinstance(asset_metadata, dict): - asset_writer = AssetDBWriterFromDictionary( - equities=asset_metadata) - elif isinstance(asset_metadata, pd.DataFrame): - asset_writer = AssetDBWriterFromDataFrame( - equities=asset_metadata) - asset_writer.write_all(self.engine) + if (equities_identifiers or futures_identifiers or + exchanges_identifiers or root_symbols_identifiers): + self._write_data_lists(equities_identifiers, + futures_identifiers, + exchanges_identifiers, + root_symbols_identifiers) - if identifiers is not None: - asset_writer = AssetDBWriterFromList(equities=identifiers) - asset_writer.write_all(self.engine) + def _write_data_lists(self, equities=[], futures=[], + exchanges=[], root_symbols=[]): + AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\ + .write_all(self.engine) + + def _write_data_dicts(self, equities={}, futures={}, + exchanges={}, root_symbols={}): + AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\ + .write_all(self.engine) + + def _write_data_dataframes(self, equities, futures, + exchanges, root_symbols): + AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\ + .write_all(self.engine) def normalize_date(self, test_date): test_date = pd.Timestamp(test_date, tz='UTC') diff --git a/zipline/sources/simulated.py b/zipline/sources/simulated.py index d0a9c63d..66c4f679 100644 --- a/zipline/sources/simulated.py +++ b/zipline/sources/simulated.py @@ -93,8 +93,8 @@ class RandomWalkSource(DataSource): self.sd = sd self.sids = self.start_prices.keys() - TradingEnvironment.instance().update_asset_finder( - identifiers=self.sids + TradingEnvironment.instance().write_data( + equities_identifiers=self.sids ) self.open_and_closes = \ diff --git a/zipline/sources/test_source.py b/zipline/sources/test_source.py index f91b42e1..512309cd 100644 --- a/zipline/sources/test_source.py +++ b/zipline/sources/test_source.py @@ -136,7 +136,7 @@ class SpecificEquityTrades(object): 'sids', set(event.sid for event in self.event_list) ) - env.update_asset_finder(identifiers=self.identifiers) + env.write_data(equities_identifiers=self.identifiers) assets_by_identifier = {} for identifier in self.identifiers: assets_by_identifier[identifier] = env.asset_finder.\ @@ -160,7 +160,7 @@ class SpecificEquityTrades(object): self.concurrent = kwargs.get('concurrent', False) self.identifiers = kwargs.get('sids', [1, 2]) - env.update_asset_finder(identifiers=self.identifiers) + env.write_data(equities_identifiers=self.identifiers) assets_by_identifier = {} for identifier in self.identifiers: assets_by_identifier[identifier] = env.asset_finder.\ diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index 915efe6f..17fca76e 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -119,7 +119,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params, source_id="test_factory"): trades = [] current = sim_params.first_open - trading.environment.update_asset_finder(identifiers=[sid]) + trading.environment.write_data(equities_identifiers=[sid]) oneday = timedelta(days=1) use_midnight = interval >= oneday @@ -311,7 +311,7 @@ def create_test_df_source(sim_params=None, bars='daily'): df = pd.DataFrame(x, index=index, columns=[0]) - trading.environment.update_asset_finder(identifiers=[0]) + trading.environment.write_data(equities_identifiers=[0]) return DataFrameSource(df), df