mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 15:15:51 +08:00
BLD: improved stats display to better support multiple assets per algo
This commit is contained in:
@@ -97,10 +97,9 @@ def handle_data(context, data):
|
||||
# the record() method to save it. This data will be available as
|
||||
# a parameter of the analyze() function for further analysis.
|
||||
record(
|
||||
price=price,
|
||||
volume=current['volume'],
|
||||
price_change=price_change,
|
||||
rsi=rsi[-1],
|
||||
volume=(context.market, current['volume']),
|
||||
price_change=(context.market, price_change),
|
||||
rsi=(context.market, rsi[-1]),
|
||||
cash=cash
|
||||
)
|
||||
|
||||
@@ -278,6 +277,6 @@ if __name__ == '__main__':
|
||||
algo_namespace=NAMESPACE,
|
||||
base_currency='eth',
|
||||
live_graph=False,
|
||||
simulate_orders=False,
|
||||
simulate_orders=True,
|
||||
stats_output=None
|
||||
)
|
||||
|
||||
@@ -614,12 +614,12 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
|
||||
self.add_exposure_stats(frame_stats)
|
||||
|
||||
print_df = pd.DataFrame(list(self.frame_stats))
|
||||
# print_df = pd.DataFrame(list(self.frame_stats))
|
||||
log.info(
|
||||
'statistics for the last {stats_minutes} minutes:\n{stats}'.format(
|
||||
stats_minutes=self.stats_minutes,
|
||||
stats=get_pretty_stats(
|
||||
df=print_df,
|
||||
stats=self.frame_stats,
|
||||
recorded_cols=recorded_cols,
|
||||
num_rows=self.stats_minutes
|
||||
)
|
||||
@@ -644,7 +644,7 @@ class ExchangeTradingAlgorithmLive(ExchangeTradingAlgorithmBase):
|
||||
if 's3://' in self.stats_output:
|
||||
stats_to_s3(
|
||||
uri=self.stats_output,
|
||||
df=print_df,
|
||||
stats=self.frame_stats,
|
||||
algo_namespace=self.algo_namespace,
|
||||
recorded_cols=recorded_cols,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from logbook import Logger
|
||||
|
||||
from catalyst.constants import LOG_LEVEL
|
||||
from catalyst.exchange.exchange_errors import ExchangeRequestError, \
|
||||
ExchangePortfolioDataError, OrphanOrderError, ExchangeTransactionError
|
||||
ExchangePortfolioDataError, ExchangeTransactionError
|
||||
from catalyst.finance.blotter import Blotter
|
||||
from catalyst.finance.commission import CommissionModel
|
||||
from catalyst.finance.order import ORDER_STATUS
|
||||
@@ -175,6 +175,8 @@ class ExchangeBlotter(Blotter):
|
||||
|
||||
@expect_types(asset=TradingPair)
|
||||
def order(self, asset, amount, style, order_id=None):
|
||||
log.debug('ordering {} {}'.format(amount, asset.symbol))
|
||||
|
||||
if self.simulate_orders:
|
||||
return super(ExchangeBlotter, self).order(
|
||||
asset, amount, style, order_id
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import numbers
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import boto3
|
||||
import time
|
||||
|
||||
from catalyst.assets._assets import TradingPair
|
||||
|
||||
s3 = boto3.resource('s3')
|
||||
|
||||
|
||||
@@ -123,50 +126,115 @@ def vwap(df):
|
||||
return ret
|
||||
|
||||
|
||||
def format_positions(positions):
|
||||
parts = []
|
||||
for position in positions:
|
||||
msg = '{amount:.2f}{base} cost basis {cost_basis:.8f}{quote}'.format(
|
||||
amount=position['amount'],
|
||||
base=position['sid'].base_currency,
|
||||
cost_basis=position['cost_basis'],
|
||||
quote=position['sid'].quote_currency
|
||||
)
|
||||
parts.append(msg)
|
||||
return ', '.join(parts)
|
||||
def set_position_row(row, position_index, recorded_cols=None):
|
||||
"""
|
||||
Apply the position data as individual columns.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row: dict[str, Object]
|
||||
position_index: int
|
||||
recorded_cols: list[str]
|
||||
If a recorded_col contains a tuple which first value is an asset
|
||||
matching a position, its value will be displayed with the
|
||||
position and not in the index.
|
||||
|
||||
def prepare_stats(df, recorded_cols=None):
|
||||
columns = ['starting_cash', 'ending_cash', 'portfolio_value',
|
||||
'pnl', 'long_exposure', 'short_exposure', 'orders',
|
||||
'transactions', 'positions']
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
position = row['positions'][position_index]
|
||||
|
||||
asset = position['sid']
|
||||
row['symbol'] = asset.symbol
|
||||
|
||||
columns = ['amount', 'cost_basis', 'last_sale_price']
|
||||
for column in columns:
|
||||
row[column] = position[column]
|
||||
|
||||
columns.insert(0, 'symbol')
|
||||
|
||||
if recorded_cols is not None:
|
||||
for column in recorded_cols[:]:
|
||||
value = row[column]
|
||||
if type(value) in [list, tuple] and \
|
||||
isinstance(value[0], TradingPair) and asset == value[0]:
|
||||
row[column] = value[1]
|
||||
|
||||
columns.append(column)
|
||||
# Removing the asset specific entries
|
||||
recorded_cols.remove(column)
|
||||
|
||||
return columns
|
||||
|
||||
|
||||
def prepare_stats(stats, recorded_cols=None):
|
||||
"""
|
||||
Prepare the stats DataFrame for user-friendly output.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stats: list[Object]
|
||||
recorded_cols: list[str]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
position_cols = None
|
||||
|
||||
# Using a copy since we are adding rows inside the loop.
|
||||
for row_index, row_data in enumerate(list(stats)):
|
||||
if len(row_data['positions']) == 1:
|
||||
row = stats[row_index]
|
||||
columns = set_position_row(row, 0, recorded_cols)
|
||||
|
||||
elif len(row_data['positions']) > 1:
|
||||
for pos_index, position in enumerate(row_data['positions']):
|
||||
if pos_index > 0:
|
||||
row = row_data
|
||||
stats.append(row)
|
||||
|
||||
else:
|
||||
row = stats[row_index]
|
||||
|
||||
columns = set_position_row(row, pos_index, recorded_cols)
|
||||
|
||||
else:
|
||||
break
|
||||
|
||||
if position_cols is None:
|
||||
position_cols = columns
|
||||
|
||||
df = pd.DataFrame(list(stats))
|
||||
|
||||
index_cols = [
|
||||
'period_close', 'starting_cash', 'ending_cash', 'portfolio_value',
|
||||
'pnl', 'long_exposure', 'short_exposure', 'orders', 'transactions',
|
||||
]
|
||||
if recorded_cols is not None:
|
||||
for column in recorded_cols:
|
||||
columns.append(column)
|
||||
|
||||
df = df.copy(True)
|
||||
|
||||
df.set_index('period_close', drop=True, inplace=True)
|
||||
df.dropna(axis=1, how='all', inplace=True)
|
||||
index_cols.append(column)
|
||||
|
||||
df['orders'] = df['orders'].apply(lambda orders: len(orders))
|
||||
df['transactions'] = df['transactions'].apply(
|
||||
lambda transactions: len(transactions)
|
||||
)
|
||||
df['positions'] = df['positions'].apply(format_positions)
|
||||
|
||||
return df, columns
|
||||
df.set_index(index_cols, drop=True, inplace=True)
|
||||
df.dropna(axis=1, how='all', inplace=True)
|
||||
|
||||
return df, position_cols
|
||||
|
||||
|
||||
def get_pretty_stats(df, recorded_cols=None, num_rows=10):
|
||||
def get_pretty_stats(stats, recorded_cols=None, num_rows=10):
|
||||
"""
|
||||
Format and print the last few rows of a statistics DataFrame.
|
||||
See the pyfolio project for the data structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df: pd.DataFrame
|
||||
stats: list[Object]
|
||||
num_rows: int
|
||||
|
||||
Returns
|
||||
@@ -174,7 +242,7 @@ def get_pretty_stats(df, recorded_cols=None, num_rows=10):
|
||||
str
|
||||
|
||||
"""
|
||||
df, columns = prepare_stats(df, recorded_cols=recorded_cols)
|
||||
df, columns = prepare_stats(stats, recorded_cols=recorded_cols)
|
||||
|
||||
pd.set_option('display.expand_frame_repr', False)
|
||||
pd.set_option('precision', 3)
|
||||
@@ -191,21 +259,21 @@ def get_pretty_stats(df, recorded_cols=None, num_rows=10):
|
||||
)
|
||||
|
||||
|
||||
def get_csv_stats(df, recorded_cols=None):
|
||||
def get_csv_stats(stats, recorded_cols=None):
|
||||
"""
|
||||
Create a CSV buffer from the stats DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path: str
|
||||
df: pd.DataFrame
|
||||
stats: list[Object]
|
||||
recorded_cols: list[str]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
df, columns = prepare_stats(df, recorded_cols=recorded_cols)
|
||||
df, columns = prepare_stats(stats, recorded_cols=recorded_cols)
|
||||
|
||||
return df.to_csv(
|
||||
None,
|
||||
@@ -214,8 +282,8 @@ def get_csv_stats(df, recorded_cols=None):
|
||||
).encode()
|
||||
|
||||
|
||||
def stats_to_s3(uri, df, algo_namespace, recorded_cols=None):
|
||||
bytes_to_write = get_csv_stats(df, recorded_cols=recorded_cols)
|
||||
def stats_to_s3(uri, stats, algo_namespace, recorded_cols=None):
|
||||
bytes_to_write = get_csv_stats(stats, recorded_cols=recorded_cols)
|
||||
|
||||
timestr = time.strftime('%Y%m%d')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user