mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 19:58:17 +08:00
Merge branch 'run-algo-and-fix-124' of https://github.com/caioaao/catalyst into caioaao-run-algo-and-fix-124
This commit is contained in:
+266
-223
@@ -70,36 +70,7 @@ class _RunAlgoError(click.ClickException, ValueError):
|
||||
return self.pyfunc_msg
|
||||
|
||||
|
||||
def _run(handle_data,
|
||||
initialize,
|
||||
before_trading_start,
|
||||
analyze,
|
||||
algofile,
|
||||
algotext,
|
||||
defines,
|
||||
data_frequency,
|
||||
capital_base,
|
||||
data,
|
||||
bundle,
|
||||
bundle_timestamp,
|
||||
start,
|
||||
end,
|
||||
output,
|
||||
print_algo,
|
||||
local_namespace,
|
||||
environ,
|
||||
live,
|
||||
exchange,
|
||||
algo_namespace,
|
||||
base_currency,
|
||||
live_graph,
|
||||
analyze_live,
|
||||
simulate_orders,
|
||||
stats_output):
|
||||
"""Run a backtest for the given algorithm.
|
||||
|
||||
This is shared between the cli and :func:`catalyst.run_algo`.
|
||||
"""
|
||||
def _build_namespace(algotext, local_namespace, defines):
|
||||
if algotext is not None:
|
||||
if local_namespace:
|
||||
ip = get_ipython() # noqa
|
||||
@@ -113,173 +84,197 @@ def _run(handle_data,
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
'invalid define %r, should be of the form name=value' %
|
||||
assign,
|
||||
)
|
||||
assign)
|
||||
try:
|
||||
# evaluate in the same namespace so names may refer to
|
||||
# eachother
|
||||
namespace[name] = eval(value, namespace)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
'failed to execute definition for name %r: %s' % (name, e),
|
||||
)
|
||||
'failed to execute definition for name %r: %s' % (name, e))
|
||||
elif defines:
|
||||
raise _RunAlgoError(
|
||||
'cannot pass define without `algotext`',
|
||||
"cannot pass '-D' / '--define' without '-t' / '--algotext'",
|
||||
)
|
||||
"cannot pass '-D' / '--define' without '-t' / '--algotext'")
|
||||
else:
|
||||
namespace = {}
|
||||
if algofile is not None:
|
||||
algotext = algofile.read()
|
||||
|
||||
if print_algo:
|
||||
if PYGMENTS:
|
||||
highlight(
|
||||
algotext,
|
||||
PythonLexer(),
|
||||
TerminalFormatter(),
|
||||
outfile=sys.stdout,
|
||||
)
|
||||
else:
|
||||
click.echo(algotext)
|
||||
return namespace
|
||||
|
||||
mode = 'paper-trading' if simulate_orders else 'live-trading' \
|
||||
if live else 'backtest'
|
||||
log.info('running algo in {mode} mode'.format(mode=mode))
|
||||
|
||||
def _mode(simulate_orders, live):
|
||||
if not live:
|
||||
return 'backtest'
|
||||
elif simulate_orders:
|
||||
return 'paper-trading'
|
||||
else:
|
||||
return 'live-trading'
|
||||
|
||||
|
||||
def _build_exchanges_dict(exchange, live, simulate_orders, base_currency):
|
||||
exchange_name = exchange
|
||||
if exchange_name is None:
|
||||
raise ValueError('Please specify at least one exchange.')
|
||||
|
||||
exchange_list = [x.strip().lower() for x in exchange.split(',')]
|
||||
|
||||
exchanges = dict()
|
||||
for exchange_name in exchange_list:
|
||||
exchanges[exchange_name] = get_exchange(
|
||||
exchange_name=exchange_name,
|
||||
exchanges = {exchange_name: get_exchange(
|
||||
exchange_name=exchange_name,
|
||||
base_currency=base_currency,
|
||||
must_authenticate=(live and not simulate_orders))
|
||||
for exchange_name in exchange_list}
|
||||
|
||||
return exchanges
|
||||
|
||||
|
||||
def _pretty_print_code(algotext):
|
||||
if PYGMENTS:
|
||||
highlight(
|
||||
algotext,
|
||||
PythonLexer(),
|
||||
TerminalFormatter(),
|
||||
outfile=sys.stdout)
|
||||
else:
|
||||
click.echo(algotext)
|
||||
|
||||
|
||||
def _choose_loader(data_frequency, column):
|
||||
bound_cols = TradingPairPricing.columns
|
||||
if column in bound_cols:
|
||||
return ExchangePricingLoader(data_frequency)
|
||||
raise ValueError(
|
||||
"No PipelineLoader registered for column %s." % column)
|
||||
|
||||
|
||||
def _get_live_time_range():
|
||||
start = pd.Timestamp.utcnow()
|
||||
# TODO: fix the end data.
|
||||
end = start + timedelta(hours=8760)
|
||||
return start, end
|
||||
|
||||
|
||||
def _data_for_live_trading(sim_params, exchanges, env, open_calendar):
|
||||
data = DataPortalExchangeLive(
|
||||
exchanges=exchanges,
|
||||
asset_finder=env.asset_finder,
|
||||
trading_calendar=open_calendar,
|
||||
first_trading_day=pd.to_datetime('today', utc=True))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# TODO use proper retry here
|
||||
def _fetch_capital_base(base_currency, exchange_name, exchange,
|
||||
attempt_index=0):
|
||||
"""
|
||||
Fetch the base currency amount required to bootstrap
|
||||
the algorithm against the exchange.
|
||||
|
||||
The algorithm cannot continue without this value.
|
||||
|
||||
:param exchange: the targeted exchange
|
||||
:param attempt_index:
|
||||
:return capital_base: the amount of base currency available for
|
||||
trading
|
||||
"""
|
||||
try:
|
||||
log.debug('retrieving capital base in {} to bootstrap '
|
||||
'exchange {}'.format(base_currency, exchange_name))
|
||||
balances = exchange.get_balances()
|
||||
except ExchangeRequestError as e:
|
||||
if attempt_index < 20:
|
||||
log.warn(
|
||||
'could not retrieve balances on {}: {}'.format(
|
||||
exchange.name, e))
|
||||
sleep(5)
|
||||
return _fetch_capital_base(base_currency, exchange_name, exchange,
|
||||
attempt_index + 1)
|
||||
|
||||
else:
|
||||
raise ExchangeRequestErrorTooManyAttempts(
|
||||
attempts=attempt_index,
|
||||
error=e)
|
||||
|
||||
if base_currency in balances:
|
||||
base_currency_available = balances[base_currency]['free']
|
||||
log.info(
|
||||
'base currency available in the account: {} {}'.format(
|
||||
base_currency_available, base_currency))
|
||||
|
||||
return base_currency_available
|
||||
else:
|
||||
raise BaseCurrencyNotFoundError(
|
||||
base_currency=base_currency,
|
||||
must_authenticate=(live and not simulate_orders),
|
||||
skip_init=True,
|
||||
)
|
||||
exchange=exchange_name)
|
||||
|
||||
open_calendar = get_calendar('OPEN')
|
||||
|
||||
env = TradingEnvironment(
|
||||
load=partial(
|
||||
load_crypto_market_data,
|
||||
environ=environ,
|
||||
start_dt=start,
|
||||
end_dt=end
|
||||
),
|
||||
environ=environ,
|
||||
exchange_tz='UTC',
|
||||
asset_db_path=None # We don't need an asset db, we have exchanges
|
||||
)
|
||||
env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)
|
||||
def _algorithm_class_for_live(algo_namespace, live_graph, stats_output,
|
||||
analyze_live, base_currency, simulate_orders,
|
||||
exchanges, capital_base):
|
||||
if not simulate_orders:
|
||||
for exchange_name in exchanges:
|
||||
exchange = exchanges[exchange_name]
|
||||
balance = _fetch_capital_base(base_currency, exchange_name,
|
||||
exchange)
|
||||
|
||||
def choose_loader(column):
|
||||
bound_cols = TradingPairPricing.columns
|
||||
if column in bound_cols:
|
||||
return ExchangePricingLoader(data_frequency)
|
||||
raise ValueError(
|
||||
"No PipelineLoader registered for column %s." % column
|
||||
)
|
||||
|
||||
if live:
|
||||
start = pd.Timestamp.utcnow()
|
||||
|
||||
# TODO: fix the end data.
|
||||
end = start + timedelta(hours=8760)
|
||||
|
||||
data = DataPortalExchangeLive(
|
||||
exchanges=exchanges,
|
||||
asset_finder=env.asset_finder,
|
||||
trading_calendar=open_calendar,
|
||||
first_trading_day=pd.to_datetime('today', utc=True)
|
||||
)
|
||||
|
||||
def fetch_capital_base(exchange, attempt_index=0):
|
||||
"""
|
||||
Fetch the base currency amount required to bootstrap
|
||||
the algorithm against the exchange.
|
||||
|
||||
The algorithm cannot continue without this value.
|
||||
|
||||
:param exchange: the targeted exchange
|
||||
:param attempt_index:
|
||||
:return capital_base: the amount of base currency available for
|
||||
trading
|
||||
"""
|
||||
try:
|
||||
log.debug('retrieving capital base in {} to bootstrap '
|
||||
'exchange {}'.format(base_currency, exchange_name))
|
||||
balances = exchange.get_balances()
|
||||
except ExchangeRequestError as e:
|
||||
if attempt_index < 20:
|
||||
log.warn(
|
||||
'could not retrieve balances on {}: {}'.format(
|
||||
exchange.name, e
|
||||
)
|
||||
)
|
||||
sleep(5)
|
||||
return fetch_capital_base(exchange, attempt_index + 1)
|
||||
|
||||
else:
|
||||
raise ExchangeRequestErrorTooManyAttempts(
|
||||
attempts=attempt_index,
|
||||
error=e
|
||||
)
|
||||
|
||||
if base_currency in balances:
|
||||
base_currency_available = balances[base_currency]['free']
|
||||
log.info(
|
||||
'base currency available in the account: {} {}'.format(
|
||||
base_currency_available, base_currency
|
||||
)
|
||||
)
|
||||
|
||||
return base_currency_available
|
||||
else:
|
||||
raise BaseCurrencyNotFoundError(
|
||||
if balance < capital_base:
|
||||
raise NotEnoughCapitalError(
|
||||
exchange=exchange_name,
|
||||
base_currency=base_currency,
|
||||
exchange=exchange_name
|
||||
)
|
||||
balance=balance,
|
||||
capital_base=capital_base)
|
||||
|
||||
if not simulate_orders:
|
||||
for exchange_name in exchanges:
|
||||
exchange = exchanges[exchange_name]
|
||||
balance = fetch_capital_base(exchange)
|
||||
algorithm_class = partial(
|
||||
ExchangeTradingAlgorithmLive,
|
||||
exchanges=exchanges,
|
||||
algo_namespace=algo_namespace,
|
||||
live_graph=live_graph,
|
||||
simulate_orders=simulate_orders,
|
||||
stats_output=stats_output,
|
||||
analyze_live=analyze_live,)
|
||||
|
||||
if balance < capital_base:
|
||||
raise NotEnoughCapitalError(
|
||||
exchange=exchange_name,
|
||||
base_currency=base_currency,
|
||||
balance=balance,
|
||||
capital_base=capital_base,
|
||||
)
|
||||
return algorithm_class
|
||||
|
||||
sim_params = create_simulation_parameters(
|
||||
start=start,
|
||||
end=end,
|
||||
capital_base=capital_base,
|
||||
emission_rate='minute',
|
||||
data_frequency='minute'
|
||||
)
|
||||
|
||||
# TODO: use the constructor instead
|
||||
sim_params._arena = 'live'
|
||||
def _bundle_trading_environment(bundle_data, environ):
|
||||
prefix, connstr = re.split(
|
||||
r'sqlite:///',
|
||||
str(bundle_data.asset_finder.engine.url),
|
||||
maxsplit=1)
|
||||
if prefix:
|
||||
raise ValueError(
|
||||
"invalid url %r, must begin with 'sqlite:///'" %
|
||||
str(bundle_data.asset_finder.engine.url))
|
||||
|
||||
algorithm_class = partial(
|
||||
ExchangeTradingAlgorithmLive,
|
||||
exchanges=exchanges,
|
||||
algo_namespace=algo_namespace,
|
||||
live_graph=live_graph,
|
||||
simulate_orders=simulate_orders,
|
||||
stats_output=stats_output,
|
||||
analyze_live=analyze_live,
|
||||
)
|
||||
elif exchanges:
|
||||
return TradingEnvironment(asset_db_path=connstr, environ=environ)
|
||||
|
||||
|
||||
def _build_live_algo_and_data(sim_params, exchanges, env, open_calendar,
|
||||
simulate_orders, algo_namespace, capital_base,
|
||||
live_graph, stats_output, analyze_live,
|
||||
base_currency, namespace, choose_loader,
|
||||
algorithm_class_kwargs):
|
||||
sim_params._arena = 'live' # TODO: use the constructor instead
|
||||
|
||||
data = _data_for_live_trading(sim_params, exchanges, env, open_calendar)
|
||||
|
||||
algorithm_class = _algorithm_class_for_live(
|
||||
algo_namespace, live_graph, stats_output, analyze_live,
|
||||
base_currency, simulate_orders, exchanges, capital_base)
|
||||
|
||||
return data, algorithm_class(
|
||||
namespace=namespace,
|
||||
env=env,
|
||||
get_pipeline_loader=choose_loader,
|
||||
sim_params=sim_params,
|
||||
**algorithm_class_kwargs)
|
||||
|
||||
|
||||
def _build_backtest_algo_and_data(
|
||||
exchanges, bundle, env, environ, bundle_timestamp, open_calendar,
|
||||
start, end, namespace, choose_loader, sim_params,
|
||||
algorithm_class_kwargs):
|
||||
if exchanges:
|
||||
# Removed the existing Poloniex fork to keep things simple
|
||||
# We can add back the complexity if required.
|
||||
|
||||
@@ -293,41 +288,19 @@ def _run(handle_data,
|
||||
asset_finder=None,
|
||||
trading_calendar=open_calendar,
|
||||
first_trading_day=start,
|
||||
last_available_session=end
|
||||
)
|
||||
|
||||
sim_params = create_simulation_parameters(
|
||||
start=start,
|
||||
end=end,
|
||||
capital_base=capital_base,
|
||||
data_frequency=data_frequency,
|
||||
emission_rate=data_frequency,
|
||||
)
|
||||
last_available_session=end)
|
||||
|
||||
algorithm_class = partial(
|
||||
ExchangeTradingAlgorithmBacktest,
|
||||
exchanges=exchanges
|
||||
)
|
||||
|
||||
exchanges=exchanges)
|
||||
elif bundle is not None:
|
||||
bundle_data = load(
|
||||
bundle,
|
||||
environ,
|
||||
bundle_timestamp,
|
||||
)
|
||||
# TODO This branch should probably be removed or fixed: it doesn't even
|
||||
# build `algorithm_class`, so it will break when trying to instantiate
|
||||
# it.
|
||||
bundle_data = load(bundle, environ, bundle_timestamp)
|
||||
|
||||
prefix, connstr = re.split(
|
||||
r'sqlite:///',
|
||||
str(bundle_data.asset_finder.engine.url),
|
||||
maxsplit=1,
|
||||
)
|
||||
if prefix:
|
||||
raise ValueError(
|
||||
"invalid url %r, must begin with 'sqlite:///'" %
|
||||
str(bundle_data.asset_finder.engine.url),
|
||||
)
|
||||
env = _bundle_trading_environment(bundle_data, environ)
|
||||
|
||||
env = TradingEnvironment(asset_db_path=connstr, environ=environ)
|
||||
first_trading_day = \
|
||||
bundle_data.equity_minute_bar_reader.first_trading_day
|
||||
|
||||
@@ -336,27 +309,103 @@ def _run(handle_data,
|
||||
first_trading_day=first_trading_day,
|
||||
equity_minute_reader=bundle_data.equity_minute_bar_reader,
|
||||
equity_daily_reader=bundle_data.equity_daily_bar_reader,
|
||||
adjustment_reader=bundle_data.adjustment_reader,
|
||||
)
|
||||
adjustment_reader=bundle_data.adjustment_reader)
|
||||
|
||||
perf = algorithm_class(
|
||||
return data, algorithm_class(
|
||||
namespace=namespace,
|
||||
env=env,
|
||||
get_pipeline_loader=choose_loader,
|
||||
sim_params=sim_params,
|
||||
**{
|
||||
'initialize': initialize,
|
||||
'handle_data': handle_data,
|
||||
'before_trading_start': before_trading_start,
|
||||
'analyze': analyze,
|
||||
} if algotext is None else {
|
||||
'algo_filename': getattr(algofile, 'name', '<algorithm>'),
|
||||
'script': algotext,
|
||||
}
|
||||
).run(
|
||||
**algorithm_class_kwargs)
|
||||
|
||||
|
||||
def _build_algo_and_data(handle_data, initialize, before_trading_start,
|
||||
analyze, algofile, algotext, defines, data_frequency,
|
||||
capital_base, data, bundle, bundle_timestamp, start,
|
||||
end, output, print_algo, local_namespace, environ,
|
||||
live, exchange, algo_namespace, base_currency,
|
||||
live_graph, analyze_live, simulate_orders,
|
||||
stats_output):
|
||||
namespace = _build_namespace(algotext, local_namespace, defines)
|
||||
if algotext is not None:
|
||||
algotext = algofile.read()
|
||||
|
||||
if print_algo:
|
||||
_pretty_print_code(algotext)
|
||||
|
||||
mode = _mode(simulate_orders, live)
|
||||
log.info('running algo in {mode} mode'.format(mode=mode))
|
||||
|
||||
exchanges = _build_exchanges_dict(exchange, live, simulate_orders,
|
||||
base_currency)
|
||||
|
||||
open_calendar = get_calendar('OPEN')
|
||||
|
||||
env = TradingEnvironment(
|
||||
load=partial(load_crypto_market_data, environ=environ, start_dt=start,
|
||||
end_dt=end),
|
||||
environ=environ,
|
||||
exchange_tz='UTC',
|
||||
asset_db_path=None) # We don't need an asset db, we have exchanges
|
||||
|
||||
env.asset_finder = ExchangeAssetFinder(exchanges=exchanges)
|
||||
|
||||
choose_loader = partial(_choose_loader, data_frequency)
|
||||
|
||||
if live:
|
||||
start, end = _get_live_time_range()
|
||||
data_frequency = 'minute' # TODO double check if this is the desired behavior
|
||||
|
||||
sim_params = create_simulation_parameters(
|
||||
start=start,
|
||||
end=end,
|
||||
capital_base=capital_base,
|
||||
emission_rate=data_frequency,
|
||||
data_frequency=data_frequency)
|
||||
|
||||
if algotext is None:
|
||||
algorithm_class_kwargs = {'initialize': initialize,
|
||||
'handle_data': handle_data,
|
||||
'before_trading_start': before_trading_start,
|
||||
'analyze': analyze}
|
||||
else:
|
||||
algorithm_class_kwargs = {'algo_filename': getattr(algofile, 'name',
|
||||
'<algorithm>'),
|
||||
'script': algotext}
|
||||
|
||||
if live:
|
||||
return _build_live_algo_and_data(
|
||||
sim_params, exchanges, env, open_calendar, simulate_orders,
|
||||
algo_namespace, capital_base, live_graph, stats_output,
|
||||
analyze_live, base_currency, namespace, choose_loader,
|
||||
algorithm_class_kwargs)
|
||||
else:
|
||||
return _build_backtest_algo_and_data(
|
||||
exchanges, bundle, env, environ, bundle_timestamp, open_calendar,
|
||||
start, end, namespace, choose_loader, sim_params,
|
||||
algorithm_class_kwargs)
|
||||
|
||||
|
||||
def _run(handle_data, initialize, before_trading_start, analyze, algofile,
|
||||
algotext, defines, data_frequency, capital_base, data, bundle,
|
||||
bundle_timestamp, start, end, output, print_algo, local_namespace,
|
||||
environ, live, exchange, algo_namespace, base_currency, live_graph,
|
||||
analyze_live, simulate_orders, stats_output):
|
||||
"""Run an algorithm in backtest,
|
||||
paper-trading or live-trading mode.
|
||||
|
||||
This is shared between the cli and :func:`catalyst.run_algo`.
|
||||
"""
|
||||
|
||||
data, algorithm = _build_algo_and_data(
|
||||
handle_data, initialize, before_trading_start, analyze, algofile,
|
||||
algotext, defines, data_frequency, capital_base, data, bundle,
|
||||
bundle_timestamp, start, end, output, print_algo, local_namespace,
|
||||
environ, live, exchange, algo_namespace, base_currency, live_graph,
|
||||
analyze_live, simulate_orders, stats_output)
|
||||
perf = algorithm.run(
|
||||
data,
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
overwrite_sim_params=False)
|
||||
|
||||
if output == '-':
|
||||
click.echo(str(perf))
|
||||
@@ -413,8 +462,7 @@ def load_extensions(default, extensions, strict, environ, reload=False):
|
||||
# without `strict` we should just log the failure
|
||||
warnings.warn(
|
||||
'Failed to load extension: %r\n%s' % (ext, e),
|
||||
stacklevel=2
|
||||
)
|
||||
stacklevel=2)
|
||||
else:
|
||||
_loaded_extensions.add(ext)
|
||||
|
||||
@@ -513,8 +561,7 @@ def run_algorithm(initialize,
|
||||
catalyst.data.bundles.bundles : The available data bundles.
|
||||
"""
|
||||
load_extensions(
|
||||
default_extension, extensions, strict_extensions, environ
|
||||
)
|
||||
default_extension, extensions, strict_extensions, environ)
|
||||
|
||||
if capital_base is None:
|
||||
raise ValueError(
|
||||
@@ -522,8 +569,7 @@ def run_algorithm(initialize,
|
||||
'amount of base currency available for trading. For example, '
|
||||
'if the `capital_base` is 5ETH, the '
|
||||
'`order_target_percent(asset, 1)` command will order 5ETH worth '
|
||||
'of the specified asset.'
|
||||
)
|
||||
'of the specified asset.')
|
||||
# I'm not sure that we need this since the modified DataPortal
|
||||
# does not require extensions to be explicitly loaded.
|
||||
|
||||
@@ -541,13 +587,11 @@ def run_algorithm(initialize,
|
||||
elif len(non_none_data) != 1:
|
||||
raise ValueError(
|
||||
'must specify one of `data`, `data_portal`, or `bundle`,'
|
||||
' got: %r' % non_none_data,
|
||||
)
|
||||
' got: %r' % non_none_data)
|
||||
|
||||
elif 'bundle' not in non_none_data and bundle_timestamp is not None:
|
||||
raise ValueError(
|
||||
'cannot specify `bundle_timestamp` without passing `bundle`',
|
||||
)
|
||||
'cannot specify `bundle_timestamp` without passing `bundle`')
|
||||
return _run(
|
||||
handle_data=handle_data,
|
||||
initialize=initialize,
|
||||
@@ -574,5 +618,4 @@ def run_algorithm(initialize,
|
||||
live_graph=live_graph,
|
||||
analyze_live=analyze_live,
|
||||
simulate_orders=simulate_orders,
|
||||
stats_output=stats_output
|
||||
)
|
||||
stats_output=stats_output)
|
||||
|
||||
Reference in New Issue
Block a user