mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 16:52:26 +08:00
Got trading algorithm and related pieces
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datetime import time
|
||||
import logbook
|
||||
|
||||
import catalyst.protocol as zp
|
||||
from catalyst.algorithm import TradingAlgorithm
|
||||
from catalyst.exchange.exchange_clock import ExchangeClock
|
||||
from catalyst.gens.tradesimulation import AlgorithmSimulator
|
||||
from catalyst.errors import OrderInBeforeTradingStart
|
||||
from catalyst.utils.input_validation import error_keywords
|
||||
from catalyst.utils.api_support import (
|
||||
api_method,
|
||||
disallowed_in_before_trading_start)
|
||||
|
||||
from catalyst.utils.calendars.trading_calendar import days_at_time
|
||||
|
||||
log = logbook.Logger("ExchangeTradingAlgorithm")
|
||||
|
||||
|
||||
class ExchangeAlgorithmExecutor(AlgorithmSimulator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(self.__class__, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ExchangeTradingAlgorithm(TradingAlgorithm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.exchange = kwargs.pop('exchange', None)
|
||||
self.orders = {}
|
||||
|
||||
super(self.__class__, self).__init__(*args, **kwargs)
|
||||
|
||||
log.info("initialization done")
|
||||
|
||||
def _create_clock(self):
|
||||
# This method is taken from TradingAlgorithm.
|
||||
# The clock has been replaced to use RealtimeClock
|
||||
trading_o_and_c = self.trading_calendar.schedule.ix[
|
||||
self.sim_params.sessions]
|
||||
market_closes = trading_o_and_c['market_close']
|
||||
minutely_emission = False
|
||||
|
||||
if self.sim_params.data_frequency == 'minute':
|
||||
market_opens = trading_o_and_c['market_open']
|
||||
|
||||
minutely_emission = self.sim_params.emission_rate == "minute"
|
||||
else:
|
||||
# in daily mode, we want to have one bar per session, timestamped
|
||||
# as the last minute of the session.
|
||||
market_opens = market_closes
|
||||
|
||||
# The calendar's execution times are the minutes over which we actually
|
||||
# want to run the clock. Typically the execution times simply adhere to
|
||||
# the market open and close times. In the case of the futures calendar,
|
||||
# for example, we only want to simulate over a subset of the full 24
|
||||
# hour calendar, so the execution times dictate a market open time of
|
||||
# 6:31am US/Eastern and a close of 5:00pm US/Eastern.
|
||||
execution_opens = \
|
||||
self.trading_calendar.execution_time_from_open(market_opens)
|
||||
execution_closes = \
|
||||
self.trading_calendar.execution_time_from_close(market_closes)
|
||||
|
||||
# FIXME generalize these values
|
||||
before_trading_start_minutes = days_at_time(
|
||||
self.sim_params.sessions,
|
||||
time(8, 45),
|
||||
"US/Eastern"
|
||||
)
|
||||
|
||||
return ExchangeClock(
|
||||
self.sim_params.sessions,
|
||||
execution_opens,
|
||||
execution_closes,
|
||||
before_trading_start_minutes,
|
||||
minute_emission=minutely_emission,
|
||||
time_skew=self.exchange.time_skew
|
||||
)
|
||||
|
||||
def _create_generator(self, sim_params):
|
||||
# Call the simulation trading algorithm for side-effects:
|
||||
# it creates the perf tracker
|
||||
TradingAlgorithm._create_generator(self, sim_params)
|
||||
self.trading_client = ExchangeAlgorithmExecutor(
|
||||
self,
|
||||
sim_params,
|
||||
self.data_portal,
|
||||
self._create_clock(),
|
||||
self._create_benchmark_source(),
|
||||
self.restrictions,
|
||||
universe_func=self._calculate_universe
|
||||
)
|
||||
|
||||
return self.trading_client.transform()
|
||||
|
||||
def updated_portfolio(self):
|
||||
return self.exchange.portfolio
|
||||
|
||||
def updated_account(self):
|
||||
return self.exchange.account
|
||||
|
||||
@api_method
|
||||
@disallowed_in_before_trading_start(OrderInBeforeTradingStart())
|
||||
def order(self,
|
||||
asset,
|
||||
amount,
|
||||
limit_price=None,
|
||||
stop_price=None,
|
||||
style=None):
|
||||
amount, style = self._calculate_order(asset, amount,
|
||||
limit_price, stop_price, style)
|
||||
|
||||
return self.exchange.order(asset, amount, limit_price, stop_price, style)
|
||||
|
||||
@api_method
|
||||
def batch_market_order(self, share_counts):
|
||||
raise NotImplementedError()
|
||||
|
||||
@error_keywords(sid='Keyword argument `sid` is no longer supported for '
|
||||
'get_open_orders. Use `asset` instead.')
|
||||
@api_method
|
||||
def get_open_orders(self, asset=None):
|
||||
return self.exchange.get_open_orders(asset)
|
||||
|
||||
@api_method
|
||||
def get_order(self, order_id):
|
||||
return self.exchange.get_order(order_id)
|
||||
|
||||
@api_method
|
||||
def cancel_order(self, order_param):
|
||||
order_id = order_param
|
||||
if isinstance(order_param, zp.Order):
|
||||
order_id = order_param.id
|
||||
self.exchange.cancel_order(order_id)
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
import time
|
||||
import requests
|
||||
import pandas as pd
|
||||
import collections
|
||||
from catalyst.protocol import Portfolio, Account
|
||||
# from websocket import create_connection
|
||||
from catalyst.exchange.exchange import Exchange
|
||||
@@ -15,6 +16,7 @@ from catalyst.finance.execution import (MarketOrder,
|
||||
LimitOrder,
|
||||
StopOrder,
|
||||
StopLimitOrder)
|
||||
from catalyst.data.data_portal import BASE_FIELDS
|
||||
|
||||
BITFINEX_URL = 'https://api.bitfinex.com'
|
||||
BITFINEX_KEY = 'hjZ7DZzwbBZsIZPWeSSQtrWCPNwyhxw96r3LnY7jtOH'
|
||||
@@ -84,6 +86,11 @@ class Bitfinex(Exchange):
|
||||
|
||||
return request
|
||||
|
||||
def _get_v2_symbol(self, asset):
|
||||
pair = asset.symbol.split('_')
|
||||
symbol = 't' + pair[0].upper() + pair[1].upper()
|
||||
return symbol
|
||||
|
||||
def _get_v2_symbols(self, assets):
|
||||
"""
|
||||
Workaround to support Bitfinex v2
|
||||
@@ -95,9 +102,7 @@ class Bitfinex(Exchange):
|
||||
|
||||
v2_symbols = []
|
||||
for asset in assets:
|
||||
pair = asset.symbol.split('_')
|
||||
symbol = 't' + pair[0].upper() + pair[1].upper()
|
||||
v2_symbols.append(symbol)
|
||||
v2_symbols.append(self._get_v2_symbol(asset))
|
||||
|
||||
return v2_symbols
|
||||
|
||||
@@ -214,13 +219,99 @@ class Bitfinex(Exchange):
|
||||
@property
|
||||
def time_skew(self):
|
||||
# TODO: research the time skew conditions
|
||||
return None
|
||||
return pd.Timedelta('0s')
|
||||
|
||||
def subscribe_to_market_data(self, symbol):
|
||||
pass
|
||||
|
||||
def get_spot_value(self, assets, field, dt, data_frequency):
|
||||
raise NotImplementedError()
|
||||
def get_spot_value(self, assets, field, dt=None, data_frequency='minute'):
|
||||
"""
|
||||
Public API method that returns a scalar value representing the value
|
||||
of the desired asset's field at either the given dt.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
assets : Asset, ContinuousFuture, or iterable of same.
|
||||
The asset or assets whose data is desired.
|
||||
field : {'open', 'high', 'low', 'close', 'volume',
|
||||
'price', 'last_traded'}
|
||||
The desired field of the asset.
|
||||
dt : pd.Timestamp
|
||||
The timestamp for the desired value.
|
||||
data_frequency : str
|
||||
The frequency of the data to query; i.e. whether the data is
|
||||
'daily' or 'minute' bars
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : float, int, or pd.Timestamp
|
||||
The spot value of ``field`` for ``asset`` The return type is based
|
||||
on the ``field`` requested. If the field is one of 'open', 'high',
|
||||
'low', 'close', or 'price', the value will be a float. If the
|
||||
``field`` is 'volume' the value will be a int. If the ``field`` is
|
||||
'last_traded' the value will be a Timestamp.
|
||||
|
||||
Bitfinex timeframes
|
||||
-------------------
|
||||
Available values: '1m', '5m', '15m', '30m', '1h', '3h', '6h', '12h',
|
||||
'1D', '7D', '14D', '1M'
|
||||
"""
|
||||
if field not in BASE_FIELDS:
|
||||
raise KeyError('Invalid column: ' + str(field))
|
||||
|
||||
if isinstance(assets, collections.Iterable):
|
||||
values = list()
|
||||
for asset in assets:
|
||||
value = self.get_single_spot_value(
|
||||
asset, field, data_frequency)
|
||||
values.append(value)
|
||||
|
||||
return values
|
||||
else:
|
||||
return self.get_single_spot_value(
|
||||
assets, field, data_frequency)
|
||||
|
||||
def get_single_spot_value(self, asset, field, data_frequency):
|
||||
symbol = self._get_v2_symbol(asset)
|
||||
log.debug('fetching spot value for symbol {}'.format(symbol))
|
||||
|
||||
if data_frequency == 'minute':
|
||||
frequency = '1m'
|
||||
elif data_frequency == 'daily':
|
||||
frequency = '1d'
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Unsupported frequency %s' % data_frequency
|
||||
)
|
||||
|
||||
request = requests.get(
|
||||
'{url}/v2/candles/trade:{frequency}:{symbol}/last'.format(
|
||||
url=self.url,
|
||||
frequency=frequency,
|
||||
symbol=symbol
|
||||
)
|
||||
)
|
||||
candles = request.json()
|
||||
|
||||
if 'message' in candles:
|
||||
raise ValueError(
|
||||
'Unable to retrieve candles: %s' % candles['message']
|
||||
)
|
||||
|
||||
ohlc = dict(
|
||||
open=candles[1],
|
||||
high=candles[3],
|
||||
low=candles[4],
|
||||
close=candles[2],
|
||||
volume=candles[5],
|
||||
price=candles[2],
|
||||
last_traded=pd.Timestamp.utcfromtimestamp(candles[0] / 1000.0),
|
||||
)
|
||||
|
||||
if not field in ohlc:
|
||||
raise KeyError('Invalid column: ' + str(field))
|
||||
|
||||
return ohlc[field]
|
||||
|
||||
def order(self, asset, amount, limit_price, stop_price, style):
|
||||
"""Place an order.
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from catalyst.data.data_portal import DataPortal
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
log = Logger('DataPortalExchange')
|
||||
|
||||
|
||||
class DataPortalExchange(DataPortal):
|
||||
def __init__(self, exchange, *args, **kwargs):
|
||||
self.exchange = exchange
|
||||
super(DataPortalExchange, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_history_window(self,
|
||||
assets,
|
||||
end_dt,
|
||||
bar_count,
|
||||
frequency,
|
||||
field,
|
||||
data_frequency,
|
||||
ffill=True):
|
||||
history_window = super(self.__class__, self).get_history_window(
|
||||
assets,
|
||||
end_dt,
|
||||
bar_count,
|
||||
frequency,
|
||||
field,
|
||||
data_frequency,
|
||||
ffill)
|
||||
|
||||
# The returned dataframe contains today's value as a NaN because
|
||||
# end_dt points to the current wall clock. We drop today's
|
||||
# value to be in sync with the simulation's behavior.
|
||||
today = pd.to_datetime('now').date()
|
||||
return history_window[history_window.index.date != today]
|
||||
|
||||
def get_spot_value(self, assets, field, dt, data_frequency):
|
||||
return self.exchange.get_spot_value(assets, field, dt, data_frequency)
|
||||
|
||||
def get_adjusted_value(self, asset, field, dt,
|
||||
perspective_dt,
|
||||
data_frequency,
|
||||
spot_value=None):
|
||||
raise NotImplementedError("get_adjusted_value is not implemented yet!")
|
||||
@@ -184,6 +184,32 @@ class Exchange:
|
||||
|
||||
@abstractmethod
|
||||
def get_spot_value(self, assets, field, dt, data_frequency):
|
||||
"""
|
||||
Public API method that returns a scalar value representing the value
|
||||
of the desired asset's field at either the given dt.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
assets : Asset, ContinuousFuture, or iterable of same.
|
||||
The asset or assets whose data is desired.
|
||||
field : {'open', 'high', 'low', 'close', 'volume',
|
||||
'price', 'last_traded'}
|
||||
The desired field of the asset.
|
||||
dt : pd.Timestamp
|
||||
The timestamp for the desired value.
|
||||
data_frequency : str
|
||||
The frequency of the data to query; i.e. whether the data is
|
||||
'daily' or 'minute' bars
|
||||
|
||||
Returns
|
||||
-------
|
||||
value : float, int, or pd.Timestamp
|
||||
The spot value of ``field`` for ``asset`` The return type is based
|
||||
on the ``field`` requested. If the field is one of 'open', 'high',
|
||||
'low', 'close', or 'price', the value will be a float. If the
|
||||
``field`` is 'volume' the value will be a int. If the ``field`` is
|
||||
'last_traded' the value will be a Timestamp.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from time import sleep
|
||||
|
||||
from logbook import Logger
|
||||
import pandas as pd
|
||||
|
||||
from catalyst.gens.sim_engine import (
|
||||
BAR,
|
||||
SESSION_START,
|
||||
SESSION_END,
|
||||
MINUTE_END,
|
||||
BEFORE_TRADING_START_BAR
|
||||
)
|
||||
|
||||
log = Logger('ExchangeClock')
|
||||
|
||||
|
||||
class ExchangeClock(object):
|
||||
"""Realtime clock for live trading.
|
||||
|
||||
This class is a drop-in replacement for
|
||||
:class:`zipline.gens.sim_engine.MinuteSimulationClock`.
|
||||
|
||||
This is a stripped down version because crypto exchanges run around the clock.
|
||||
|
||||
The :param:`time_skew` parameter represents the time difference between
|
||||
the Broker and the live trading machine's clock.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sessions,
|
||||
execution_opens,
|
||||
execution_closes,
|
||||
before_trading_start_minutes,
|
||||
minute_emission,
|
||||
time_skew=pd.Timedelta("0s")):
|
||||
self.sessions = sessions
|
||||
self.execution_opens = execution_opens
|
||||
self.execution_closes = execution_closes
|
||||
self.before_trading_start_minutes = before_trading_start_minutes
|
||||
self.minute_emission = minute_emission
|
||||
self.time_skew = time_skew
|
||||
self._last_emit = None
|
||||
self._before_trading_start_bar_yielded = False
|
||||
|
||||
# It is expected to have this clock created once a day (ideally prior
|
||||
# to BEFORE_TRADING_START_BAR event). Multiple days (sessions) are
|
||||
# not supported.
|
||||
assert len(self.sessions) == 1
|
||||
|
||||
def __iter__(self):
|
||||
yield self.sessions[0], SESSION_START
|
||||
|
||||
while True:
|
||||
current_time = pd.Timestamp.utcnow()
|
||||
server_time = (current_time + self.time_skew).floor('1 min')
|
||||
|
||||
if (self._last_emit is None or
|
||||
server_time - self._last_emit >=
|
||||
pd.Timedelta('1 minute')):
|
||||
|
||||
self._last_emit = server_time
|
||||
yield server_time, BAR
|
||||
|
||||
if self.minute_emission:
|
||||
yield server_time, MINUTE_END
|
||||
|
||||
else:
|
||||
sleep(1)
|
||||
@@ -38,7 +38,7 @@ class BaseExchangeTestCase():
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def test_spot_value(self):
|
||||
def test_get_spot_value(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -64,8 +64,20 @@ class BitfinexTestCase(BaseExchangeTestCase):
|
||||
log.info('canceled order: {}'.format(response))
|
||||
pass
|
||||
|
||||
def test_spot_value(self):
|
||||
log.info('spot valud not implemented')
|
||||
def test_get_spot_value(self):
|
||||
log.info('spot value not implemented')
|
||||
bitfinex = Bitfinex()
|
||||
assets = [
|
||||
bitfinex.get_asset('eth_usd'),
|
||||
bitfinex.get_asset('etc_usd'),
|
||||
bitfinex.get_asset('eos_usd'),
|
||||
]
|
||||
# assets = bitfinex.get_asset('eth_usd')
|
||||
value = bitfinex.get_spot_value(
|
||||
assets=assets,
|
||||
field='close',
|
||||
data_frequency='minute'
|
||||
)
|
||||
pass
|
||||
|
||||
def test_tickers(self):
|
||||
|
||||
Reference in New Issue
Block a user