mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 09:59:46 +08:00
BLD: improved saving algo state
This commit is contained in:
@@ -42,6 +42,7 @@ from catalyst.exchange.simple_clock import SimpleClock
|
||||
from catalyst.exchange.stats_utils import get_pretty_stats, stats_to_s3, \
|
||||
stats_to_algo_folder
|
||||
from catalyst.finance.execution import MarketOrder
|
||||
from catalyst.finance.performance import PerformanceTracker
|
||||
from catalyst.finance.performance.period import calc_period_stats
|
||||
from catalyst.gens.tradesimulation import AlgorithmSimulator
|
||||
from catalyst.utils.api_support import api_method
|
||||
@@ -433,24 +434,37 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
|
||||
def _create_generator(self, sim_params):
|
||||
if self.perf_tracker is None:
|
||||
self.perf_tracker = get_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='perf_tracker'
|
||||
self.perf_tracker = PerformanceTracker(
|
||||
sim_params=self.sim_params,
|
||||
trading_calendar=self.trading_calendar,
|
||||
env=self.trading_environment,
|
||||
)
|
||||
|
||||
# Unpacking the perf_tracker and positions if available
|
||||
perf = get_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='perf_tracker',
|
||||
)
|
||||
if perf is not None:
|
||||
positions = get_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='positions',
|
||||
)
|
||||
self.perf_tracker.period_start = perf['period_start']
|
||||
self.perf_tracker.position_tracker.positions = positions
|
||||
|
||||
# Call the simulation trading algorithm for side-effects:
|
||||
# it creates the perf tracker
|
||||
TradingAlgorithm._create_generator(self, sim_params)
|
||||
self.trading_client = ExchangeAlgorithmExecutor(
|
||||
self,
|
||||
sim_params,
|
||||
self.data_portal,
|
||||
self.clock,
|
||||
self._create_benchmark_source(),
|
||||
self.restrictions,
|
||||
algo=self,
|
||||
sim_params=sim_params,
|
||||
data_portal=self.data_portal,
|
||||
clock=self.clock,
|
||||
benchmark_source=self._create_benchmark_source(),
|
||||
restrictions=self.restrictions,
|
||||
universe_func=self._calculate_universe
|
||||
)
|
||||
|
||||
return self.trading_client.transform()
|
||||
|
||||
def updated_portfolio(self):
|
||||
@@ -658,14 +672,19 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
log.warn('unable to calculate performance: {}'.format(e))
|
||||
|
||||
# TODO: pickle does not seem to work in python 3
|
||||
try:
|
||||
save_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='perf_tracker',
|
||||
obj=self.perf_tracker
|
||||
)
|
||||
except Exception as e:
|
||||
log.warn('unable to save minute perfs to disk: {}'.format(e))
|
||||
# try:
|
||||
save_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='perf_tracker',
|
||||
obj=self.perf_tracker.to_dict(emission_type=self.data_frequency),
|
||||
)
|
||||
save_algo_object(
|
||||
algo_name=self.algo_namespace,
|
||||
key='positions',
|
||||
obj=self.perf_tracker.position_tracker.positions,
|
||||
)
|
||||
# except Exception as e:
|
||||
# log.warn('unable to save perf_tracker to disk: {}'.format(e))
|
||||
|
||||
self.current_day = data.current_dt.floor('1D')
|
||||
|
||||
|
||||
+41
-24
@@ -1,16 +1,17 @@
|
||||
from logbook import Logger
|
||||
|
||||
from catalyst.constants import LOG_LEVEL
|
||||
from catalyst.errors import SidsNotFound
|
||||
from catalyst.exchange.factory import find_exchanges
|
||||
|
||||
import pandas as pd
|
||||
|
||||
log = Logger('AssetFinderExchange', level=LOG_LEVEL)
|
||||
log = Logger('ExchangeAssetFinder', level=LOG_LEVEL)
|
||||
|
||||
|
||||
class AssetFinderExchange(object):
|
||||
def __init__(self):
|
||||
self._asset_cache = {}
|
||||
class ExchangeAssetFinder(object):
|
||||
def __init__(self, exchanges):
|
||||
self.exchanges = exchanges
|
||||
|
||||
@property
|
||||
def sids(self):
|
||||
@@ -19,7 +20,33 @@ class AssetFinderExchange(object):
|
||||
I don't think that we need this for live-trading.
|
||||
Leaving the list empty.
|
||||
"""
|
||||
return list()
|
||||
all_sids = []
|
||||
for exchange_name in self.exchanges:
|
||||
# This is what initializes each exchanges at the beginning
|
||||
# of an algo
|
||||
exchange = self.exchanges[exchange_name]
|
||||
exchange.init()
|
||||
|
||||
all_sids += [asset.sid for asset in exchange.assets]
|
||||
|
||||
sids = list(set(all_sids))
|
||||
return sids
|
||||
|
||||
def retrieve_asset(self, sid, default_none=False):
|
||||
"""
|
||||
Retrieve the first Asset found for a given sid.
|
||||
"""
|
||||
asset = None
|
||||
for exchange_name in self.exchanges:
|
||||
if asset is not None:
|
||||
break
|
||||
|
||||
exchange = self.exchanges[exchange_name]
|
||||
assets = [asset for asset in exchange.assets if asset.sid == sid]
|
||||
if assets:
|
||||
asset = assets[0]
|
||||
|
||||
return asset
|
||||
|
||||
def retrieve_all(self, sids, default_none=False):
|
||||
"""
|
||||
@@ -44,12 +71,13 @@ class AssetFinderExchange(object):
|
||||
SidsNotFound
|
||||
When a requested sid is not found and default_none=False.
|
||||
"""
|
||||
# for sid in sids:
|
||||
# if sid in self._asset_cache:
|
||||
# log.debug('got asset from cache: {}'.format(sid))
|
||||
# else:
|
||||
# log.debug('fetching asset: {}'.format(sid))
|
||||
return list()
|
||||
assets = []
|
||||
for exchange_name in self.exchanges:
|
||||
exchange = self.exchanges[exchange_name]
|
||||
xas = [asset for asset in exchange.assets if asset.sid in sids]
|
||||
assets += xas
|
||||
|
||||
return assets
|
||||
|
||||
def lookup_symbol(self, symbol, exchange, data_frequency=None,
|
||||
as_of_date=None, fuzzy=False):
|
||||
@@ -88,18 +116,7 @@ class AssetFinderExchange(object):
|
||||
"""
|
||||
log.debug('looking up symbol: {} {}'.format(symbol, exchange.name))
|
||||
|
||||
if data_frequency is not None:
|
||||
key = ','.join([exchange.name, symbol, data_frequency])
|
||||
|
||||
else:
|
||||
key = ','.join([exchange.name, symbol])
|
||||
|
||||
if key in self._asset_cache:
|
||||
return self._asset_cache[key]
|
||||
else:
|
||||
asset = exchange.get_asset(symbol, data_frequency)
|
||||
self._asset_cache[key] = asset
|
||||
return asset
|
||||
return exchange.get_asset(symbol, data_frequency)
|
||||
|
||||
def lifetimes(self, dates, include_start_date):
|
||||
"""
|
||||
@@ -160,6 +177,6 @@ class AssetFinderExchange(object):
|
||||
data.append(exists)
|
||||
|
||||
sids = [asset.sid for asset in exchange.assets]
|
||||
df = pd.DataFrame(data, index=dates, columns=sids)
|
||||
df = pd.DataFrame(data, index=dates, columns=exchange.assets)
|
||||
|
||||
return df
|
||||
@@ -11,12 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from logbook import Logger
|
||||
from numpy import (
|
||||
iinfo,
|
||||
uint32,
|
||||
)
|
||||
|
||||
from catalyst.constants import LOG_LEVEL
|
||||
from catalyst.data.us_equity_pricing import BcolzDailyBarReader
|
||||
from catalyst.exchange.factory import get_exchange
|
||||
from catalyst.lib.adjusted_array import AdjustedArray
|
||||
from catalyst.errors import NoFurtherDataError
|
||||
from catalyst.pipeline.data import DataSet, Column
|
||||
@@ -26,6 +29,8 @@ from catalyst.utils.numpy_utils import float64_dtype
|
||||
|
||||
UINT32_MAX = iinfo(uint32).max
|
||||
|
||||
log = Logger('ExchangePriceLoader', level=LOG_LEVEL)
|
||||
|
||||
|
||||
class TradingPairPricing(DataSet):
|
||||
"""
|
||||
@@ -62,6 +67,7 @@ class ExchangePricingLoader(PipelineLoader):
|
||||
'Invalid data frequency: {}'.format(data_frequency)
|
||||
)
|
||||
|
||||
self.data_frequency = data_frequency
|
||||
self.raw_price_loader = reader
|
||||
self._columns = TradingPairPricing.columns
|
||||
self._all_sessions = all_sessions
|
||||
@@ -91,7 +97,21 @@ class ExchangePricingLoader(PipelineLoader):
|
||||
self._all_sessions, dates[0], dates[-1], shift=1,
|
||||
)
|
||||
colnames = [c.name for c in columns]
|
||||
raw_arrays = self.raw_price_loader.load_raw_arrays(
|
||||
|
||||
if len(assets) == 0:
|
||||
raise ValueError(
|
||||
'Pipeline cannot load data with eligible assets.'
|
||||
)
|
||||
|
||||
exchange_names = []
|
||||
for asset in assets:
|
||||
if asset.exchange not in exchange_names:
|
||||
exchange_names.append(asset.exchange)
|
||||
|
||||
exchange = get_exchange(exchange_names[0])
|
||||
reader = exchange.bundle.get_reader(self.data_frequency)
|
||||
|
||||
raw_arrays = reader.load_raw_arrays(
|
||||
colnames,
|
||||
start_date,
|
||||
end_date,
|
||||
|
||||
@@ -309,7 +309,8 @@ def get_algo_object(algo_name, key, environ=None, rel_path=None):
|
||||
return None
|
||||
|
||||
|
||||
def save_algo_object(algo_name, key, obj, environ=None, rel_path=None):
|
||||
def save_algo_object(algo_name, key, obj, environ=None, rel_path=None,
|
||||
how='pickle'):
|
||||
"""
|
||||
Serialize and save an object by algo name and key.
|
||||
|
||||
@@ -328,10 +329,15 @@ def save_algo_object(algo_name, key, obj, environ=None, rel_path=None):
|
||||
folder = os.path.join(folder, rel_path)
|
||||
ensure_directory(folder)
|
||||
|
||||
filename = os.path.join(folder, key + '.p')
|
||||
if how == 'json':
|
||||
filename = os.path.join(folder, '{}.json'.format(key))
|
||||
with open(filename, 'wt') as handle:
|
||||
json.dump(obj, handle, indent=4, default=symbols_serial)
|
||||
|
||||
with open(filename, 'wb') as handle:
|
||||
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
filename = os.path.join(folder, '{}.p'.format(key))
|
||||
with open(filename, 'wb') as handle:
|
||||
pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def get_algo_df(algo_name, key, environ=None, rel_path=None):
|
||||
|
||||
@@ -10,10 +10,15 @@ from catalyst.exchange.exchange_utils import get_exchange_auth, \
|
||||
get_exchange_folder, is_blacklist
|
||||
|
||||
log = Logger('factory', level=LOG_LEVEL)
|
||||
exchange_cache = dict()
|
||||
|
||||
|
||||
def get_exchange(exchange_name, base_currency=None, must_authenticate=False,
|
||||
skip_init=False):
|
||||
key = (exchange_name, base_currency)
|
||||
if key in exchange_cache:
|
||||
return exchange_cache[key]
|
||||
|
||||
exchange_auth = get_exchange_auth(exchange_name)
|
||||
|
||||
has_auth = (exchange_auth['key'] != '' and exchange_auth['secret'] != '')
|
||||
@@ -31,6 +36,7 @@ def get_exchange(exchange_name, base_currency=None, must_authenticate=False,
|
||||
secret=exchange_auth['secret'],
|
||||
base_currency=base_currency,
|
||||
)
|
||||
exchange_cache[key] = exchange
|
||||
|
||||
if not skip_init:
|
||||
exchange.init()
|
||||
|
||||
@@ -7,6 +7,7 @@ from abc import (
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import six
|
||||
from six import (
|
||||
iteritems,
|
||||
with_metaclass,
|
||||
@@ -485,8 +486,9 @@ class SimplePipelineEngine(PipelineEngine):
|
||||
|
||||
if isinstance(term, LoadableTerm):
|
||||
term_key = loader_group_key(term)
|
||||
# TODO: temp workaround
|
||||
to_load = sorted(
|
||||
loader_groups[term_key],
|
||||
six.next(six.itervalues(loader_groups)),
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = get_loader(term)
|
||||
|
||||
@@ -15,8 +15,6 @@ from catalyst.data.data_portal import DataPortal
|
||||
from catalyst.exchange.exchange_pricing_loader import ExchangePricingLoader, \
|
||||
TradingPairPricing
|
||||
from catalyst.exchange.factory import get_exchange
|
||||
from catalyst.pipeline import USEquityPricingLoader
|
||||
from catalyst.pipeline.data import USEquityPricing
|
||||
|
||||
try:
|
||||
from pygments import highlight
|
||||
@@ -41,7 +39,7 @@ from catalyst.exchange.exchange_algorithm import (
|
||||
)
|
||||
from catalyst.exchange.exchange_data_portal import DataPortalExchangeLive, \
|
||||
DataPortalExchangeBacktest
|
||||
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
|
||||
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
|
||||
from catalyst.exchange.exchange_errors import (
|
||||
ExchangeRequestError, ExchangeRequestErrorTooManyAttempts,
|
||||
BaseCurrencyNotFoundError, NotEnoughCapitalError)
|
||||
@@ -161,6 +159,7 @@ def _run(handle_data,
|
||||
exchange_name=exchange_name,
|
||||
base_currency=base_currency,
|
||||
must_authenticate=(live and not simulate_orders),
|
||||
skip_init=True,
|
||||
)
|
||||
|
||||
open_calendar = get_calendar('OPEN')
|
||||
@@ -176,7 +175,7 @@ def _run(handle_data,
|
||||
exchange_tz='UTC',
|
||||
asset_db_path=None # We don't need an asset db, we have exchanges
|
||||
)
|
||||
env.asset_finder = AssetFinderExchange()
|
||||
env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)
|
||||
|
||||
def choose_loader(column):
|
||||
bound_cols = TradingPairPricing.columns
|
||||
|
||||
@@ -2,7 +2,7 @@ import pandas as pd
|
||||
from logbook import Logger
|
||||
|
||||
from catalyst import get_calendar
|
||||
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
|
||||
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
|
||||
from catalyst.exchange.exchange_data_portal import (
|
||||
DataPortalExchangeBacktest,
|
||||
DataPortalExchangeLive
|
||||
@@ -20,7 +20,7 @@ class TestExchangeDataPortal:
|
||||
log.info('creating bitfinex exchange')
|
||||
exchanges = get_exchanges(['bitfinex', 'bittrex', 'poloniex'])
|
||||
open_calendar = get_calendar('OPEN')
|
||||
asset_finder = AssetFinderExchange()
|
||||
asset_finder = ExchangeAssetFinder()
|
||||
|
||||
self.data_portal_live = DataPortalExchangeLive(
|
||||
exchanges=exchanges,
|
||||
|
||||
@@ -5,7 +5,7 @@ from logbook import Logger
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
|
||||
from catalyst import get_calendar
|
||||
from catalyst.exchange.asset_finder_exchange import AssetFinderExchange
|
||||
from catalyst.exchange.exchange_asset_finder import ExchangeAssetFinder
|
||||
from catalyst.exchange.exchange_data_portal import DataPortalExchangeBacktest
|
||||
from catalyst.exchange.exchange_utils import get_candles_df
|
||||
from catalyst.exchange.factory import get_exchange
|
||||
@@ -24,7 +24,7 @@ class TestSuiteBundle:
|
||||
@staticmethod
|
||||
def get_data_portal(exchange_names):
|
||||
open_calendar = get_calendar('OPEN')
|
||||
asset_finder = AssetFinderExchange()
|
||||
asset_finder = ExchangeAssetFinder()
|
||||
|
||||
data_portal = DataPortalExchangeBacktest(
|
||||
exchange_names=exchange_names,
|
||||
|
||||
Reference in New Issue
Block a user