mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 19:15:15 +08:00
BLD: tested stats with multiple assets
This commit is contained in:
@@ -96,13 +96,13 @@ def handle_data(context, data):
|
||||
# Now that we've collected all current data for this frame, we use
|
||||
# the record() method to save it. This data will be available as
|
||||
# a parameter of the analyze() function for further analysis.
|
||||
|
||||
record(
|
||||
volume=(context.market, current['volume']),
|
||||
price_change=(context.market, price_change),
|
||||
rsi=(context.market, rsi[-1]),
|
||||
cash=cash
|
||||
)
|
||||
|
||||
# We are trying to avoid over-trading by limiting our trades to
|
||||
# one per day.
|
||||
if context.traded_today:
|
||||
|
||||
@@ -326,7 +326,7 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
self.retry_order = 2
|
||||
self.retry_delay = 5
|
||||
|
||||
self.stats_minutes = 5
|
||||
self.stats_minutes = 20
|
||||
|
||||
super(ExchangeTradingAlgorithmLive, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -176,6 +176,9 @@ class ExchangeBlotter(Blotter):
|
||||
@expect_types(asset=TradingPair)
|
||||
def order(self, asset, amount, style, order_id=None):
|
||||
log.debug('ordering {} {}'.format(amount, asset.symbol))
|
||||
if amount == 0:
|
||||
log.warn('skipping 0 amount orders')
|
||||
return None
|
||||
|
||||
if self.simulate_orders:
|
||||
return super(ExchangeBlotter, self).order(
|
||||
|
||||
@@ -126,15 +126,15 @@ def vwap(df):
|
||||
return ret
|
||||
|
||||
|
||||
def set_position_row(row, position_index, recorded_cols=None):
|
||||
def set_position_row(row, asset, asset_values=list()):
|
||||
"""
|
||||
Apply the position data as individual columns.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row: dict[str, Object]
|
||||
position_index: int
|
||||
recorded_cols: list[str]
|
||||
asset: TradingPair
|
||||
asset_values: list[str]
|
||||
If a recorded_col contains a tuple which first value is an asset
|
||||
matching a position, its value will be displayed with the
|
||||
position and not in the index.
|
||||
@@ -143,32 +143,31 @@ def set_position_row(row, position_index, recorded_cols=None):
|
||||
-------
|
||||
|
||||
"""
|
||||
position = row['positions'][position_index]
|
||||
|
||||
asset = position['sid']
|
||||
asset_cols = ['symbol']
|
||||
row['symbol'] = asset.symbol
|
||||
|
||||
position = next((p for p in row['positions'] if p['sid'] == asset), None)
|
||||
|
||||
columns = ['amount', 'cost_basis', 'last_sale_price']
|
||||
for column in columns:
|
||||
row[column] = position[column]
|
||||
if position is not None:
|
||||
row[column] = position[column]
|
||||
|
||||
columns.insert(0, 'symbol')
|
||||
else:
|
||||
row[column] = 0
|
||||
|
||||
if recorded_cols is not None:
|
||||
for column in recorded_cols[:]:
|
||||
value = row[column]
|
||||
if type(value) in [list, tuple] and \
|
||||
isinstance(value[0], TradingPair) and asset == value[0]:
|
||||
row[column] = value[1]
|
||||
asset_cols.append(column)
|
||||
|
||||
columns.append(column)
|
||||
# Removing the asset specific entries
|
||||
recorded_cols.remove(column)
|
||||
values = asset_values[asset] if asset in asset_values else list()
|
||||
for column in values:
|
||||
row[column] = values[column]
|
||||
|
||||
return columns
|
||||
asset_cols.append(column)
|
||||
|
||||
return asset_cols
|
||||
|
||||
|
||||
def prepare_stats(stats, recorded_cols=None):
|
||||
def prepare_stats(stats, recorded_cols=list()):
|
||||
"""
|
||||
Prepare the stats DataFrame for user-friendly output.
|
||||
|
||||
@@ -181,30 +180,43 @@ def prepare_stats(stats, recorded_cols=None):
|
||||
-------
|
||||
|
||||
"""
|
||||
position_cols = None
|
||||
asset_cols = list()
|
||||
|
||||
# Using a copy since we are adding rows inside the loop.
|
||||
for row_index, row_data in enumerate(list(stats)):
|
||||
if len(row_data['positions']) == 1:
|
||||
row = stats[row_index]
|
||||
columns = set_position_row(row, 0, recorded_cols)
|
||||
assets = [p['sid'] for p in row_data['positions']]
|
||||
|
||||
elif len(row_data['positions']) > 1:
|
||||
for pos_index, position in enumerate(row_data['positions']):
|
||||
if pos_index > 0:
|
||||
row = row_data
|
||||
asset_values = dict()
|
||||
for column in recorded_cols[:]:
|
||||
value = row_data[column]
|
||||
if type(value) is dict:
|
||||
for asset in value:
|
||||
if not isinstance(asset, TradingPair):
|
||||
break
|
||||
|
||||
if asset not in assets:
|
||||
assets.append(asset)
|
||||
|
||||
if asset not in asset_values:
|
||||
asset_values[asset] = dict()
|
||||
|
||||
asset_values[asset][column] = value[asset]
|
||||
|
||||
if len(assets) == 1:
|
||||
row = stats[row_index]
|
||||
asset_cols = set_position_row(row, assets[0], asset_values)
|
||||
|
||||
elif len(assets) > 1:
|
||||
for asset_index, asset in enumerate(assets):
|
||||
if asset_index > 0:
|
||||
row = copy.deepcopy(row_data)
|
||||
stats.append(row)
|
||||
|
||||
else:
|
||||
row = stats[row_index]
|
||||
|
||||
columns = set_position_row(row, pos_index, recorded_cols)
|
||||
|
||||
else:
|
||||
break
|
||||
|
||||
if position_cols is None:
|
||||
position_cols = columns
|
||||
asset_cols = set_position_row(row, assets[asset_index],
|
||||
asset_values)
|
||||
|
||||
df = pd.DataFrame(list(stats))
|
||||
|
||||
@@ -212,6 +224,9 @@ def prepare_stats(stats, recorded_cols=None):
|
||||
'period_close', 'starting_cash', 'ending_cash', 'portfolio_value',
|
||||
'pnl', 'long_exposure', 'short_exposure', 'orders', 'transactions',
|
||||
]
|
||||
|
||||
# Removing the asset specific entries
|
||||
recorded_cols = [x for x in recorded_cols if x not in asset_cols]
|
||||
if recorded_cols is not None:
|
||||
for column in recorded_cols:
|
||||
index_cols.append(column)
|
||||
@@ -223,8 +238,9 @@ def prepare_stats(stats, recorded_cols=None):
|
||||
|
||||
df.set_index(index_cols, drop=True, inplace=True)
|
||||
df.dropna(axis=1, how='all', inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
return df, position_cols
|
||||
return df, asset_cols
|
||||
|
||||
|
||||
def get_pretty_stats(stats, recorded_cols=None, num_rows=10):
|
||||
@@ -245,7 +261,7 @@ def get_pretty_stats(stats, recorded_cols=None, num_rows=10):
|
||||
df, columns = prepare_stats(stats, recorded_cols=recorded_cols)
|
||||
|
||||
pd.set_option('display.expand_frame_repr', False)
|
||||
pd.set_option('precision', 3)
|
||||
pd.set_option('precision', 8)
|
||||
pd.set_option('display.width', 1000)
|
||||
pd.set_option('display.max_colwidth', 1000)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user