From e9714cfb32aac087ffbead47ae6b353d2b8f0653 Mon Sep 17 00:00:00 2001 From: Frederic Fortier Date: Sat, 23 Dec 2017 21:42:04 -0500 Subject: [PATCH] BLD: improved saving algo state --- catalyst/exchange/exchange_algorithm.py | 55 +++++++++++----- ...r_exchange.py => exchange_asset_finder.py} | 65 ++++++++++++------- catalyst/exchange/exchange_pricing_loader.py | 22 ++++++- catalyst/exchange/exchange_utils.py | 14 ++-- catalyst/exchange/factory.py | 6 ++ catalyst/pipeline/engine.py | 4 +- catalyst/utils/run_algo.py | 7 +- tests/exchange/test_data_portal.py | 4 +- tests/exchange/test_suite_bundle.py | 4 +- 9 files changed, 125 insertions(+), 56 deletions(-) rename catalyst/exchange/{asset_finder_exchange.py => exchange_asset_finder.py} (77%) diff --git a/catalyst/exchange/exchange_algorithm.py b/catalyst/exchange/exchange_algorithm.py index 0c602faf..f31cb7bd 100644 --- a/catalyst/exchange/exchange_algorithm.py +++ b/catalyst/exchange/exchange_algorithm.py @@ -42,6 +42,7 @@ from catalyst.exchange.simple_clock import SimpleClock from catalyst.exchange.stats_utils import get_pretty_stats, stats_to_s3, \ stats_to_algo_folder from catalyst.finance.execution import MarketOrder +from catalyst.finance.performance import PerformanceTracker from catalyst.finance.performance.period import calc_period_stats from catalyst.gens.tradesimulation import AlgorithmSimulator from catalyst.utils.api_support import api_method @@ -433,24 +434,37 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase): def _create_generator(self, sim_params): if self.perf_tracker is None: - self.perf_tracker = get_algo_object( - algo_name=self.algo_namespace, - key='perf_tracker' + self.perf_tracker = PerformanceTracker( + sim_params=self.sim_params, + trading_calendar=self.trading_calendar, + env=self.trading_environment, ) + # Unpacking the perf_tracker and positions if available + perf = get_algo_object( + algo_name=self.algo_namespace, + key='perf_tracker', + ) + if perf is not None: + positions = get_algo_object( + algo_name=self.algo_namespace, + key='positions', + ) + self.perf_tracker.period_start = perf['period_start'] + self.perf_tracker.position_tracker.positions = positions + # Call the simulation trading algorithm for side-effects: # it creates the perf tracker TradingAlgorithm._create_generator(self, sim_params) self.trading_client = ExchangeAlgorithmExecutor( - self, - sim_params, - self.data_portal, - self.clock, - self._create_benchmark_source(), - self.restrictions, + algo=self, + sim_params=sim_params, + data_portal=self.data_portal, + clock=self.clock, + benchmark_source=self._create_benchmark_source(), + restrictions=self.restrictions, universe_func=self._calculate_universe ) - return self.trading_client.transform() def updated_portfolio(self): @@ -658,14 +672,19 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase): log.warn('unable to calculate performance: {}'.format(e)) # TODO: pickle does not seem to work in python 3 - try: - save_algo_object( - algo_name=self.algo_namespace, - key='perf_tracker', - obj=self.perf_tracker - ) - except Exception as e: - log.warn('unable to save minute perfs to disk: {}'.format(e)) + # try: + save_algo_object( + algo_name=self.algo_namespace, + key='perf_tracker', + obj=self.perf_tracker.to_dict(emission_type=self.data_frequency), + ) + save_algo_object( + algo_name=self.algo_namespace, + key='positions', + obj=self.perf_tracker.position_tracker.positions, + ) + # except Exception as e: + # log.warn('unable to save perf_tracker to disk: {}'.format(e)) self.current_day = data.current_dt.floor('1D') diff --git a/catalyst/exchange/asset_finder_exchange.py b/catalyst/exchange/exchange_asset_finder.py similarity index 77% rename from catalyst/exchange/asset_finder_exchange.py rename to catalyst/exchange/exchange_asset_finder.py index 148aab2f..7ab00ef4 100644 --- a/catalyst/exchange/asset_finder_exchange.py +++ b/catalyst/exchange/exchange_asset_finder.py @@ -1,16 +1,17 @@ from logbook import Logger from catalyst.constants import LOG_LEVEL +from catalyst.errors import SidsNotFound from catalyst.exchange.factory import find_exchanges import pandas as pd -log = Logger('AssetFinderExchange', level=LOG_LEVEL) +log = Logger('ExchangeAssetFinder', level=LOG_LEVEL) -class AssetFinderExchange(object): - def __init__(self): - self._asset_cache = {} +class ExchangeAssetFinder(object): + def __init__(self, exchanges): + self.exchanges = exchanges @property def sids(self): @@ -19,7 +20,33 @@ class AssetFinderExchange(object): I don't think that we need this for live-trading. Leaving the list empty. """ - return list() + all_sids = [] + for exchange_name in self.exchanges: + # This is what initializes each exchanges at the beginning + # of an algo + exchange = self.exchanges[exchange_name] + exchange.init() + + all_sids += [asset.sid for asset in exchange.assets] + + sids = list(set(all_sids)) + return sids + + def retrieve_asset(self, sid, default_none=False): + """ + Retrieve the first Asset found for a given sid. + """ + asset = None + for exchange_name in self.exchanges: + if asset is not None: + break + + exchange = self.exchanges[exchange_name] + assets = [asset for asset in exchange.assets if asset.sid == sid] + if assets: + asset = assets[0] + + return asset def retrieve_all(self, sids, default_none=False): """ @@ -44,12 +71,13 @@ class AssetFinderExchange(object): SidsNotFound When a requested sid is not found and default_none=False. """ - # for sid in sids: - # if sid in self._asset_cache: - # log.debug('got asset from cache: {}'.format(sid)) - # else: - # log.debug('fetching asset: {}'.format(sid)) - return list() + assets = [] + for exchange_name in self.exchanges: + exchange = self.exchanges[exchange_name] + xas = [asset for asset in exchange.assets if asset.sid in sids] + assets += xas + + return assets def lookup_symbol(self, symbol, exchange, data_frequency=None, as_of_date=None, fuzzy=False): @@ -88,18 +116,7 @@ class AssetFinderExchange(object): """ log.debug('looking up symbol: {} {}'.format(symbol, exchange.name)) - if data_frequency is not None: - key = ','.join([exchange.name, symbol, data_frequency]) - - else: - key = ','.join([exchange.name, symbol]) - - if key in self._asset_cache: - return self._asset_cache[key] - else: - asset = exchange.get_asset(symbol, data_frequency) - self._asset_cache[key] = asset - return asset + return exchange.get_asset(symbol, data_frequency) def lifetimes(self, dates, include_start_date): """ @@ -160,6 +177,6 @@ class AssetFinderExchange(object): data.append(exists) sids = [asset.sid for asset in exchange.assets] - df = pd.DataFrame(data, index=dates, columns=sids) + df = pd.DataFrame(data, index=dates, columns=exchange.assets) return df diff --git a/catalyst/exchange/exchange_pricing_loader.py b/catalyst/exchange/exchange_pricing_loader.py index 3bf106e9..ed89c3c4 100644 --- a/catalyst/exchange/exchange_pricing_loader.py +++ b/catalyst/exchange/exchange_pricing_loader.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from logbook import Logger from numpy import ( iinfo, uint32, ) +from catalyst.constants import LOG_LEVEL from catalyst.data.us_equity_pricing import BcolzDailyBarReader +from catalyst.exchange.factory import get_exchange from catalyst.lib.adjusted_array import AdjustedArray from catalyst.errors import NoFurtherDataError from catalyst.pipeline.data import DataSet, Column @@ -26,6 +29,8 @@ from catalyst.utils.numpy_utils import float64_dtype UINT32_MAX = iinfo(uint32).max +log = Logger('ExchangePriceLoader', level=LOG_LEVEL) + class TradingPairPricing(DataSet): """ @@ -62,6 +67,7 @@ class ExchangePricingLoader(PipelineLoader): 'Invalid data frequency: {}'.format(data_frequency) ) + self.data_frequency = data_frequency self.raw_price_loader = reader self._columns = TradingPairPricing.columns self._all_sessions = all_sessions @@ -91,7 +97,21 @@ class ExchangePricingLoader(PipelineLoader): self._all_sessions, dates[0], dates[-1], shift=1, ) colnames = [c.name for c in columns] - raw_arrays = self.raw_price_loader.load_raw_arrays( + + if len(assets) == 0: + raise ValueError( + 'Pipeline cannot load data with eligible assets.' + ) + + exchange_names = [] + for asset in assets: + if asset.exchange not in exchange_names: + exchange_names.append(asset.exchange) + + exchange = get_exchange(exchange_names[0]) + reader = exchange.bundle.get_reader(self.data_frequency) + + raw_arrays = reader.load_raw_arrays( colnames, start_date, end_date, diff --git a/catalyst/exchange/exchange_utils.py b/catalyst/exchange/exchange_utils.py index a3f490f3..c73a1b35 100644 --- a/catalyst/exchange/exchange_utils.py +++ b/catalyst/exchange/exchange_utils.py @@ -309,7 +309,8 @@ def get_algo_object(algo_name, key, environ=None, rel_path=None): return None -def save_algo_object(algo_name, key, obj, environ=None, rel_path=None): +def save_algo_object(algo_name, key, obj, environ=None, rel_path=None, + how='pickle'): """ Serialize and save an object by algo name and key. @@ -328,10 +329,15 @@ def save_algo_object(algo_name, key, obj, environ=None, rel_path=None): folder = os.path.join(folder, rel_path) ensure_directory(folder) - filename = os.path.join(folder, key + '.p') + if how == 'json': + filename = os.path.join(folder, '{}.json'.format(key)) + with open(filename, 'wt') as handle: + json.dump(obj, handle, indent=4, default=symbols_serial) - with open(filename, 'wb') as handle: - pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL) + else: + filename = os.path.join(folder, '{}.p'.format(key)) + with open(filename, 'wb') as handle: + pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL) def get_algo_df(algo_name, key, environ=None, rel_path=None): diff --git a/catalyst/exchange/factory.py b/catalyst/exchange/factory.py index 7bc72f7a..8f8ecf34 100644 --- a/catalyst/exchange/factory.py +++ b/catalyst/exchange/factory.py @@ -10,10 +10,15 @@ from catalyst.exchange.exchange_utils import get_exchange_auth, \ get_exchange_folder, is_blacklist log = Logger('factory', level=LOG_LEVEL) +exchange_cache = dict() def get_exchange(exchange_name, base_currency=None, must_authenticate=False, skip_init=False): + key = (exchange_name, base_currency) + if key in exchange_cache: + return exchange_cache[key] + exchange_auth = get_exchange_auth(exchange_name) has_auth = (exchange_auth['key'] != '' and exchange_auth['secret'] != '') @@ -31,6 +36,7 @@ def get_exchange(exchange_name, base_currency=None, must_authenticate=False, secret=exchange_auth['secret'], base_currency=base_currency, ) + exchange_cache[key] = exchange if not skip_init: exchange.init() diff --git a/catalyst/pipeline/engine.py b/catalyst/pipeline/engine.py index 4a6538fc..4d7738fb 100644 --- a/catalyst/pipeline/engine.py +++ b/catalyst/pipeline/engine.py @@ -7,6 +7,7 @@ from abc import ( ) from uuid import uuid4 +import six from six import ( iteritems, with_metaclass, @@ -485,8 +486,9 @@ class SimplePipelineEngine(PipelineEngine): if isinstance(term, LoadableTerm): term_key = loader_group_key(term) + # TODO: temp workaround to_load = sorted( - loader_groups[term_key], + six.next(six.itervalues(loader_groups)), key=lambda t: t.dataset ) loader = get_loader(term) diff --git a/catalyst/utils/run_algo.py b/catalyst/utils/run_algo.py index 06c2e69b..4a5ba74e 100644 --- a/catalyst/utils/run_algo.py +++ b/catalyst/utils/run_algo.py @@ -15,8 +15,6 @@ from catalyst.data.data_portal import DataPortal from catalyst.exchange.exchange_pricing_loader import ExchangePricingLoader, \ TradingPairPricing from catalyst.exchange.factory import get_exchange -from catalyst.pipeline import USEquityPricingLoader -from catalyst.pipeline.data import USEquityPricing try: from pygments import highlight @@ -41,7 +39,7 @@ from catalyst.exchange.exchange_algorithm import ( ) from catalyst.exchange.exchange_data_portal import DataPortalExchangeLive, \ DataPortalExchangeBacktest -from catalyst.exchange.asset_finder_exchange import AssetFinderExchange +from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder from catalyst.exchange.exchange_errors import ( ExchangeRequestError, ExchangeRequestErrorTooManyAttempts, BaseCurrencyNotFoundError, NotEnoughCapitalError) @@ -161,6 +159,7 @@ def _run(handle_data, exchange_name=exchange_name, base_currency=base_currency, must_authenticate=(live and not simulate_orders), + skip_init=True, ) open_calendar = get_calendar('OPEN') @@ -176,7 +175,7 @@ def _run(handle_data, exchange_tz='UTC', asset_db_path=None # We don't need an asset db, we have exchanges ) - env.asset_finder = AssetFinderExchange() + env.asset_finder = ExchangeAssetFinder(exchanges=exchanges) def choose_loader(column): bound_cols = TradingPairPricing.columns diff --git a/tests/exchange/test_data_portal.py b/tests/exchange/test_data_portal.py index 29ef4d46..7febfcd5 100644 --- a/tests/exchange/test_data_portal.py +++ b/tests/exchange/test_data_portal.py @@ -2,7 +2,7 @@ import pandas as pd from logbook import Logger from catalyst import get_calendar -from catalyst.exchange.asset_finder_exchange import AssetFinderExchange +from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder from catalyst.exchange.exchange_data_portal import ( DataPortalExchangeBacktest, DataPortalExchangeLive @@ -20,7 +20,7 @@ class TestExchangeDataPortal: log.info('creating bitfinex exchange') exchanges = get_exchanges(['bitfinex', 'bittrex', 'poloniex']) open_calendar = get_calendar('OPEN') - asset_finder = AssetFinderExchange() + asset_finder = ExchangeAssetFinder() self.data_portal_live = DataPortalExchangeLive( exchanges=exchanges, diff --git a/tests/exchange/test_suite_bundle.py b/tests/exchange/test_suite_bundle.py index 5651d7f6..0a3a8796 100644 --- a/tests/exchange/test_suite_bundle.py +++ b/tests/exchange/test_suite_bundle.py @@ -5,7 +5,7 @@ from logbook import Logger from pandas.util.testing import assert_frame_equal from catalyst import get_calendar -from catalyst.exchange.asset_finder_exchange import AssetFinderExchange +from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder from catalyst.exchange.exchange_data_portal import DataPortalExchangeBacktest from catalyst.exchange.exchange_utils import get_candles_df from catalyst.exchange.factory import get_exchange @@ -24,7 +24,7 @@ class TestSuiteBundle: @staticmethod def get_data_portal(exchange_names): open_calendar = get_calendar('OPEN') - asset_finder = AssetFinderExchange() + asset_finder = ExchangeAssetFinder() data_portal = DataPortalExchangeBacktest( exchange_names=exchange_names,