BLD: improved saving algo state

This commit is contained in:
Frederic Fortier
2017-12-23 21:42:04 -05:00
parent 67bd5c8f6a
commit e9714cfb32
9 changed files with 125 additions and 56 deletions
+37 -18
View File
@@ -42,6 +42,7 @@ from catalyst.exchange.simple_clock import SimpleClock
from catalyst.exchange.stats_utils import get_pretty_stats, stats_to_s3, \
stats_to_algo_folder
from catalyst.finance.execution import MarketOrder
from catalyst.finance.performance import PerformanceTracker
from catalyst.finance.performance.period import calc_period_stats
from catalyst.gens.tradesimulation import AlgorithmSimulator
from catalyst.utils.api_support import api_method
@@ -433,24 +434,37 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
def _create_generator(self, sim_params):
if self.perf_tracker is None:
self.perf_tracker = get_algo_object(
algo_name=self.algo_namespace,
key='perf_tracker'
self.perf_tracker = PerformanceTracker(
sim_params=self.sim_params,
trading_calendar=self.trading_calendar,
env=self.trading_environment,
)
# Unpacking the perf_tracker and positions if available
perf = get_algo_object(
algo_name=self.algo_namespace,
key='perf_tracker',
)
if perf is not None:
positions = get_algo_object(
algo_name=self.algo_namespace,
key='positions',
)
self.perf_tracker.period_start = perf['period_start']
self.perf_tracker.position_tracker.positions = positions
# Call the simulation trading algorithm for side-effects:
# it creates the perf tracker
TradingAlgorithm._create_generator(self, sim_params)
self.trading_client = ExchangeAlgorithmExecutor(
self,
sim_params,
self.data_portal,
self.clock,
self._create_benchmark_source(),
self.restrictions,
algo=self,
sim_params=sim_params,
data_portal=self.data_portal,
clock=self.clock,
benchmark_source=self._create_benchmark_source(),
restrictions=self.restrictions,
universe_func=self._calculate_universe
)
return self.trading_client.transform()
def updated_portfolio(self):
@@ -658,14 +672,19 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
log.warn('unable to calculate performance: {}'.format(e))
# TODO: pickle does not seem to work in python 3
try:
save_algo_object(
algo_name=self.algo_namespace,
key='perf_tracker',
obj=self.perf_tracker
)
except Exception as e:
log.warn('unable to save minute perfs to disk: {}'.format(e))
# try:
save_algo_object(
algo_name=self.algo_namespace,
key='perf_tracker',
obj=self.perf_tracker.to_dict(emission_type=self.data_frequency),
)
save_algo_object(
algo_name=self.algo_namespace,
key='positions',
obj=self.perf_tracker.position_tracker.positions,
)
# except Exception as e:
# log.warn('unable to save perf_tracker to disk: {}'.format(e))
self.current_day = data.current_dt.floor('1D')
@@ -1,16 +1,17 @@
from logbook import Logger
from catalyst.constants import LOG_LEVEL
from catalyst.errors import SidsNotFound
from catalyst.exchange.factory import find_exchanges
import pandas as pd
log = Logger('AssetFinderExchange', level=LOG_LEVEL)
log = Logger('ExchangeAssetFinder', level=LOG_LEVEL)
class AssetFinderExchange(object):
def __init__(self):
self._asset_cache = {}
class ExchangeAssetFinder(object):
def __init__(self, exchanges):
self.exchanges = exchanges
@property
def sids(self):
@@ -19,7 +20,33 @@ class AssetFinderExchange(object):
I don't think that we need this for live-trading.
Leaving the list empty.
"""
return list()
all_sids = []
for exchange_name in self.exchanges:
# This is what initializes each exchanges at the beginning
# of an algo
exchange = self.exchanges[exchange_name]
exchange.init()
all_sids += [asset.sid for asset in exchange.assets]
sids = list(set(all_sids))
return sids
def retrieve_asset(self, sid, default_none=False):
"""
Retrieve the first Asset found for a given sid.
"""
asset = None
for exchange_name in self.exchanges:
if asset is not None:
break
exchange = self.exchanges[exchange_name]
assets = [asset for asset in exchange.assets if asset.sid == sid]
if assets:
asset = assets[0]
return asset
def retrieve_all(self, sids, default_none=False):
"""
@@ -44,12 +71,13 @@ class AssetFinderExchange(object):
SidsNotFound
When a requested sid is not found and default_none=False.
"""
# for sid in sids:
# if sid in self._asset_cache:
# log.debug('got asset from cache: {}'.format(sid))
# else:
# log.debug('fetching asset: {}'.format(sid))
return list()
assets = []
for exchange_name in self.exchanges:
exchange = self.exchanges[exchange_name]
xas = [asset for asset in exchange.assets if asset.sid in sids]
assets += xas
return assets
def lookup_symbol(self, symbol, exchange, data_frequency=None,
as_of_date=None, fuzzy=False):
@@ -88,18 +116,7 @@ class AssetFinderExchange(object):
"""
log.debug('looking up symbol: {} {}'.format(symbol, exchange.name))
if data_frequency is not None:
key = ','.join([exchange.name, symbol, data_frequency])
else:
key = ','.join([exchange.name, symbol])
if key in self._asset_cache:
return self._asset_cache[key]
else:
asset = exchange.get_asset(symbol, data_frequency)
self._asset_cache[key] = asset
return asset
return exchange.get_asset(symbol, data_frequency)
def lifetimes(self, dates, include_start_date):
"""
@@ -160,6 +177,6 @@ class AssetFinderExchange(object):
data.append(exists)
sids = [asset.sid for asset in exchange.assets]
df = pd.DataFrame(data, index=dates, columns=sids)
df = pd.DataFrame(data, index=dates, columns=exchange.assets)
return df
+21 -1
View File
@@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logbook import Logger
from numpy import (
iinfo,
uint32,
)
from catalyst.constants import LOG_LEVEL
from catalyst.data.us_equity_pricing import BcolzDailyBarReader
from catalyst.exchange.factory import get_exchange
from catalyst.lib.adjusted_array import AdjustedArray
from catalyst.errors import NoFurtherDataError
from catalyst.pipeline.data import DataSet, Column
@@ -26,6 +29,8 @@ from catalyst.utils.numpy_utils import float64_dtype
UINT32_MAX = iinfo(uint32).max
log = Logger('ExchangePriceLoader', level=LOG_LEVEL)
class TradingPairPricing(DataSet):
"""
@@ -62,6 +67,7 @@ class ExchangePricingLoader(PipelineLoader):
'Invalid data frequency: {}'.format(data_frequency)
)
self.data_frequency = data_frequency
self.raw_price_loader = reader
self._columns = TradingPairPricing.columns
self._all_sessions = all_sessions
@@ -91,7 +97,21 @@ class ExchangePricingLoader(PipelineLoader):
self._all_sessions, dates[0], dates[-1], shift=1,
)
colnames = [c.name for c in columns]
raw_arrays = self.raw_price_loader.load_raw_arrays(
if len(assets) == 0:
raise ValueError(
'Pipeline cannot load data with eligible assets.'
)
exchange_names = []
for asset in assets:
if asset.exchange not in exchange_names:
exchange_names.append(asset.exchange)
exchange = get_exchange(exchange_names[0])
reader = exchange.bundle.get_reader(self.data_frequency)
raw_arrays = reader.load_raw_arrays(
colnames,
start_date,
end_date,
+10 -4
View File
@@ -309,7 +309,8 @@ def get_algo_object(algo_name, key, environ=None, rel_path=None):
return None
def save_algo_object(algo_name, key, obj, environ=None, rel_path=None):
def save_algo_object(algo_name, key, obj, environ=None, rel_path=None,
how='pickle'):
"""
Serialize and save an object by algo name and key.
@@ -328,10 +329,15 @@ def save_algo_object(algo_name, key, obj, environ=None, rel_path=None):
folder = os.path.join(folder, rel_path)
ensure_directory(folder)
filename = os.path.join(folder, key + '.p')
if how == 'json':
filename = os.path.join(folder, '{}.json'.format(key))
with open(filename, 'wt') as handle:
json.dump(obj, handle, indent=4, default=symbols_serial)
with open(filename, 'wb') as handle:
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
filename = os.path.join(folder, '{}.p'.format(key))
with open(filename, 'wb') as handle:
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
def get_algo_df(algo_name, key, environ=None, rel_path=None):
+6
View File
@@ -10,10 +10,15 @@ from catalyst.exchange.exchange_utils import get_exchange_auth, \
get_exchange_folder, is_blacklist
log = Logger('factory', level=LOG_LEVEL)
exchange_cache = dict()
def get_exchange(exchange_name, base_currency=None, must_authenticate=False,
skip_init=False):
key = (exchange_name, base_currency)
if key in exchange_cache:
return exchange_cache[key]
exchange_auth = get_exchange_auth(exchange_name)
has_auth = (exchange_auth['key'] != '' and exchange_auth['secret'] != '')
@@ -31,6 +36,7 @@ def get_exchange(exchange_name, base_currency=None, must_authenticate=False,
secret=exchange_auth['secret'],
base_currency=base_currency,
)
exchange_cache[key] = exchange
if not skip_init:
exchange.init()
+3 -1
View File
@@ -7,6 +7,7 @@ from abc import (
)
from uuid import uuid4
import six
from six import (
iteritems,
with_metaclass,
@@ -485,8 +486,9 @@ class SimplePipelineEngine(PipelineEngine):
if isinstance(term, LoadableTerm):
term_key = loader_group_key(term)
# TODO: temp workaround
to_load = sorted(
loader_groups[term_key],
six.next(six.itervalues(loader_groups)),
key=lambda t: t.dataset
)
loader = get_loader(term)
+3 -4
View File
@@ -15,8 +15,6 @@ from catalyst.data.data_portal import DataPortal
from catalyst.exchange.exchange_pricing_loader import ExchangePricingLoader, \
TradingPairPricing
from catalyst.exchange.factory import get_exchange
from catalyst.pipeline import USEquityPricingLoader
from catalyst.pipeline.data import USEquityPricing
try:
from pygments import highlight
@@ -41,7 +39,7 @@ from catalyst.exchange.exchange_algorithm import (
)
from catalyst.exchange.exchange_data_portal import DataPortalExchangeLive, \
DataPortalExchangeBacktest
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
from catalyst.exchange.exchange_errors import (
ExchangeRequestError, ExchangeRequestErrorTooManyAttempts,
BaseCurrencyNotFoundError, NotEnoughCapitalError)
@@ -161,6 +159,7 @@ def _run(handle_data,
exchange_name=exchange_name,
base_currency=base_currency,
must_authenticate=(live and not simulate_orders),
skip_init=True,
)
open_calendar = get_calendar('OPEN')
@@ -176,7 +175,7 @@ def _run(handle_data,
exchange_tz='UTC',
asset_db_path=None # We don't need an asset db, we have exchanges
)
env.asset_finder = AssetFinderExchange()
env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)
def choose_loader(column):
bound_cols = TradingPairPricing.columns
+2 -2
View File
@@ -2,7 +2,7 @@ import pandas as pd
from logbook import Logger
from catalyst import get_calendar
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
from catalyst.exchange.exchange_data_portal import (
DataPortalExchangeBacktest,
DataPortalExchangeLive
@@ -20,7 +20,7 @@ class TestExchangeDataPortal:
log.info('creating bitfinex exchange')
exchanges = get_exchanges(['bitfinex', 'bittrex', 'poloniex'])
open_calendar = get_calendar('OPEN')
asset_finder = AssetFinderExchange()
asset_finder = ExchangeAssetFinder()
self.data_portal_live = DataPortalExchangeLive(
exchanges=exchanges,
+2 -2
View File
@@ -5,7 +5,7 @@ from logbook import Logger
from pandas.util.testing import assert_frame_equal
from catalyst import get_calendar
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
from catalyst.exchange.exchange_data_portal import DataPortalExchangeBacktest
from catalyst.exchange.exchange_utils import get_candles_df
from catalyst.exchange.factory import get_exchange
@@ -24,7 +24,7 @@ class TestSuiteBundle:
@staticmethod
def get_data_portal(exchange_names):
open_calendar = get_calendar('OPEN')
asset_finder = AssetFinderExchange()
asset_finder = ExchangeAssetFinder()
data_portal = DataPortalExchangeBacktest(
exchange_names=exchange_names,