mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 02:29:54 +08:00
ENH: Adds futures trading and asset management logic to TradingAlgorithm and performance classes
This commit is contained in:
+115
-42
@@ -39,9 +39,10 @@ from zipline.errors import (
|
||||
UnsupportedCommissionModel,
|
||||
UnsupportedOrderParameters,
|
||||
UnsupportedSlippageModel,
|
||||
SidNotFound,
|
||||
)
|
||||
|
||||
from zipline.finance import trading
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.finance.blotter import Blotter
|
||||
from zipline.finance.commission import PerShare, PerTrade, PerDollar
|
||||
from zipline.finance.controls import (
|
||||
@@ -64,6 +65,7 @@ from zipline.finance.slippage import (
|
||||
SlippageModel,
|
||||
transact_partial
|
||||
)
|
||||
from zipline.assets import Asset, Future
|
||||
from zipline.gens.composites import date_sorted_sources
|
||||
from zipline.gens.tradesimulation import AlgorithmSimulator
|
||||
from zipline.sources import DataFrameSource, DataPanelSource
|
||||
@@ -133,8 +135,27 @@ class TradingAlgorithm(object):
|
||||
How much capital to start with.
|
||||
instant_fill : bool <default: False>
|
||||
Whether to fill orders immediately or on next bar.
|
||||
environment : str <default: 'zipline'>
|
||||
The environment that this algorithm is running in.
|
||||
asset_finder : An AssetFinder object
|
||||
A new AssetFinder object to be used in this TradingEnvironment
|
||||
asset_metadata: can be either:
|
||||
- dict
|
||||
- pandas.DataFrame
|
||||
- object with 'read' property
|
||||
If dict is provided, it must have the following structure:
|
||||
* keys are the identifiers
|
||||
* values are dicts containing the metadata, with the metadata
|
||||
field name as the key
|
||||
If pandas.DataFrame is provided, it must have the
|
||||
following structure:
|
||||
* column names must be the metadata fields
|
||||
* index must be the different asset identifiers
|
||||
* array contents should be the metadata value
|
||||
If an object with a 'read' property is provided, 'read' must
|
||||
return rows containing at least one of 'sid' or 'symbol' along
|
||||
with the other metadata fields.
|
||||
identifiers : List
|
||||
Any asset identifiers that are not provided in the
|
||||
asset_metadata, but will be traded by this TradingAlgorithm
|
||||
"""
|
||||
self.datetime = None
|
||||
|
||||
@@ -167,10 +188,23 @@ class TradingAlgorithm(object):
|
||||
self.sim_params = kwargs.pop('sim_params', None)
|
||||
if self.sim_params is None:
|
||||
self.sim_params = create_simulation_parameters(
|
||||
capital_base=self.capital_base
|
||||
capital_base=self.capital_base,
|
||||
start=kwargs.pop('start', None),
|
||||
end=kwargs.pop('end', None)
|
||||
)
|
||||
self.perf_tracker = PerformanceTracker(self.sim_params)
|
||||
|
||||
# Update the TradingEnvironment with the provided asset metadata
|
||||
self.trading_environment = kwargs.pop('env',
|
||||
TradingEnvironment.instance())
|
||||
self.trading_environment.update_asset_finder(
|
||||
asset_finder=kwargs.pop('asset_finder', None),
|
||||
asset_metadata=kwargs.pop('asset_metadata', None),
|
||||
identifiers=kwargs.pop('identifiers', None)
|
||||
)
|
||||
# Pull in the environment's new AssetFinder for quick reference
|
||||
self.asset_finder = self.trading_environment.asset_finder
|
||||
|
||||
self.blotter = kwargs.pop('blotter', None)
|
||||
if not self.blotter:
|
||||
self.blotter = Blotter()
|
||||
@@ -322,11 +356,10 @@ class TradingAlgorithm(object):
|
||||
sim_params = self.sim_params
|
||||
|
||||
if self.benchmark_return_source is None:
|
||||
env = trading.environment
|
||||
if sim_params.data_frequency == 'minute' or \
|
||||
sim_params.emission_rate == 'minute':
|
||||
def update_time(date):
|
||||
return env.get_open_and_close(date)[1]
|
||||
return self.trading_environment.get_open_and_close(date)[1]
|
||||
else:
|
||||
def update_time(date):
|
||||
return date
|
||||
@@ -336,7 +369,7 @@ class TradingAlgorithm(object):
|
||||
'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
|
||||
'source_id': 'benchmarks'})
|
||||
for dt, ret in
|
||||
trading.environment.benchmark_returns.iteritems()
|
||||
self.trading_environment.benchmark_returns.iteritems()
|
||||
if dt.date() >= sim_params.period_start.date() and
|
||||
dt.date() <= sim_params.period_end.date()
|
||||
]
|
||||
@@ -410,8 +443,7 @@ class TradingAlgorithm(object):
|
||||
|
||||
If pandas.DataFrame is provided, it must have the
|
||||
following structure:
|
||||
* column names must consist of ints representing the
|
||||
different sids
|
||||
* column names must be the different asset identifiers
|
||||
* index must be DatetimeIndex
|
||||
* array contents should be price info.
|
||||
|
||||
@@ -420,9 +452,11 @@ class TradingAlgorithm(object):
|
||||
Daily performance metrics such as returns, alpha etc.
|
||||
|
||||
"""
|
||||
|
||||
# Ensure that source is a DataSource object
|
||||
if isinstance(source, list):
|
||||
if overwrite_sim_params:
|
||||
warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
|
||||
warnings.warn("""List of sources passed, will not attempt to extract start and end
|
||||
dates. Make sure to set the correct fields in sim_params passed to
|
||||
__init__().""", UserWarning)
|
||||
overwrite_sim_params = False
|
||||
@@ -443,8 +477,19 @@ class TradingAlgorithm(object):
|
||||
self.sim_params.period_start = source.start
|
||||
if hasattr(source, 'end'):
|
||||
self.sim_params.period_end = source.end
|
||||
# The sids field of the source is the canonical reference for
|
||||
# sids in this run
|
||||
all_sids = [sid for s in self.sources for sid in s.sids]
|
||||
self.sim_params.sids = set(all_sids)
|
||||
# Check that all sids from the source are accounted for in
|
||||
# the AssetFinder
|
||||
for sid in self.sim_params.sids:
|
||||
try:
|
||||
self.asset_finder.retrieve_asset(sid)
|
||||
except SidNotFound:
|
||||
warnings.warn("No Asset found for sid '%s'. Make sure "
|
||||
"that the correct identifiers and asset "
|
||||
"metadata are passed to __init__()." % sid)
|
||||
# Changing period_start and period_close might require updating
|
||||
# of first_open and last_close.
|
||||
self.sim_params._update_internal()
|
||||
@@ -604,17 +649,52 @@ class TradingAlgorithm(object):
|
||||
def symbol(self, symbol_str):
|
||||
"""
|
||||
Default symbol lookup for any source that directly maps the
|
||||
symbol to the identifier (e.g. yahoo finance).
|
||||
symbol to the Asset (e.g. yahoo finance).
|
||||
"""
|
||||
return symbol_str
|
||||
asset, _ = self.asset_finder.lookup_generic(
|
||||
asset_convertible_or_iterable=symbol_str,
|
||||
as_of_date=self.datetime,
|
||||
)
|
||||
return asset
|
||||
|
||||
@api_method
|
||||
def symbols(self, *args):
|
||||
"""
|
||||
Default symbols lookup for any source that directly maps the
|
||||
symbol to the identifier (e.g. yahoo finance).
|
||||
symbol to the Asset (e.g. yahoo finance).
|
||||
"""
|
||||
return args
|
||||
return [self.symbol(identifier) for identifier in args]
|
||||
|
||||
@api_method
|
||||
def sid(self, a_sid):
|
||||
"""
|
||||
Default sid lookup for any source that directly maps the integer sid
|
||||
to the Asset.
|
||||
"""
|
||||
return self.asset_finder.retrieve_asset(a_sid)
|
||||
|
||||
def _calculate_order_value_amount(self, asset, value):
|
||||
"""
|
||||
Calculates how many shares/contracts to order based on the type of
|
||||
asset being ordered.
|
||||
"""
|
||||
last_price = self.trading_client.current_data[asset].price
|
||||
|
||||
if tolerant_equals(last_price, 0):
|
||||
zero_message = "Price of 0 for {psid}; can't infer value".format(
|
||||
psid=asset
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.debug(zero_message)
|
||||
# Don't place any order
|
||||
return 0
|
||||
|
||||
if isinstance(asset, Future):
|
||||
value_multiplier = asset.contract_multiplier
|
||||
else:
|
||||
value_multiplier = 1
|
||||
|
||||
return value / (last_price * value_multiplier)
|
||||
|
||||
@api_method
|
||||
def order(self, sid, amount,
|
||||
@@ -655,7 +735,7 @@ class TradingAlgorithm(object):
|
||||
return self.blotter.order(sid, amount, style)
|
||||
|
||||
def validate_order_params(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
limit_price,
|
||||
stop_price,
|
||||
@@ -682,8 +762,14 @@ class TradingAlgorithm(object):
|
||||
msg="Passing both stop_price and style is not supported."
|
||||
)
|
||||
|
||||
if not isinstance(asset, Asset):
|
||||
raise UnsupportedOrderParameters(
|
||||
msg="Passing non-Asset argument to 'order()' is not supported."
|
||||
" Use 'sid()' or 'symbol()' methods to look up an Asset."
|
||||
)
|
||||
|
||||
for control in self.trading_controls:
|
||||
control.validate(sid,
|
||||
control.validate(asset,
|
||||
amount,
|
||||
self.updated_portfolio(),
|
||||
self.get_datetime(),
|
||||
@@ -718,6 +804,8 @@ class TradingAlgorithm(object):
|
||||
Place an order by desired value rather than desired number of shares.
|
||||
If the requested sid is found in the universe, the requested value is
|
||||
divided by its price to imply the number of shares to transact.
|
||||
If the Asset being ordered is a Future, the 'value' calculated
|
||||
is actually the exposure, as Futures have no 'value'.
|
||||
|
||||
value > 0 :: Buy/Cover
|
||||
value < 0 :: Sell/Short
|
||||
@@ -726,21 +814,11 @@ class TradingAlgorithm(object):
|
||||
Stop order: order(sid, value, None, stop_price)
|
||||
StopLimit order: order(sid, value, limit_price, stop_price)
|
||||
"""
|
||||
last_price = self.trading_client.current_data[sid].price
|
||||
if tolerant_equals(last_price, 0):
|
||||
zero_message = "Price of 0 for {psid}; can't infer value".format(
|
||||
psid=sid
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.debug(zero_message)
|
||||
# Don't place any order
|
||||
return
|
||||
else:
|
||||
amount = value / last_price
|
||||
return self.order(sid, amount,
|
||||
limit_price=limit_price,
|
||||
stop_price=stop_price,
|
||||
style=style)
|
||||
amount = self._calculate_order_value_amount(sid, value)
|
||||
return self.order(sid, amount,
|
||||
limit_price=limit_price,
|
||||
stop_price=stop_price,
|
||||
style=style)
|
||||
|
||||
@property
|
||||
def recorded_vars(self):
|
||||
@@ -855,7 +933,7 @@ class TradingAlgorithm(object):
|
||||
def order_percent(self, sid, percent,
|
||||
limit_price=None, stop_price=None, style=None):
|
||||
"""
|
||||
Place an order in the specified security corresponding to the given
|
||||
Place an order in the specified asset corresponding to the given
|
||||
percent of the current portfolio value.
|
||||
|
||||
Note that percent must expressed as a decimal (0.50 means 50\%).
|
||||
@@ -898,15 +976,10 @@ class TradingAlgorithm(object):
|
||||
order. If the position does exist, this is equivalent to placing an
|
||||
order for the difference between the target value and the
|
||||
current value.
|
||||
If the Asset being ordered is a Future, the 'target value' calculated
|
||||
is actually the target exposure, as Futures have no 'value'.
|
||||
"""
|
||||
last_price = self.trading_client.current_data[sid].price
|
||||
if tolerant_equals(last_price, 0):
|
||||
# Don't place an order
|
||||
if self.logger:
|
||||
zero_message = "Price of 0 for {psid}; can't infer value"
|
||||
self.logger.debug(zero_message.format(psid=sid))
|
||||
return
|
||||
target_amount = target / last_price
|
||||
target_amount = self._calculate_order_value_amount(sid, target)
|
||||
return self.order_target(sid, target_amount,
|
||||
limit_price=limit_price,
|
||||
stop_price=stop_price,
|
||||
@@ -1066,7 +1139,7 @@ class TradingAlgorithm(object):
|
||||
increasing the absolute value of shares/dollar value exceeding one of
|
||||
these limits, raise a TradingControlException.
|
||||
"""
|
||||
control = MaxPositionSize(sid=sid,
|
||||
control = MaxPositionSize(asset=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.register_trading_control(control)
|
||||
@@ -1081,7 +1154,7 @@ class TradingAlgorithm(object):
|
||||
If an algorithm attempts to place an order that would result in
|
||||
exceeding one of these limits, raise a TradingControlException.
|
||||
"""
|
||||
control = MaxOrderSize(sid=sid,
|
||||
control = MaxOrderSize(asset=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.register_trading_control(control)
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from . import loader
|
||||
from .loader import load_from_yahoo, load_bars_from_yahoo
|
||||
from .loader import (
|
||||
load_from_yahoo, load_bars_from_yahoo, load_prices_from_csv,
|
||||
load_prices_from_csv_folder
|
||||
)
|
||||
|
||||
__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo']
|
||||
__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo',
|
||||
'load_prices_from_csv', 'load_prices_from_csv_folder']
|
||||
|
||||
+22
-1
@@ -294,7 +294,7 @@ def load_from_yahoo(indexes=None,
|
||||
adjusted=True):
|
||||
"""
|
||||
Loads price data from Yahoo into a dataframe for each of the indicated
|
||||
securities. By default, 'price' is taken from Yahoo's 'Adjusted Close',
|
||||
assets. By default, 'price' is taken from Yahoo's 'Adjusted Close',
|
||||
which removes the impact of splits and dividends. If the argument
|
||||
'adjusted' is False, then the non-adjusted 'close' field is used instead.
|
||||
|
||||
@@ -367,3 +367,24 @@ def load_bars_from_yahoo(indexes=None,
|
||||
for col in adj_cols:
|
||||
panel[ticker][col] *= ratio_filtered
|
||||
return panel
|
||||
|
||||
|
||||
def load_prices_from_csv(filepath, identifier_col, tz='UTC'):
|
||||
data = pd.read_csv(filepath, index_col=identifier_col)
|
||||
data.index = pd.DatetimeIndex(data.index, tz=tz)
|
||||
data.sort_index(inplace=True)
|
||||
return data
|
||||
|
||||
|
||||
def load_prices_from_csv_folder(folderpath, identifier_col, tz='UTC'):
|
||||
data = None
|
||||
for file in os.listdir(folderpath):
|
||||
if '.csv' not in file:
|
||||
continue
|
||||
raw = load_prices_from_csv(os.path.join(folderpath, file),
|
||||
identifier_col, tz)
|
||||
if data is None:
|
||||
data = raw
|
||||
else:
|
||||
data = pd.concat([data, raw], axis=1)
|
||||
return data
|
||||
|
||||
+94
-1
@@ -175,7 +175,8 @@ class TradingControlViolation(ZiplineError):
|
||||
Raised if an order would violate a constraint set by a TradingControl.
|
||||
"""
|
||||
msg = """
|
||||
Order for {amount} shares of {sid} violates trading constraint {constraint}.
|
||||
Order for {amount} shares of {asset} at {datetime} violates trading constraint
|
||||
{constraint}.
|
||||
""".strip()
|
||||
|
||||
|
||||
@@ -188,3 +189,95 @@ class IncompatibleHistoryFrequency(ZiplineError):
|
||||
Requested history at frequency '{frequency}' cannot be created with data
|
||||
at frequency '{data_frequency}'.
|
||||
""".strip()
|
||||
|
||||
|
||||
class MultipleSymbolsFound(ZiplineError):
|
||||
"""
|
||||
Raised when a symbol() call contains a symbol that changed over
|
||||
time and is thus not resolvable without additional information
|
||||
provided via as_of_date.
|
||||
"""
|
||||
msg = """
|
||||
Multiple symbols with the name '{symbol}' found. Use the
|
||||
as_of_date' argument to to specify when the date symbol-lookup
|
||||
should be valid.
|
||||
|
||||
Possible options:{options}
|
||||
""".strip()
|
||||
|
||||
|
||||
class SymbolNotFound(ZiplineError):
|
||||
"""
|
||||
Raised when a symbol() call contains a non-existant symbol.
|
||||
"""
|
||||
msg = """
|
||||
Symbol '{symbol}' was not found.
|
||||
""".strip()
|
||||
|
||||
|
||||
class SidNotFound(ZiplineError):
|
||||
"""
|
||||
Raised when a retrieve_asset() call contains a non-existent sid.
|
||||
"""
|
||||
msg = """
|
||||
Asset with sid '{sid}' was not found.
|
||||
""".strip()
|
||||
|
||||
|
||||
class IdentifierNotFound(ZiplineError):
|
||||
"""
|
||||
Raised when a retrieve_asset_by_identifier() call contains a non-existent
|
||||
identifier.
|
||||
"""
|
||||
msg = """
|
||||
Asset with identifier '{identifier}' was not found.
|
||||
""".strip()
|
||||
|
||||
|
||||
class InvalidAssetType(ZiplineError):
|
||||
"""
|
||||
Raised when an AssetFinder tries to build an Asset with an invalid
|
||||
AssetType.
|
||||
"""
|
||||
msg = """
|
||||
AssetMetaData contained an invalid Asset type: '{asset_type}'.
|
||||
""".strip()
|
||||
|
||||
|
||||
class UpdateAssetFinderTypeError(ZiplineError):
|
||||
"""
|
||||
Raised when TradingEnvironment.update_asset_finder() gets an asset_finder
|
||||
arg that is not of AssetFinder class.
|
||||
"""
|
||||
msg = """
|
||||
TradingEnvironment can not set asset_finder to object of class {cls}.
|
||||
""".strip()
|
||||
|
||||
|
||||
class ConsumeAssetMetaDataError(ZiplineError):
|
||||
"""
|
||||
Raised when AssetMetaData.consume() is called on an invalid object.
|
||||
"""
|
||||
msg = """
|
||||
AssetMetaData can not consume {obj}. MetaData must be a dict, a DataFrame, or
|
||||
""".strip()
|
||||
|
||||
|
||||
class NoSourceError(ZiplineError):
|
||||
"""
|
||||
Raised when no source is given to the pipeline
|
||||
"""
|
||||
msg = """
|
||||
No data source given.
|
||||
""".strip()
|
||||
|
||||
|
||||
class PipelineDateError(ZiplineError):
|
||||
"""
|
||||
Raised when only one date is passed to the pipeline
|
||||
"""
|
||||
msg = """
|
||||
Only one simulation date given. Please specify both the 'start' and 'end' for
|
||||
the simulation, or neither. If neither is given, the start and end of the
|
||||
DataSource will be used. Given start = '{start}', end = '{end}'
|
||||
""".strip()
|
||||
|
||||
@@ -29,7 +29,7 @@ from zipline.transforms.ta import EMA
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.security = symbol('AAPL')
|
||||
context.asset = symbol('AAPL')
|
||||
|
||||
# Add 2 mavg transforms, one with a long window, one with a short window.
|
||||
context.short_ema_trans = EMA(timeperiod=20)
|
||||
@@ -49,17 +49,17 @@ def handle_data(context, data):
|
||||
sell = False
|
||||
|
||||
if (short_ema > long_ema).all() and not context.invested:
|
||||
order(context.security, 100)
|
||||
order(context.asset, 100)
|
||||
context.invested = True
|
||||
buy = True
|
||||
elif (short_ema < long_ema).all() and context.invested:
|
||||
order(context.security, -100)
|
||||
order(context.asset, -100)
|
||||
context.invested = False
|
||||
sell = True
|
||||
|
||||
record(AAPL=data[context.security].price,
|
||||
short_ema=short_ema[context.security],
|
||||
long_ema=long_ema[context.security],
|
||||
record(AAPL=data[context.asset].price,
|
||||
short_ema=short_ema[context.asset],
|
||||
long_ema=long_ema[context.asset],
|
||||
buy=buy,
|
||||
sell=sell)
|
||||
|
||||
@@ -77,7 +77,8 @@ if __name__ == '__main__':
|
||||
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
|
||||
end=end)
|
||||
|
||||
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data)
|
||||
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
|
||||
identifiers=['AAPL'])
|
||||
results = algo.run(data).dropna()
|
||||
|
||||
fig = plt.figure()
|
||||
|
||||
@@ -31,6 +31,8 @@ def initialize(context):
|
||||
add_history(100, '1d', 'price')
|
||||
add_history(300, '1d', 'price')
|
||||
|
||||
context.sym = symbol('AAPL')
|
||||
|
||||
context.i = 0
|
||||
|
||||
|
||||
@@ -46,17 +48,15 @@ def handle_data(context, data):
|
||||
short_mavg = history(100, '1d', 'price').mean()
|
||||
long_mavg = history(300, '1d', 'price').mean()
|
||||
|
||||
sym = symbol('AAPL')
|
||||
|
||||
# Trading logic
|
||||
if short_mavg[sym] > long_mavg[sym]:
|
||||
if short_mavg[context.sym] > long_mavg[context.sym]:
|
||||
# order_target orders as many shares as needed to
|
||||
# achieve the desired number of shares.
|
||||
order_target(sym, 100)
|
||||
elif short_mavg[sym] < long_mavg[sym]:
|
||||
order_target(sym, 0)
|
||||
order_target(context.sym, 100)
|
||||
elif short_mavg[context.sym] < long_mavg[context.sym]:
|
||||
order_target(context.sym, 0)
|
||||
|
||||
# Save values for later inspection
|
||||
record(AAPL=data[sym].price,
|
||||
short_mavg=short_mavg[sym],
|
||||
long_mavg=long_mavg[sym])
|
||||
record(AAPL=data[context.sym].price,
|
||||
short_mavg=short_mavg[context.sym],
|
||||
long_mavg=long_mavg[context.sym])
|
||||
|
||||
+13
-10
@@ -24,6 +24,7 @@ STOCKS = ['AMD', 'CERN', 'COST', 'DELL', 'GPS', 'INTC', 'MMM']
|
||||
# http://icml.cc/2012/papers/168.pdf
|
||||
def initialize(algo, eps=1, window_length=5):
|
||||
algo.stocks = STOCKS
|
||||
algo.sids = [algo.symbol(symbol) for symbol in algo.stocks]
|
||||
algo.m = len(algo.stocks)
|
||||
algo.price = {}
|
||||
algo.b_t = np.ones(algo.m) / algo.m
|
||||
@@ -52,11 +53,11 @@ def handle_data(algo, data):
|
||||
x_tilde = np.zeros(m)
|
||||
b = np.zeros(m)
|
||||
|
||||
# find relative moving average price for each security
|
||||
for i, stock in enumerate(algo.stocks):
|
||||
price = data[stock].price
|
||||
# find relative moving average price for each asset
|
||||
for i, sid in enumerate(algo.sids):
|
||||
price = data[sid].price
|
||||
# Relative mean deviation
|
||||
x_tilde[i] = data[stock].mavg(algo.window_length) / price
|
||||
x_tilde[i] = data[sid].mavg(algo.window_length) / price
|
||||
|
||||
###########################
|
||||
# Inside of OLMAR (algo 2)
|
||||
@@ -98,17 +99,17 @@ def rebalance_portfolio(algo, data, desired_port):
|
||||
positions_value = algo.portfolio.positions_value + \
|
||||
algo.portfolio.cash
|
||||
|
||||
for i, stock in enumerate(algo.stocks):
|
||||
current_amount[i] = algo.portfolio.positions[stock].amount
|
||||
prices[i] = data[stock].price
|
||||
for i, sid in enumerate(algo.sids):
|
||||
current_amount[i] = algo.portfolio.positions[sid].amount
|
||||
prices[i] = data[sid].price
|
||||
|
||||
desired_amount = np.round(desired_port * positions_value / prices)
|
||||
|
||||
algo.last_desired_port = desired_port
|
||||
diff_amount = desired_amount - current_amount
|
||||
|
||||
for i, stock in enumerate(algo.stocks):
|
||||
algo.order(stock, diff_amount[i])
|
||||
for i, sid in enumerate(algo.sids):
|
||||
algo.order(sid, diff_amount[i])
|
||||
|
||||
|
||||
def simplex_projection(v, b=1):
|
||||
@@ -154,7 +155,9 @@ if __name__ == '__main__':
|
||||
end = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=STOCKS, indexes={}, start=start, end=end)
|
||||
data = data.dropna()
|
||||
olmar = TradingAlgorithm(handle_data=handle_data, initialize=initialize)
|
||||
olmar = TradingAlgorithm(handle_data=handle_data,
|
||||
initialize=initialize,
|
||||
identifiers=STOCKS)
|
||||
results = olmar.run(data)
|
||||
results.portfolio_value.plot()
|
||||
pl.show()
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytz
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.transforms import batch_transform
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
from zipline.sources.data_frame_source import DataFrameSource
|
||||
|
||||
|
||||
@batch_transform
|
||||
@@ -59,11 +60,13 @@ class Pairtrade(TradingAlgorithm):
|
||||
self.window_length = window_length
|
||||
self.ols_transform = ols_transform(refresh_period=self.window_length,
|
||||
window_length=self.window_length)
|
||||
self.PEP = self.symbol('PEP')
|
||||
self.KO = self.symbol('KO')
|
||||
|
||||
def handle_data(self, data):
|
||||
######################################################
|
||||
# 1. Compute regression coefficients between PEP and KO
|
||||
params = self.ols_transform.handle_data(data, 'PEP', 'KO')
|
||||
params = self.ols_transform.handle_data(data, self.PEP, self.KO)
|
||||
if params is None:
|
||||
return
|
||||
intercept, slope = params
|
||||
@@ -81,7 +84,8 @@ class Pairtrade(TradingAlgorithm):
|
||||
"""1. Compute the spread given slope and intercept.
|
||||
2. zscore the spread.
|
||||
"""
|
||||
spread = (data['PEP'].price - (slope * data['KO'].price + intercept))
|
||||
spread = (data[self.PEP].price -
|
||||
(slope * data[self.KO].price + intercept))
|
||||
self.spreads.append(spread)
|
||||
spread_wind = self.spreads[-self.window_length:]
|
||||
zscore = (spread - np.mean(spread_wind)) / np.std(spread_wind)
|
||||
@@ -91,12 +95,12 @@ class Pairtrade(TradingAlgorithm):
|
||||
"""Buy spread if zscore is > 2, sell if zscore < .5.
|
||||
"""
|
||||
if zscore >= 2.0 and not self.invested:
|
||||
self.order('PEP', int(100 / data['PEP'].price))
|
||||
self.order('KO', -int(100 / data['KO'].price))
|
||||
self.order(self.PEP, int(100 / data[self.PEP].price))
|
||||
self.order(self.KO, -int(100 / data[self.KO].price))
|
||||
self.invested = True
|
||||
elif zscore <= -2.0 and not self.invested:
|
||||
self.order('PEP', -int(100 / data['PEP'].price))
|
||||
self.order('KO', int(100 / data['KO'].price))
|
||||
self.order(self.PEP, -int(100 / data[self.PEP].price))
|
||||
self.order(self.KO, int(100 / data[self.KO].price))
|
||||
self.invested = True
|
||||
elif abs(zscore) < .5 and self.invested:
|
||||
self.sell_spread()
|
||||
@@ -107,23 +111,25 @@ class Pairtrade(TradingAlgorithm):
|
||||
decrease exposure, regardless of position long/short.
|
||||
buy for a short position, sell for a long.
|
||||
"""
|
||||
ko_amount = self.portfolio.positions['KO'].amount
|
||||
self.order('KO', -1 * ko_amount)
|
||||
pep_amount = self.portfolio.positions['PEP'].amount
|
||||
self.order('PEP', -1 * pep_amount)
|
||||
ko_amount = self.portfolio.positions[self.KO].amount
|
||||
self.order(self.KO, -1 * ko_amount)
|
||||
pep_amount = self.portfolio.positions[self.PEP].amount
|
||||
self.order(self.PEP, -1 * pep_amount)
|
||||
|
||||
if __name__ == '__main__':
|
||||
start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
|
||||
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
|
||||
start=start, end=end)
|
||||
source = DataFrameSource(data)
|
||||
|
||||
pairtrade = Pairtrade()
|
||||
results = pairtrade.run(data)
|
||||
results = pairtrade.run(source)
|
||||
data['spreads'] = np.nan
|
||||
|
||||
ax1 = plt.subplot(211)
|
||||
data[['PEP', 'KO']].plot(ax=ax1)
|
||||
# TODO Bugged - indices are out of bounds
|
||||
# data[[pairtrade.PEPsid, pairtrade.KOsid]].plot(ax=ax1)
|
||||
plt.ylabel('price')
|
||||
plt.setp(ax1.get_xticklabels(), visible=False)
|
||||
|
||||
|
||||
@@ -19,15 +19,16 @@ import pytz
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.utils.factory import load_from_yahoo
|
||||
|
||||
from zipline.api import order
|
||||
from zipline.api import order, symbol
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.test = 10
|
||||
context.aapl = symbol('AAPL')
|
||||
|
||||
|
||||
def handle_date(context, data):
|
||||
order('AAPL', 10)
|
||||
order(context.aapl, 10)
|
||||
print(context.test)
|
||||
|
||||
|
||||
@@ -39,7 +40,8 @@ if __name__ == '__main__':
|
||||
end=end)
|
||||
data = data.dropna()
|
||||
algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_date)
|
||||
handle_data=handle_date,
|
||||
identifiers=['AAPL'])
|
||||
results = algo.run(data)
|
||||
results.portfolio_value.plot()
|
||||
pl.show()
|
||||
|
||||
+61
-36
@@ -37,7 +37,7 @@ class TradingControl(with_metaclass(abc.ABCMeta)):
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
portfolio,
|
||||
algo_datetime,
|
||||
@@ -46,21 +46,22 @@ class TradingControl(with_metaclass(abc.ABCMeta)):
|
||||
Before any order is executed by TradingAlgorithm, this method should be
|
||||
called *exactly once* on each registered TradingControl object.
|
||||
|
||||
If the specified sid and amount do not violate this TradingControl's
|
||||
If the specified asset and amount do not violate this TradingControl's
|
||||
restraint given the information in `portfolio`, this method should
|
||||
return None and have no externally-visible side-effects.
|
||||
|
||||
If the desired order violates this TradingControl's contraint, this
|
||||
method should call self.fail(sid, amount).
|
||||
method should call self.fail(asset, amount).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fail(self, sid, amount):
|
||||
def fail(self, asset, amount, datetime):
|
||||
"""
|
||||
Raise a TradingControlViolation with information about the failure.
|
||||
"""
|
||||
raise TradingControlViolation(sid=sid,
|
||||
raise TradingControlViolation(asset=asset,
|
||||
amount=amount,
|
||||
datetime=datetime,
|
||||
constraint=repr(self))
|
||||
|
||||
def __repr__(self):
|
||||
@@ -82,7 +83,7 @@ class MaxOrderCount(TradingControl):
|
||||
self.current_date = None
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
_portfolio,
|
||||
algo_datetime,
|
||||
@@ -98,13 +99,13 @@ class MaxOrderCount(TradingControl):
|
||||
self.current_date = algo_date
|
||||
|
||||
if self.orders_placed >= self.max_count:
|
||||
self.fail(sid, amount)
|
||||
self.fail(asset, amount, algo_datetime)
|
||||
self.orders_placed += 1
|
||||
|
||||
|
||||
class RestrictedListOrder(TradingControl):
|
||||
"""
|
||||
TradingControl representing a restricted list of securities that
|
||||
TradingControl representing a restricted list of assets that
|
||||
cannot be ordered by the algorithm.
|
||||
"""
|
||||
|
||||
@@ -119,30 +120,30 @@ class RestrictedListOrder(TradingControl):
|
||||
self.restricted_list = restricted_list
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
_portfolio,
|
||||
_algo_datetime,
|
||||
_algo_current_data):
|
||||
"""
|
||||
Fail if the sid is in the restricted_list.
|
||||
Fail if the asset is in the restricted_list.
|
||||
"""
|
||||
if sid in self.restricted_list:
|
||||
self.fail(sid, amount)
|
||||
if asset in self.restricted_list:
|
||||
self.fail(asset, amount, _algo_datetime)
|
||||
|
||||
|
||||
class MaxOrderSize(TradingControl):
|
||||
"""
|
||||
TradingControl representing a limit on the magnitude of any single order
|
||||
placed with the given security. Can be specified by share or by dollar
|
||||
placed with the given asset. Can be specified by share or by dollar
|
||||
value.
|
||||
"""
|
||||
|
||||
def __init__(self, sid=None, max_shares=None, max_notional=None):
|
||||
super(MaxOrderSize, self).__init__(sid=sid,
|
||||
def __init__(self, asset=None, max_shares=None, max_notional=None):
|
||||
super(MaxOrderSize, self).__init__(asset=asset,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.sid = sid
|
||||
self.asset = asset
|
||||
self.max_shares = max_shares
|
||||
self.max_notional = max_notional
|
||||
|
||||
@@ -162,7 +163,7 @@ class MaxOrderSize(TradingControl):
|
||||
)
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
portfolio,
|
||||
_algo_datetime,
|
||||
@@ -172,33 +173,33 @@ class MaxOrderSize(TradingControl):
|
||||
or self.max_notional.
|
||||
"""
|
||||
|
||||
if self.sid is not None and self.sid != sid:
|
||||
if self.asset is not None and self.asset != asset:
|
||||
return
|
||||
|
||||
if self.max_shares is not None and abs(amount) > self.max_shares:
|
||||
self.fail(sid, amount)
|
||||
self.fail(asset, amount, _algo_datetime)
|
||||
|
||||
current_sid_price = algo_current_data[sid].price
|
||||
order_value = amount * current_sid_price
|
||||
current_asset_price = algo_current_data[asset].price
|
||||
order_value = amount * current_asset_price
|
||||
|
||||
too_much_value = (self.max_notional is not None and
|
||||
abs(order_value) > self.max_notional)
|
||||
|
||||
if too_much_value:
|
||||
self.fail(sid, amount)
|
||||
self.fail(asset, amount, _algo_datetime)
|
||||
|
||||
|
||||
class MaxPositionSize(TradingControl):
|
||||
"""
|
||||
TradingControl representing a limit on the maximum position size that can
|
||||
be held by an algo for a given security.
|
||||
be held by an algo for a given asset.
|
||||
"""
|
||||
|
||||
def __init__(self, sid=None, max_shares=None, max_notional=None):
|
||||
super(MaxPositionSize, self).__init__(sid=sid,
|
||||
def __init__(self, asset=None, max_shares=None, max_notional=None):
|
||||
super(MaxPositionSize, self).__init__(asset=asset,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.sid = sid
|
||||
self.asset = asset
|
||||
self.max_shares = max_shares
|
||||
self.max_notional = max_notional
|
||||
|
||||
@@ -218,7 +219,7 @@ class MaxPositionSize(TradingControl):
|
||||
)
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
portfolio,
|
||||
algo_datetime,
|
||||
@@ -229,25 +230,25 @@ class MaxPositionSize(TradingControl):
|
||||
self.max_notional.
|
||||
"""
|
||||
|
||||
if self.sid is not None and self.sid != sid:
|
||||
if self.asset is not None and self.asset != asset:
|
||||
return
|
||||
|
||||
current_share_count = portfolio.positions[sid].amount
|
||||
current_share_count = portfolio.positions[asset].amount
|
||||
shares_post_order = current_share_count + amount
|
||||
|
||||
too_many_shares = (self.max_shares is not None and
|
||||
abs(shares_post_order) > self.max_shares)
|
||||
if too_many_shares:
|
||||
self.fail(sid, amount)
|
||||
self.fail(asset, amount, algo_datetime)
|
||||
|
||||
current_price = algo_current_data[sid].price
|
||||
current_price = algo_current_data[asset].price
|
||||
value_post_order = shares_post_order * current_price
|
||||
|
||||
too_much_value = (self.max_notional is not None and
|
||||
abs(value_post_order) > self.max_notional)
|
||||
|
||||
if too_much_value:
|
||||
self.fail(sid, amount)
|
||||
self.fail(asset, amount, algo_datetime)
|
||||
|
||||
|
||||
class LongOnly(TradingControl):
|
||||
@@ -256,17 +257,41 @@ class LongOnly(TradingControl):
|
||||
"""
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
asset,
|
||||
amount,
|
||||
portfolio,
|
||||
_algo_datetime,
|
||||
_algo_current_data):
|
||||
"""
|
||||
Fail if we would hold negative shares of sid after completing this
|
||||
Fail if we would hold negative shares of asset after completing this
|
||||
order.
|
||||
"""
|
||||
if portfolio.positions[sid].amount + amount < 0:
|
||||
self.fail(sid, amount)
|
||||
if portfolio.positions[asset].amount + amount < 0:
|
||||
self.fail(asset, amount, _algo_datetime)
|
||||
|
||||
|
||||
class AssetDateBounds(TradingControl):
|
||||
"""
|
||||
TradingControl representing a prohibition against ordering an asset before
|
||||
its start_date, or after its end_date.
|
||||
"""
|
||||
|
||||
def validate(self,
|
||||
asset,
|
||||
amount,
|
||||
portfolio,
|
||||
algo_datetime,
|
||||
algo_current_data):
|
||||
"""
|
||||
Fail if the algo has passed this Asset's end_date, or before the
|
||||
Asset's start date.
|
||||
"""
|
||||
# Fail if the algo is before this Asset's start_date
|
||||
if asset.start_date and (algo_datetime < asset.start_date):
|
||||
self.fail(asset, amount, algo_datetime)
|
||||
# Fail if the algo has passed this Asset's end_date
|
||||
if asset.end_date and (algo_datetime >= asset.end_date):
|
||||
self.fail(asset, amount, algo_datetime)
|
||||
|
||||
|
||||
class AccountControl(with_metaclass(abc.ABCMeta)):
|
||||
|
||||
@@ -31,7 +31,7 @@ omitted).
|
||||
| | end of the period |
|
||||
+---------------+------------------------------------------------------+
|
||||
| cash_flow | the cash flow in the period (negative means spent) |
|
||||
| | from buying and selling securities in the period. |
|
||||
| | from buying and selling assets in the period. |
|
||||
| | Includes dividend payments in the period as well. |
|
||||
+---------------+------------------------------------------------------+
|
||||
| starting_value| the total market value of the positions held at the |
|
||||
@@ -75,6 +75,9 @@ import logbook
|
||||
|
||||
import numpy as np
|
||||
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.assets import Future
|
||||
|
||||
try:
|
||||
# optional cython based OrderedDict
|
||||
from cyordereddict import OrderedDict
|
||||
@@ -128,6 +131,7 @@ class PerformancePeriod(object):
|
||||
self.period_close = period_close
|
||||
|
||||
self.ending_value = 0.0
|
||||
self.ending_exposure = 0.0
|
||||
self.period_cash_flow = 0.0
|
||||
self.pnl = 0.0
|
||||
|
||||
@@ -160,6 +164,7 @@ class PerformancePeriod(object):
|
||||
|
||||
def rollover(self):
|
||||
self.starting_value = self.ending_value
|
||||
self.starting_exposure = self.ending_exposure
|
||||
self.starting_cash = self.ending_cash
|
||||
self.period_cash_flow = 0.0
|
||||
self.pnl = 0.0
|
||||
@@ -187,6 +192,7 @@ class PerformancePeriod(object):
|
||||
|
||||
def calculate_performance(self):
|
||||
self.ending_value = self.calculate_positions_value()
|
||||
self.ending_exposure = self.calculate_positions_exposure()
|
||||
|
||||
total_at_start = self.starting_cash + self.starting_value
|
||||
self.ending_cash = self.starting_cash + self.period_cash_flow
|
||||
@@ -215,7 +221,12 @@ class PerformancePeriod(object):
|
||||
self.orders_by_id[order.id] = order
|
||||
|
||||
def handle_execution(self, txn):
|
||||
self.period_cash_flow -= txn.price * txn.amount
|
||||
asset = TradingEnvironment.instance().asset_finder.\
|
||||
retrieve_asset(txn.sid)
|
||||
|
||||
# Futures experience no cash flow on transactions
|
||||
if not isinstance(asset, Future):
|
||||
self.period_cash_flow -= txn.price * txn.amount
|
||||
|
||||
if self.keep_transactions:
|
||||
try:
|
||||
@@ -232,6 +243,10 @@ class PerformancePeriod(object):
|
||||
def position_amounts(self):
|
||||
return self.position_tracker.position_amounts
|
||||
|
||||
@position_proxy
|
||||
def calculate_positions_exposure(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def calculate_positions_value(self):
|
||||
raise ProxyError()
|
||||
@@ -244,6 +259,10 @@ class PerformancePeriod(object):
|
||||
def _long_exposure(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _long_value(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _shorts_count(self):
|
||||
raise ProxyError()
|
||||
@@ -252,19 +271,29 @@ class PerformancePeriod(object):
|
||||
def _short_exposure(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _short_value(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _gross_exposure(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _gross_value(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _net_exposure(self):
|
||||
raise ProxyError()
|
||||
|
||||
@position_proxy
|
||||
def _net_value(self):
|
||||
raise ProxyError()
|
||||
|
||||
@property
|
||||
def _net_liquidation_value(self):
|
||||
return self.ending_cash + \
|
||||
self._long_exposure() + \
|
||||
self._short_exposure()
|
||||
return self.ending_cash + self._long_value() + self._short_value()
|
||||
|
||||
def _gross_leverage(self):
|
||||
net_liq = self._net_liquidation_value
|
||||
@@ -283,10 +312,12 @@ class PerformancePeriod(object):
|
||||
def __core_dict(self):
|
||||
rval = {
|
||||
'ending_value': self.ending_value,
|
||||
'ending_exposure': self.ending_exposure,
|
||||
# this field is renamed to capital_used for backward
|
||||
# compatibility.
|
||||
'capital_used': self.period_cash_flow,
|
||||
'starting_value': self.starting_value,
|
||||
'starting_exposure': self.starting_exposure,
|
||||
'starting_cash': self.starting_cash,
|
||||
'ending_cash': self.ending_cash,
|
||||
'portfolio_value': self.ending_cash + self.ending_value,
|
||||
@@ -298,6 +329,8 @@ class PerformancePeriod(object):
|
||||
'net_leverage': self._net_leverage(),
|
||||
'short_exposure': self._short_exposure(),
|
||||
'long_exposure': self._long_exposure(),
|
||||
'short_value': self._short_value(),
|
||||
'long_value': self._long_value(),
|
||||
'longs_count': self._longs_count(),
|
||||
'shorts_count': self._shorts_count()
|
||||
}
|
||||
@@ -372,6 +405,7 @@ class PerformancePeriod(object):
|
||||
portfolio.start_date = self.period_open
|
||||
portfolio.positions = self.get_positions()
|
||||
portfolio.positions_value = self.ending_value
|
||||
portfolio.positions_exposure = self.ending_exposure
|
||||
return portfolio
|
||||
|
||||
def as_account(self):
|
||||
@@ -393,6 +427,8 @@ class PerformancePeriod(object):
|
||||
self.ending_cash + self.ending_value)
|
||||
account.total_positions_value = \
|
||||
getattr(self, 'total_positions_value', self.ending_value)
|
||||
account.total_positions_value = \
|
||||
getattr(self, 'total_positions_exposure', self.ending_exposure)
|
||||
account.regt_equity = \
|
||||
getattr(self, 'regt_equity', self.ending_cash)
|
||||
account.regt_margin = \
|
||||
|
||||
@@ -20,12 +20,11 @@ Position Tracking
|
||||
+-----------------+----------------------------------------------------+
|
||||
| key | value |
|
||||
+=================+====================================================+
|
||||
| sid | the identifier for the security held in this |
|
||||
| | position. |
|
||||
| sid | the sid for the asset held in this position |
|
||||
+-----------------+----------------------------------------------------+
|
||||
| amount | whole number of shares in the position |
|
||||
+-----------------+----------------------------------------------------+
|
||||
| last_sale_price | price at last sale of the security on the exchange |
|
||||
| last_sale_price | price at last sale of the asset on the exchange |
|
||||
+-----------------+----------------------------------------------------+
|
||||
| cost_basis | the volume weighted average price paid per share |
|
||||
+-----------------+----------------------------------------------------+
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import division
|
||||
from operator import mul
|
||||
|
||||
import logbook
|
||||
import numpy as np
|
||||
@@ -10,8 +9,7 @@ try:
|
||||
from cyordereddict import OrderedDict
|
||||
except ImportError:
|
||||
from collections import OrderedDict
|
||||
from six import iteritems
|
||||
from six.moves import map, filter
|
||||
from six import iteritems, itervalues
|
||||
|
||||
from zipline.finance.slippage import Transaction
|
||||
from zipline.utils.serialization_utils import (
|
||||
@@ -19,6 +17,10 @@ from zipline.utils.serialization_utils import (
|
||||
)
|
||||
|
||||
import zipline.protocol as zp
|
||||
from zipline.assets import (
|
||||
Equity, Future
|
||||
)
|
||||
from zipline.finance.trading import with_environment
|
||||
from . position import positiondict
|
||||
|
||||
log = logbook.Logger('Performance')
|
||||
@@ -32,26 +34,69 @@ class PositionTracker(object):
|
||||
# Arrays for quick calculations of positions value
|
||||
self._position_amounts = OrderedDict()
|
||||
self._position_last_sale_prices = OrderedDict()
|
||||
self._position_value_multipliers = OrderedDict()
|
||||
self._position_exposure_multipliers = OrderedDict()
|
||||
self._unpaid_dividends = pd.DataFrame(
|
||||
columns=zp.DIVIDEND_PAYMENT_FIELDS,
|
||||
)
|
||||
self._positions_store = zp.Positions()
|
||||
|
||||
# Cached for fast property calculation
|
||||
self._position_values = None
|
||||
self._position_exposures = None
|
||||
|
||||
def _invalidate_cache(self):
|
||||
self._position_values = None
|
||||
self._position_exposures = None
|
||||
|
||||
@with_environment()
|
||||
def _retrieve_asset(self, sid, env=None):
|
||||
return env.asset_finder.retrieve_asset(sid)
|
||||
|
||||
def _update_multipliers(self, sid):
|
||||
try:
|
||||
self._position_value_multipliers[sid]
|
||||
self._position_exposure_multipliers[sid]
|
||||
except KeyError:
|
||||
# Collect the value multipliers from applicable sids
|
||||
asset = self._retrieve_asset(sid)
|
||||
if isinstance(asset, Equity):
|
||||
self._position_value_multipliers[sid] = 1
|
||||
self._position_exposure_multipliers[sid] = 1
|
||||
if isinstance(asset, Future):
|
||||
self._position_value_multipliers[sid] = 0
|
||||
self._position_exposure_multipliers[sid] = \
|
||||
asset.contract_multiplier
|
||||
|
||||
def update_last_sale(self, event):
|
||||
# NOTE, PerformanceTracker already vetted as TRADE type
|
||||
sid = event.sid
|
||||
if sid not in self.positions:
|
||||
return
|
||||
return 0
|
||||
|
||||
price = event.price
|
||||
if not checknull(price):
|
||||
pos = self.positions[sid]
|
||||
pos.last_sale_date = event.dt
|
||||
pos.last_sale_price = price
|
||||
self._position_last_sale_prices[sid] = price
|
||||
self._position_values = None # invalidate cache
|
||||
sid = event.sid
|
||||
price = event.price
|
||||
|
||||
if checknull(price):
|
||||
return 0
|
||||
|
||||
pos = self.positions[sid]
|
||||
old_price = pos.last_sale_price
|
||||
pos.last_sale_date = event.dt
|
||||
pos.last_sale_price = price
|
||||
self._position_last_sale_prices[sid] = price
|
||||
self._invalidate_cache()
|
||||
|
||||
asset = self._retrieve_asset(sid)
|
||||
if asset is None:
|
||||
return 0
|
||||
|
||||
# Calculate cash adjustment on futures
|
||||
cash_adjustment = 0
|
||||
if isinstance(asset, Future):
|
||||
price_change = price - old_price
|
||||
cash_adjustment = \
|
||||
price_change * asset.contract_multiplier * pos.amount
|
||||
return cash_adjustment
|
||||
|
||||
def update_positions(self, positions):
|
||||
# update positions in batch
|
||||
@@ -59,8 +104,8 @@ class PositionTracker(object):
|
||||
for sid, pos in iteritems(positions):
|
||||
self._position_amounts[sid] = pos.amount
|
||||
self._position_last_sale_prices[sid] = pos.last_sale_price
|
||||
# Invalidate cache.
|
||||
self._position_values = None # invalidate cache
|
||||
self._update_multipliers(sid)
|
||||
self._invalidate_cache()
|
||||
|
||||
def update_position(self, sid, amount=None, last_sale_price=None,
|
||||
last_sale_date=None, cost_basis=None):
|
||||
@@ -70,6 +115,7 @@ class PositionTracker(object):
|
||||
pos.amount = amount
|
||||
self._position_amounts[sid] = amount
|
||||
self._position_values = None # invalidate cache
|
||||
self._update_multipliers(sid=sid)
|
||||
if last_sale_price is not None:
|
||||
pos.last_sale_price = last_sale_price
|
||||
self._position_last_sale_prices[sid] = last_sale_price
|
||||
@@ -82,13 +128,13 @@ class PositionTracker(object):
|
||||
def execute_transaction(self, txn):
|
||||
# Update Position
|
||||
# ----------------
|
||||
|
||||
sid = txn.sid
|
||||
position = self.positions[sid]
|
||||
position.update(txn)
|
||||
self._position_amounts[sid] = position.amount
|
||||
self._position_last_sale_prices[sid] = position.last_sale_price
|
||||
self._position_values = None # invalidate cache
|
||||
self._update_multipliers(sid)
|
||||
self._invalidate_cache()
|
||||
|
||||
def handle_commission(self, commission):
|
||||
# Adjust the cost basis of the stock if we own it
|
||||
@@ -96,8 +142,6 @@ class PositionTracker(object):
|
||||
self.positions[commission.sid].\
|
||||
adjust_commission_cost_basis(commission)
|
||||
|
||||
_position_values = None
|
||||
|
||||
@property
|
||||
def position_values(self):
|
||||
"""
|
||||
@@ -105,33 +149,75 @@ class PositionTracker(object):
|
||||
self._position_last_sale_prices is changed.
|
||||
"""
|
||||
if self._position_values is None:
|
||||
vals = list(map(mul, self._position_amounts.values(),
|
||||
self._position_last_sale_prices.values()))
|
||||
self._position_values = vals
|
||||
iter_amount_price_multiplier = zip(
|
||||
itervalues(self._position_amounts),
|
||||
itervalues(self._position_last_sale_prices),
|
||||
itervalues(self._position_value_multipliers),
|
||||
)
|
||||
self._position_values = [
|
||||
price * amount * multiplier for
|
||||
price, amount, multiplier in iter_amount_price_multiplier
|
||||
]
|
||||
return self._position_values
|
||||
|
||||
@property
|
||||
def position_exposures(self):
|
||||
"""
|
||||
Invalidate any time self._position_amounts or
|
||||
self._position_last_sale_prices is changed.
|
||||
"""
|
||||
if self._position_exposures is None:
|
||||
iter_amount_price_multiplier = zip(
|
||||
itervalues(self._position_amounts),
|
||||
itervalues(self._position_last_sale_prices),
|
||||
itervalues(self._position_exposure_multipliers),
|
||||
)
|
||||
self._position_exposures = [
|
||||
price * amount * multiplier for
|
||||
price, amount, multiplier in iter_amount_price_multiplier
|
||||
]
|
||||
return self._position_exposures
|
||||
|
||||
def calculate_positions_value(self):
|
||||
if len(self.position_values) == 0:
|
||||
return np.float64(0)
|
||||
|
||||
return sum(self.position_values)
|
||||
|
||||
def calculate_positions_exposure(self):
|
||||
if len(self.position_exposures) == 0:
|
||||
return np.float64(0)
|
||||
|
||||
return sum(self.position_exposures)
|
||||
|
||||
def _longs_count(self):
|
||||
return sum(map(lambda x: x > 0, self.position_values))
|
||||
return sum(1 for i in self.position_exposures if i > 0)
|
||||
|
||||
def _long_exposure(self):
|
||||
return sum(filter(lambda x: x > 0, self.position_values))
|
||||
return sum(i for i in self.position_exposures if i > 0)
|
||||
|
||||
def _long_value(self):
|
||||
return sum(i for i in self.position_values if i > 0)
|
||||
|
||||
def _shorts_count(self):
|
||||
return sum(map(lambda x: x < 0, self.position_values))
|
||||
return sum(1 for i in self.position_exposures if i < 0)
|
||||
|
||||
def _short_exposure(self):
|
||||
return sum(filter(lambda x: x < 0, self.position_values))
|
||||
return sum(i for i in self.position_exposures if i < 0)
|
||||
|
||||
def _short_value(self):
|
||||
return sum(i for i in self.position_values if i < 0)
|
||||
|
||||
def _gross_exposure(self):
|
||||
return self._long_exposure() + abs(self._short_exposure())
|
||||
|
||||
def _gross_value(self):
|
||||
return self._long_value() + abs(self._short_value())
|
||||
|
||||
def _net_exposure(self):
|
||||
return self.calculate_positions_exposure()
|
||||
|
||||
def _net_value(self):
|
||||
return self.calculate_positions_value()
|
||||
|
||||
def handle_split(self, split):
|
||||
@@ -143,7 +229,8 @@ class PositionTracker(object):
|
||||
self._position_amounts[split.sid] = position.amount
|
||||
self._position_last_sale_prices[split.sid] = \
|
||||
position.last_sale_price
|
||||
self._position_values = None # invalidate cache
|
||||
self._update_multipliers(split.sid)
|
||||
self._invalidate_cache()
|
||||
return leftover_cash
|
||||
|
||||
def _maybe_earn_dividend(self, dividend):
|
||||
@@ -204,15 +291,17 @@ class PositionTracker(object):
|
||||
stock = row['payment_sid']
|
||||
share_count = row['share_count']
|
||||
# note we create a Position for stock dividend if we don't
|
||||
# already own the security
|
||||
# already own the asset
|
||||
position = self.positions[stock]
|
||||
|
||||
position.amount += share_count
|
||||
self._position_amounts[stock] = position.amount
|
||||
self._position_last_sale_prices[stock] = position.last_sale_price
|
||||
self._update_multipliers(stock)
|
||||
self._invalidate_cache()
|
||||
|
||||
# Add cash equal to the net cash payed from all dividends. Note that
|
||||
# "negative cash" is effectively paid if we're short a security,
|
||||
# "negative cash" is effectively paid if we're short an asset,
|
||||
# representing the fact that we're required to reimburse the owner of
|
||||
# the stock for any dividends paid while borrowing.
|
||||
net_cash_payment = payments['cash_amount'].fillna(0).sum()
|
||||
@@ -290,5 +379,8 @@ class PositionTracker(object):
|
||||
# Arrays for quick calculations of positions value
|
||||
self._position_amounts = OrderedDict()
|
||||
self._position_last_sale_prices = OrderedDict()
|
||||
self._position_value_multipliers = OrderedDict()
|
||||
self._position_exposure_multipliers = OrderedDict()
|
||||
self._invalidate_cache()
|
||||
|
||||
self.update_positions(state['positions'])
|
||||
|
||||
@@ -68,10 +68,9 @@ import pandas as pd
|
||||
from pandas.tseries.tools import normalize_date
|
||||
|
||||
import zipline.finance.risk as risk
|
||||
from zipline.finance import trading
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from . period import PerformancePeriod
|
||||
|
||||
from zipline.finance.trading import with_environment
|
||||
from zipline.utils.serialization_utils import (
|
||||
VERSION_LABEL
|
||||
)
|
||||
@@ -84,26 +83,23 @@ class PerformanceTracker(object):
|
||||
"""
|
||||
Tracks the performance of the algorithm.
|
||||
"""
|
||||
|
||||
@with_environment()
|
||||
def __init__(self, sim_params, env=None):
|
||||
def __init__(self, sim_params):
|
||||
|
||||
self.sim_params = sim_params
|
||||
env = TradingEnvironment.instance()
|
||||
|
||||
self.period_start = self.sim_params.period_start
|
||||
self.period_end = self.sim_params.period_end
|
||||
self.last_close = self.sim_params.last_close
|
||||
first_open = self.sim_params.first_open.tz_convert(
|
||||
trading.environment.exchange_tz)
|
||||
first_open = self.sim_params.first_open.tz_convert(env.exchange_tz)
|
||||
self.day = pd.Timestamp(datetime(first_open.year, first_open.month,
|
||||
first_open.day), tz='UTC')
|
||||
self.market_open, self.market_close = \
|
||||
trading.environment.get_open_and_close(self.day)
|
||||
self.market_open, self.market_close = env.get_open_and_close(self.day)
|
||||
self.total_days = self.sim_params.days_in_period
|
||||
self.capital_base = self.sim_params.capital_base
|
||||
self.emission_rate = sim_params.emission_rate
|
||||
|
||||
all_trading_days = trading.environment.trading_days
|
||||
all_trading_days = env.trading_days
|
||||
mask = ((all_trading_days >= normalize_date(self.period_start)) &
|
||||
(all_trading_days <= normalize_date(self.period_end)))
|
||||
|
||||
@@ -287,10 +283,13 @@ class PerformanceTracker(object):
|
||||
return _dict
|
||||
|
||||
def process_trade(self, event):
|
||||
self.position_tracker.update_last_sale(event)
|
||||
# update last sale, and pay out a cash adjustment
|
||||
cash_adjustment = self.position_tracker.update_last_sale(event)
|
||||
if cash_adjustment != 0:
|
||||
for perf_period in self.perf_periods:
|
||||
perf_period.handle_cash_payment(cash_adjustment)
|
||||
|
||||
def process_transaction(self, event):
|
||||
|
||||
self.txn_count += 1
|
||||
self.position_tracker.execute_transaction(event)
|
||||
for perf_period in self.perf_periods:
|
||||
@@ -342,7 +341,7 @@ class PerformanceTracker(object):
|
||||
if txn:
|
||||
self.process_transaction(txn)
|
||||
|
||||
def check_upcoming_dividends(self, midnight_of_date_that_just_ended):
|
||||
def check_upcoming_dividends(self, next_trading_day):
|
||||
"""
|
||||
Check if we currently own any stocks with dividends whose ex_date is
|
||||
the next trading day. Track how much we should be payed on those
|
||||
@@ -357,17 +356,6 @@ class PerformanceTracker(object):
|
||||
# period, so bail.
|
||||
return
|
||||
|
||||
next_trading_day_idx = self.trading_days.get_loc(
|
||||
midnight_of_date_that_just_ended,
|
||||
) + 1
|
||||
|
||||
if next_trading_day_idx < len(self.trading_days):
|
||||
next_trading_day = self.trading_days[next_trading_day_idx]
|
||||
else:
|
||||
# Bail if the next trading day is outside our trading range, since
|
||||
# we won't simulate the next day.
|
||||
return
|
||||
|
||||
# Dividends whose ex_date is the next trading day. We need to check if
|
||||
# we own any of these stocks so we know to pay them out when the pay
|
||||
# date comes.
|
||||
@@ -412,7 +400,10 @@ class PerformanceTracker(object):
|
||||
# if this is the close, save the returns objects for cumulative risk
|
||||
# calculations and update dividends for the next day.
|
||||
if dt == self.market_close:
|
||||
self.check_upcoming_dividends(todays_date)
|
||||
next_trading_day = TradingEnvironment.instance().\
|
||||
next_trading_day(todays_date)
|
||||
if next_trading_day:
|
||||
self.check_upcoming_dividends(next_trading_day)
|
||||
|
||||
def handle_intraday_market_close(self, new_mkt_open, new_mkt_close):
|
||||
"""
|
||||
@@ -454,16 +445,19 @@ class PerformanceTracker(object):
|
||||
return daily_update
|
||||
|
||||
# move the market day markers forward
|
||||
env = TradingEnvironment.instance()
|
||||
self.market_open, self.market_close = \
|
||||
trading.environment.next_open_and_close(self.day)
|
||||
self.day = trading.environment.next_trading_day(self.day)
|
||||
env.next_open_and_close(self.day)
|
||||
self.day = env.next_trading_day(self.day)
|
||||
|
||||
# Roll over positions to current day.
|
||||
self.todays_performance.rollover()
|
||||
self.todays_performance.period_open = self.market_open
|
||||
self.todays_performance.period_close = self.market_close
|
||||
|
||||
self.check_upcoming_dividends(completed_date)
|
||||
next_trading_day = env.next_trading_day(completed_date)
|
||||
if next_trading_day:
|
||||
self.check_upcoming_dividends(next_trading_day)
|
||||
|
||||
return daily_update
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ import numpy as np
|
||||
|
||||
from zipline.data.loader import load_market_data
|
||||
from zipline.utils import tradingcalendar
|
||||
from zipline.assets import AssetFinder
|
||||
from zipline.errors import UpdateAssetFinderTypeError
|
||||
|
||||
|
||||
log = logbook.Logger('Trading')
|
||||
@@ -134,6 +136,8 @@ class TradingEnvironment(object):
|
||||
|
||||
self.exchange_tz = exchange_tz
|
||||
|
||||
self.asset_finder = AssetFinder(trading_calendar=env_trading_calendar)
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
global environment
|
||||
self.prev_environment = environment
|
||||
@@ -149,6 +153,52 @@ class TradingEnvironment(object):
|
||||
# stack.
|
||||
return False
|
||||
|
||||
def update_asset_finder(self,
|
||||
clear_metadata=False,
|
||||
asset_finder=None,
|
||||
asset_metadata=None,
|
||||
identifiers=None):
|
||||
"""
|
||||
Updates the AssetFinder using the provided asset metadata and
|
||||
identifiers.
|
||||
If clear_metadata is True, all metadata and assets held in the
|
||||
asset_finder will be erased before new metadata is provided.
|
||||
If asset_finder is provided, the existing asset_finder will be replaced
|
||||
outright with the new asset_finder.
|
||||
If asset_metadata is provided, the existing metadata will be cleared
|
||||
and replaced with the provided metadata.
|
||||
All identifiers will be inserted in the asset metadata if they are not
|
||||
already present.
|
||||
|
||||
:param clear_metadata: A boolean
|
||||
:param asset_finder: An AssetFinder object to replace the environment's
|
||||
existing asset_finder
|
||||
:param asset_metadata: A dict, DataFrame, or readable object
|
||||
:param identifiers: A list of identifiers to be inserted
|
||||
:return:
|
||||
"""
|
||||
populate = False
|
||||
if clear_metadata:
|
||||
self.asset_finder.clear_metadata()
|
||||
populate = True
|
||||
|
||||
if asset_finder is not None:
|
||||
if not isinstance(asset_finder, AssetFinder):
|
||||
raise UpdateAssetFinderTypeError(cls=asset_finder.__class__)
|
||||
self.asset_finder = asset_finder
|
||||
|
||||
if asset_metadata is not None:
|
||||
self.asset_finder.clear_metadata()
|
||||
self.asset_finder.consume_metadata(asset_metadata)
|
||||
populate = True
|
||||
|
||||
if identifiers is not None:
|
||||
self.asset_finder.consume_identifiers(identifiers)
|
||||
populate = True
|
||||
|
||||
if populate:
|
||||
self.asset_finder.populate_cache()
|
||||
|
||||
def normalize_date(self, test_date):
|
||||
test_date = pd.Timestamp(test_date, tz='UTC')
|
||||
return pd.tseries.tools.normalize_date(test_date)
|
||||
|
||||
@@ -217,7 +217,7 @@ class HistoryContainer(object):
|
||||
history_specs (dict[Frequency:HistorySpec]): The starting history
|
||||
specs that this container should be able to service.
|
||||
|
||||
initial_sids (set[Security or Int]): The starting sids to watch.
|
||||
initial_sids (set[Asset or Int]): The starting sids to watch.
|
||||
|
||||
initial_dt (datetime): The datetime to start collecting history from.
|
||||
|
||||
|
||||
+7
-3
@@ -58,7 +58,12 @@ DIVIDEND_FIELDS = [
|
||||
'sid',
|
||||
]
|
||||
# Expected fields/index values for a dividend payment Series.
|
||||
DIVIDEND_PAYMENT_FIELDS = ['id', 'payment_sid', 'cash_amount', 'share_count']
|
||||
DIVIDEND_PAYMENT_FIELDS = [
|
||||
'id',
|
||||
'payment_sid',
|
||||
'cash_amount',
|
||||
'share_count',
|
||||
]
|
||||
|
||||
|
||||
def dividend_payment(data=None):
|
||||
@@ -73,7 +78,7 @@ def dividend_payment(data=None):
|
||||
payment.
|
||||
|
||||
Additionally, if @data is non-empty, either data['cash_amount'] should be
|
||||
nonzero or data['payment_sid'] should be a security identifier and
|
||||
nonzero or data['payment_sid'] should be an asset identifier and
|
||||
data['share_count'] should be nonzero.
|
||||
|
||||
The returned Series is given its id value as a name so that concatenating
|
||||
@@ -289,7 +294,6 @@ class SIDData(object):
|
||||
|
||||
def __init__(self, sid, initial_values=None):
|
||||
self._sid = sid
|
||||
|
||||
self._freqstr = None
|
||||
|
||||
# To check if we have data, we use the __len__ which depends on the
|
||||
|
||||
@@ -22,6 +22,7 @@ import pandas as pd
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
from zipline.sources.data_source import DataSource
|
||||
from zipline.finance.trading import with_environment
|
||||
|
||||
|
||||
class DataFrameSource(DataSource):
|
||||
@@ -36,14 +37,23 @@ class DataFrameSource(DataSource):
|
||||
Bars where the price is nan are filtered out.
|
||||
"""
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
@with_environment()
|
||||
def __init__(self, data, env=None, **kwargs):
|
||||
assert isinstance(data.index, pd.tseries.index.DatetimeIndex)
|
||||
|
||||
self.data = data
|
||||
self.data = data.fillna(method='ffill')
|
||||
# Unpack config dictionary with default values.
|
||||
self.sids = kwargs.get('sids', data.columns)
|
||||
self.start = kwargs.get('start', data.index[0])
|
||||
self.end = kwargs.get('end', data.index[-1])
|
||||
self.start = kwargs.get('start', self.data.index[0])
|
||||
self.end = kwargs.get('end', self.data.index[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
self.identifiers = kwargs.get('sids', self.data.columns)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
self.data.columns = [
|
||||
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
|
||||
for identifier in self.data.columns
|
||||
]
|
||||
self.sids = self.data.columns
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(data, **kwargs)
|
||||
@@ -105,14 +115,24 @@ class DataPanelSource(DataSource):
|
||||
Bars where the price is nan are filtered out.
|
||||
"""
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
@with_environment()
|
||||
def __init__(self, data, env=None, **kwargs):
|
||||
assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex)
|
||||
|
||||
self.data = data
|
||||
self.data = data.fillna(value={'volume': 0})
|
||||
self.data = self.data.fillna(method='ffill')
|
||||
# Unpack config dictionary with default values.
|
||||
self.sids = kwargs.get('sids', data.items)
|
||||
self.start = kwargs.get('start', data.major_axis[0])
|
||||
self.end = kwargs.get('end', data.major_axis[-1])
|
||||
self.start = kwargs.get('start', self.data.major_axis[0])
|
||||
self.end = kwargs.get('end', self.data.major_axis[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
self.identifiers = kwargs.get('sids', self.data.items)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
self.data.items = [
|
||||
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
|
||||
for identifier in self.data.items
|
||||
]
|
||||
self.sids = self.data.items
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(data, **kwargs)
|
||||
|
||||
@@ -23,6 +23,7 @@ import pandas as pd
|
||||
from zipline.sources.data_source import DataSource
|
||||
from zipline.utils import tradingcalendar as calendar_nyse
|
||||
from zipline.gens.utils import hash_args
|
||||
from zipline.finance import trading
|
||||
|
||||
|
||||
class RandomWalkSource(DataSource):
|
||||
@@ -92,6 +93,7 @@ class RandomWalkSource(DataSource):
|
||||
self.sd = sd
|
||||
|
||||
self.sids = self.start_prices.keys()
|
||||
trading.environment.update_asset_finder(identifiers=self.sids)
|
||||
|
||||
self.open_and_closes = \
|
||||
calendar.open_and_closes[self.start:self.end]
|
||||
|
||||
@@ -22,7 +22,6 @@ import pytz
|
||||
from six.moves import filter
|
||||
from datetime import datetime, timedelta
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from six.moves import range
|
||||
|
||||
@@ -112,8 +111,8 @@ class SpecificEquityTrades(object):
|
||||
delta : timedelta between internal events
|
||||
filter : filter to remove the sids
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@with_environment()
|
||||
def __init__(self, env=None, *args, **kwargs):
|
||||
# We shouldn't get any positional arguments.
|
||||
assert len(args) == 0
|
||||
|
||||
@@ -125,9 +124,7 @@ class SpecificEquityTrades(object):
|
||||
# This isn't really clean and ultimately I think this
|
||||
# class should serve a single purpose (either take an
|
||||
# event_list or autocreate events).
|
||||
self.sids = kwargs.get(
|
||||
'sids',
|
||||
np.unique([event.sid for event in self.event_list]).tolist())
|
||||
self.count = kwargs.get('count', len(self.event_list))
|
||||
self.start = kwargs.get('start', self.event_list[0].dt)
|
||||
self.end = kwargs.get('end', self.event_list[-1].dt)
|
||||
self.delta = kwargs.get(
|
||||
@@ -135,9 +132,22 @@ class SpecificEquityTrades(object):
|
||||
self.event_list[1].dt - self.event_list[0].dt)
|
||||
self.concurrent = kwargs.get('concurrent', False)
|
||||
|
||||
self.identifiers = kwargs.get(
|
||||
'sids',
|
||||
set(event.sid for event in self.event_list)
|
||||
)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
self.sids = [
|
||||
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
|
||||
for identifier in self.identifiers
|
||||
]
|
||||
for event in self.event_list:
|
||||
event.sid = env.asset_finder.\
|
||||
retrieve_asset_by_identifier(event.sid).sid
|
||||
|
||||
else:
|
||||
# Unpack config dictionary with default values.
|
||||
self.sids = kwargs.get('sids', [1, 2])
|
||||
self.count = kwargs.get('count', 500)
|
||||
self.start = kwargs.get(
|
||||
'start',
|
||||
datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
|
||||
@@ -149,6 +159,13 @@ class SpecificEquityTrades(object):
|
||||
timedelta(minutes=1))
|
||||
self.concurrent = kwargs.get('concurrent', False)
|
||||
|
||||
self.identifiers = kwargs.get('sids', [1, 2])
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
self.sids = [
|
||||
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
|
||||
for identifier in self.identifiers
|
||||
]
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(*args, **kwargs)
|
||||
|
||||
|
||||
+107
-63
@@ -85,14 +85,17 @@ from zipline.api import (
|
||||
order,
|
||||
set_slippage,
|
||||
record,
|
||||
sid,
|
||||
)
|
||||
from zipline.errors import UnsupportedOrderParameters
|
||||
from zipline.assets import Future, Equity
|
||||
from zipline.finance.execution import (
|
||||
LimitOrder,
|
||||
MarketOrder,
|
||||
StopLimitOrder,
|
||||
StopOrder,
|
||||
)
|
||||
from zipline.finance.controls import AssetDateBounds
|
||||
from zipline.transforms import BatchTransform, batch_transform
|
||||
|
||||
|
||||
@@ -111,14 +114,14 @@ class TestAlgorithm(TradingAlgorithm):
|
||||
slippage=None,
|
||||
commission=None):
|
||||
self.count = order_count
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
self.amount = amount
|
||||
self.incr = 0
|
||||
|
||||
if sid_filter:
|
||||
self.sid_filter = sid_filter
|
||||
else:
|
||||
self.sid_filter = [self.sid]
|
||||
self.sid_filter = [self.asset.sid]
|
||||
|
||||
if slippage is not None:
|
||||
self.set_slippage(slippage)
|
||||
@@ -129,7 +132,7 @@ class TestAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
# place an order for amount shares of sid
|
||||
if self.incr < self.count:
|
||||
self.order(self.sid, self.amount)
|
||||
self.order(self.asset, self.amount)
|
||||
self.incr += 1
|
||||
|
||||
|
||||
@@ -141,13 +144,13 @@ class HeavyBuyAlgorithm(TradingAlgorithm):
|
||||
"""
|
||||
|
||||
def initialize(self, sid, amount):
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
self.amount = amount
|
||||
self.incr = 0
|
||||
|
||||
def handle_data(self, data):
|
||||
# place an order for 100 shares of sid
|
||||
self.order(self.sid, self.amount)
|
||||
self.order(self.asset, self.amount)
|
||||
self.incr += 1
|
||||
|
||||
|
||||
@@ -177,7 +180,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, throw_from, sid):
|
||||
|
||||
self.throw_from = throw_from
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
|
||||
if self.throw_from == "initialize":
|
||||
raise Exception("Algo exception in initialize")
|
||||
@@ -200,7 +203,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
|
||||
if self.throw_from == "get_sid_filter":
|
||||
raise Exception("Algo exception in get_sid_filter")
|
||||
else:
|
||||
return [self.sid]
|
||||
return [self.asset]
|
||||
|
||||
def set_transact_setter(self, txn_sim_callable):
|
||||
pass
|
||||
@@ -209,7 +212,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
|
||||
class DivByZeroAlgorithm(TradingAlgorithm):
|
||||
|
||||
def initialize(self, sid):
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
self.incr = 0
|
||||
|
||||
def handle_data(self, data):
|
||||
@@ -222,7 +225,7 @@ class DivByZeroAlgorithm(TradingAlgorithm):
|
||||
class TooMuchProcessingAlgorithm(TradingAlgorithm):
|
||||
|
||||
def initialize(self, sid):
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
|
||||
def handle_data(self, data):
|
||||
# Unless we're running on some sort of
|
||||
@@ -234,7 +237,7 @@ class TooMuchProcessingAlgorithm(TradingAlgorithm):
|
||||
class TimeoutAlgorithm(TradingAlgorithm):
|
||||
|
||||
def initialize(self, sid):
|
||||
self.sid = sid
|
||||
self.asset = self.sid(sid)
|
||||
self.incr = 0
|
||||
|
||||
def handle_data(self, data):
|
||||
@@ -269,7 +272,7 @@ class TestOrderAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
self.incr += 1
|
||||
self.order(0, 1)
|
||||
self.order(self.sid(0), 1)
|
||||
|
||||
|
||||
class TestOrderInstantAlgorithm(TradingAlgorithm):
|
||||
@@ -286,7 +289,7 @@ class TestOrderInstantAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
self.last_price, "Orders was not filled at last price."
|
||||
self.incr += 2
|
||||
self.order_value(0, data[0].price * 2.)
|
||||
self.order_value(self.sid(0), data[0].price * 2.)
|
||||
self.last_price = data[0].price
|
||||
|
||||
|
||||
@@ -311,7 +314,9 @@ class TestOrderStyleForwardingAlgorithm(TradingAlgorithm):
|
||||
assert len(self.portfolio.positions.keys()) == 0
|
||||
|
||||
method_to_check = getattr(self, self.method_name)
|
||||
method_to_check(0, data[0].price, style=StopLimitOrder(10, 10))
|
||||
method_to_check(self.sid(0),
|
||||
data[0].price,
|
||||
style=StopLimitOrder(10, 10))
|
||||
|
||||
assert len(self.blotter.open_orders[0]) == 1
|
||||
result = self.blotter.open_orders[0][0]
|
||||
@@ -335,7 +340,12 @@ class TestOrderValueAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
self.incr += 2
|
||||
self.order_value(0, data[0].price * 2.)
|
||||
|
||||
multiplier = 2.
|
||||
if isinstance(self.sid(0), Future):
|
||||
multiplier *= self.sid(0).contract_multiplier
|
||||
|
||||
self.order_value(self.sid(0), data[0].price * multiplier)
|
||||
|
||||
|
||||
class TestTargetAlgorithm(TradingAlgorithm):
|
||||
@@ -352,7 +362,7 @@ class TestTargetAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
self.target_shares = np.random.randint(1, 30)
|
||||
self.order_target(0, self.target_shares)
|
||||
self.order_target(self.sid(0), self.target_shares)
|
||||
|
||||
|
||||
class TestOrderPercentAlgorithm(TradingAlgorithm):
|
||||
@@ -363,7 +373,7 @@ class TestOrderPercentAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
if self.target_shares == 0:
|
||||
assert 0 not in self.portfolio.positions
|
||||
self.order(0, 10)
|
||||
self.order(self.sid(0), 10)
|
||||
self.target_shares = 10
|
||||
return
|
||||
else:
|
||||
@@ -372,10 +382,17 @@ class TestOrderPercentAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
|
||||
self.order_percent(0, .001)
|
||||
self.target_shares += np.floor((.001 *
|
||||
self.portfolio.portfolio_value) /
|
||||
data[0].price)
|
||||
self.order_percent(self.sid(0), .001)
|
||||
|
||||
if isinstance(self.sid(0), Equity):
|
||||
self.target_shares += np.floor(
|
||||
(.001 * self.portfolio.portfolio_value) / data[0].price
|
||||
)
|
||||
if isinstance(self.sid(0), Future):
|
||||
self.target_shares += np.floor(
|
||||
(.001 * self.portfolio.portfolio_value) /
|
||||
(data[0].price * self.sid(0).contract_multiplier)
|
||||
)
|
||||
|
||||
|
||||
class TestTargetPercentAlgorithm(TradingAlgorithm):
|
||||
@@ -394,7 +411,7 @@ class TestTargetPercentAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
self.sale_price = data[0].price
|
||||
self.order_target_percent(0, .002)
|
||||
self.order_target_percent(self.sid(0), .002)
|
||||
|
||||
|
||||
class TestTargetValueAlgorithm(TradingAlgorithm):
|
||||
@@ -405,7 +422,7 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
if self.target_shares == 0:
|
||||
assert 0 not in self.portfolio.positions
|
||||
self.order(0, 10)
|
||||
self.order(self.sid(0), 10)
|
||||
self.target_shares = 10
|
||||
return
|
||||
else:
|
||||
@@ -415,9 +432,15 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
|
||||
assert self.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
|
||||
self.order_target_value(0, 20)
|
||||
self.order_target_value(self.sid(0), 20)
|
||||
self.target_shares = np.round(20 / data[0].price)
|
||||
|
||||
if isinstance(self.sid(0), Equity):
|
||||
self.target_shares = np.round(20 / data[0].price)
|
||||
if isinstance(self.sid(0), Future):
|
||||
self.target_shares = np.round(
|
||||
20 / (data[0].price * self.sid(0).contract_multiplier))
|
||||
|
||||
|
||||
############################
|
||||
# AccountControl Test Algos#
|
||||
@@ -468,6 +491,18 @@ class SetLongOnlyAlgorithm(TradingAlgorithm):
|
||||
self.set_long_only()
|
||||
|
||||
|
||||
class SetAssetDateBoundsAlgorithm(TradingAlgorithm):
|
||||
"""
|
||||
Algorithm that tries to order 1 share of sid 0 on every bar and has an
|
||||
AssetDateBounds() trading control in place.
|
||||
"""
|
||||
def initialize(self):
|
||||
self.register_trading_control(AssetDateBounds())
|
||||
|
||||
def handle_data(algo, data):
|
||||
algo.order(algo.sid(0), 1)
|
||||
|
||||
|
||||
class TestRegisterTransformAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.set_slippage(FixedSlippage())
|
||||
@@ -484,7 +519,7 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm):
|
||||
"""
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.sid = kwargs.pop('sid')
|
||||
self.asset = self.sid(kwargs.pop('sid'))
|
||||
|
||||
def handle_data(self, data):
|
||||
|
||||
@@ -493,42 +528,42 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm):
|
||||
########
|
||||
|
||||
# Buy with low limit, shouldn't trigger.
|
||||
self.order(self.sid, 100, limit_price=1)
|
||||
self.order(self.asset, 100, limit_price=1)
|
||||
|
||||
# But with high stop, shouldn't trigger
|
||||
self.order(self.sid, 100, stop_price=10000000)
|
||||
self.order(self.asset, 100, stop_price=10000000)
|
||||
|
||||
# Buy with high limit (should trigger) but also high stop (should
|
||||
# prevent trigger).
|
||||
self.order(self.sid, 100, limit_price=10000000, stop_price=10000000)
|
||||
self.order(self.asset, 100, limit_price=10000000, stop_price=10000000)
|
||||
|
||||
# Buy with low stop (should trigger), but also low limit (should
|
||||
# prevent trigger).
|
||||
self.order(self.sid, 100, limit_price=1, stop_price=1)
|
||||
self.order(self.asset, 100, limit_price=1, stop_price=1)
|
||||
|
||||
#########
|
||||
# Sells #
|
||||
#########
|
||||
|
||||
# Sell with high limit, shouldn't trigger.
|
||||
self.order(self.sid, -100, limit_price=1000000)
|
||||
self.order(self.asset, -100, limit_price=1000000)
|
||||
|
||||
# Sell with low stop, shouldn't trigger.
|
||||
self.order(self.sid, -100, stop_price=1)
|
||||
self.order(self.asset, -100, stop_price=1)
|
||||
|
||||
# Sell with low limit (should trigger), but also high stop (should
|
||||
# prevent trigger).
|
||||
self.order(self.sid, -100, limit_price=1000000, stop_price=1000000)
|
||||
self.order(self.asset, -100, limit_price=1000000, stop_price=1000000)
|
||||
|
||||
# Sell with low limit (should trigger), but also low stop (should
|
||||
# prevent trigger).
|
||||
self.order(self.sid, -100, limit_price=1, stop_price=1)
|
||||
self.order(self.asset, -100, limit_price=1, stop_price=1)
|
||||
|
||||
###################
|
||||
# Rounding Checks #
|
||||
###################
|
||||
self.order(self.sid, 100, limit_price=.00000001)
|
||||
self.order(self.sid, -100, stop_price=.00000001)
|
||||
self.order(self.asset, 100, limit_price=.00000001)
|
||||
self.order(self.asset, -100, stop_price=.00000001)
|
||||
|
||||
|
||||
##########################################
|
||||
@@ -695,8 +730,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
"batch transform is not updating for new kwargs"
|
||||
|
||||
new_data = deepcopy(data)
|
||||
for sid in new_data:
|
||||
new_data[sid]['arbitrary'] = 123
|
||||
for sidint in new_data:
|
||||
new_data[sidint]['arbitrary'] = 123
|
||||
|
||||
self.history_return_arbitrary_fields.append(
|
||||
self.return_arbitrary_fields.handle_data(new_data))
|
||||
@@ -707,8 +742,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.return_nan.handle_data(data))
|
||||
else:
|
||||
nan_data = deepcopy(data)
|
||||
for sid in nan_data.iterkeys():
|
||||
nan_data[sid].price = np.nan
|
||||
nan_data.price = np.nan
|
||||
self.history_return_nan.append(
|
||||
self.return_nan.handle_data(nan_data))
|
||||
|
||||
@@ -810,7 +844,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm):
|
||||
def handle_data(self, data):
|
||||
if not self.ordered:
|
||||
for s in data:
|
||||
self.order(s, 100)
|
||||
self.order(self.sid(s), 100)
|
||||
self.ordered = True
|
||||
|
||||
if not self.exited:
|
||||
@@ -821,7 +855,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm):
|
||||
(len(amounts) == len(data.keys()))
|
||||
):
|
||||
for stock in self.portfolio.positions:
|
||||
self.order(stock, -100)
|
||||
self.order(self.sid(stock), -100)
|
||||
self.exited = True
|
||||
|
||||
# Should be 0 when all positions are exited.
|
||||
@@ -834,7 +868,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
|
||||
appropriate exceptions are raised.
|
||||
"""
|
||||
def initialize(self, *args, **kwargs):
|
||||
self.sid = kwargs.pop('sids')[0]
|
||||
self.asset = self.sid(kwargs.pop('sids')[0])
|
||||
|
||||
def handle_data(self, data):
|
||||
from zipline.api import (
|
||||
@@ -849,40 +883,48 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
|
||||
StopOrder(10), StopLimitOrder(10, 10)]:
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order(self.sid, 10, limit_price=10, style=style)
|
||||
order(self.asset, 10, limit_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order(self.sid, 10, stop_price=10, style=style)
|
||||
order(self.asset, 10, stop_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_value(self.sid, 300, limit_price=10, style=style)
|
||||
order_value(self.asset, 300, limit_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_value(self.sid, 300, stop_price=10, style=style)
|
||||
order_value(self.asset, 300, stop_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_percent(self.sid, .1, limit_price=10, style=style)
|
||||
order_percent(self.asset, .1, limit_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_percent(self.sid, .1, stop_price=10, style=style)
|
||||
order_percent(self.asset, .1, stop_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target(self.sid, 100, limit_price=10, style=style)
|
||||
order_target(self.asset, 100, limit_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target(self.sid, 100, stop_price=10, style=style)
|
||||
order_target(self.asset, 100, stop_price=10, style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target_value(self.sid, 100, limit_price=10, style=style)
|
||||
order_target_value(self.asset, 100,
|
||||
limit_price=10,
|
||||
style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target_value(self.sid, 100, stop_price=10, style=style)
|
||||
order_target_value(self.asset, 100,
|
||||
stop_price=10,
|
||||
style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target_percent(self.sid, .2, limit_price=10, style=style)
|
||||
order_target_percent(self.asset, .2,
|
||||
limit_price=10,
|
||||
style=style)
|
||||
|
||||
with assert_raises(UnsupportedOrderParameters):
|
||||
order_target_percent(self.sid, .2, stop_price=10, style=style)
|
||||
order_target_percent(self.asset, .2,
|
||||
stop_price=10,
|
||||
style=style)
|
||||
|
||||
|
||||
##############################
|
||||
@@ -913,7 +955,7 @@ def handle_data_api(context, data):
|
||||
assert context.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
context.incr += 1
|
||||
order(0, 1)
|
||||
order(sid(0), 1)
|
||||
|
||||
record(incr=context.incr)
|
||||
|
||||
@@ -932,7 +974,8 @@ api_algo = """
|
||||
from zipline.api import (order,
|
||||
set_slippage,
|
||||
FixedSlippage,
|
||||
record)
|
||||
record,
|
||||
sid)
|
||||
|
||||
def initialize(context):
|
||||
context.incr = 0
|
||||
@@ -948,7 +991,7 @@ def handle_data(context, data):
|
||||
assert context.portfolio.positions[0]['last_sale_price'] == \
|
||||
data[0].price, "Orders not filled at current price."
|
||||
context.incr += 1
|
||||
order(0, 1)
|
||||
order(sid(0), 1)
|
||||
|
||||
record(incr=context.incr)
|
||||
"""
|
||||
@@ -1009,18 +1052,19 @@ from zipline.api import (order,
|
||||
order_percent,
|
||||
order_target,
|
||||
order_target_value,
|
||||
order_target_percent)
|
||||
order_target_percent,
|
||||
sid)
|
||||
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
order(0, 10)
|
||||
order_value(0, 300)
|
||||
order_percent(0, .1)
|
||||
order_target(0, 100)
|
||||
order_target_value(0, 100)
|
||||
order_target_percent(0, .2)
|
||||
order(sid(0), 10)
|
||||
order_value(sid(0), 300)
|
||||
order_percent(sid(0), .1)
|
||||
order_target(sid(0), 100)
|
||||
order_target_value(sid(0), 100)
|
||||
order_target_percent(sid(0), .2)
|
||||
"""
|
||||
|
||||
record_variables = """
|
||||
|
||||
+62
-9
@@ -31,14 +31,15 @@ except:
|
||||
PYGMENTS = False
|
||||
|
||||
import zipline
|
||||
from zipline.errors import NoSourceError, PipelineDateError
|
||||
|
||||
DEFAULTS = {
|
||||
'start': '2012-01-01',
|
||||
'end': '2012-12-31',
|
||||
'data_frequency': 'daily',
|
||||
'capital_base': '10e6',
|
||||
'source': 'yahoo',
|
||||
'symbols': 'AAPL'
|
||||
'symbols': 'AAPL',
|
||||
'metadata_index': 'symbol',
|
||||
'source_time_column': 'Date',
|
||||
}
|
||||
|
||||
|
||||
@@ -99,9 +100,12 @@ def parse_args(argv, ipython_mode=False):
|
||||
parser.add_argument('--start', '-s')
|
||||
parser.add_argument('--end', '-e')
|
||||
parser.add_argument('--capital_base')
|
||||
parser.add_argument('--source', choices=('yahoo',))
|
||||
parser.add_argument('--source', '-d')
|
||||
parser.add_argument('--source_time_column', '-t')
|
||||
parser.add_argument('--symbols')
|
||||
parser.add_argument('--output', '-o')
|
||||
parser.add_argument('--metadata_path', '-m')
|
||||
parser.add_argument('--metadata_index', '-x')
|
||||
if ipython_mode:
|
||||
parser.add_argument('--local_namespace', action='store_true')
|
||||
|
||||
@@ -153,14 +157,59 @@ def run_pipeline(print_algo=True, **kwargs):
|
||||
pygments syntax coloring if pygments is found.
|
||||
|
||||
"""
|
||||
start = pd.Timestamp(kwargs['start'], tz='UTC')
|
||||
end = pd.Timestamp(kwargs['end'], tz='UTC')
|
||||
start = kwargs['start']
|
||||
end = kwargs['end']
|
||||
# Compare against None because strings/timestamps may have been given
|
||||
if start is not None:
|
||||
start = pd.Timestamp(start, tz='UTC')
|
||||
if end is not None:
|
||||
end = pd.Timestamp(end, tz='UTC')
|
||||
|
||||
# Fail out if only one bound is provided
|
||||
if ((start is None) or (end is None)) and (start != end):
|
||||
raise PipelineDateError(start=start, end=end)
|
||||
|
||||
# Check if start and end are provided, and if the sim_params need to read
|
||||
# a start and end from the DataSource
|
||||
if start is None:
|
||||
overwrite_sim_params = True
|
||||
else:
|
||||
overwrite_sim_params = False
|
||||
|
||||
symbols = kwargs['symbols'].split(',')
|
||||
asset_identifier = kwargs['metadata_index']
|
||||
|
||||
if kwargs['source'] == 'yahoo':
|
||||
# Pull asset metadata
|
||||
asset_metadata = kwargs.get('asset_metadata', None)
|
||||
asset_metadata_path = kwargs['metadata_path']
|
||||
# Read in a CSV file, if applicable
|
||||
if asset_metadata_path is not None:
|
||||
if os.path.isfile(asset_metadata_path):
|
||||
asset_metadata = pd.read_csv(asset_metadata_path,
|
||||
index_col=asset_identifier)
|
||||
|
||||
source_arg = kwargs['source']
|
||||
source_time_column = kwargs['source_time_column']
|
||||
|
||||
if source_arg is None:
|
||||
raise NoSourceError()
|
||||
|
||||
elif source_arg == 'yahoo':
|
||||
source = zipline.data.load_bars_from_yahoo(
|
||||
stocks=symbols, start=start, end=end)
|
||||
|
||||
elif os.path.isfile(source_arg):
|
||||
source = zipline.data.load_prices_from_csv(
|
||||
filepath=source_arg,
|
||||
identifier_col=source_time_column
|
||||
)
|
||||
|
||||
elif os.path.isdir(source_arg):
|
||||
source = zipline.data.load_prices_from_csv_folder(
|
||||
folderpath=source_arg,
|
||||
identifier_col=source_time_column
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Source %s not implemented.' % kwargs['source'])
|
||||
@@ -188,9 +237,13 @@ def run_pipeline(print_algo=True, **kwargs):
|
||||
algo = zipline.TradingAlgorithm(script=algo_text,
|
||||
namespace=kwargs.get('namespace', {}),
|
||||
capital_base=float(kwargs['capital_base']),
|
||||
algo_filename=kwargs.get('algofile'))
|
||||
algo_filename=kwargs.get('algofile'),
|
||||
asset_metadata=asset_metadata,
|
||||
identifiers=symbols,
|
||||
start=start,
|
||||
end=end)
|
||||
|
||||
perf = algo.run(source)
|
||||
perf = algo.run(source, overwrite_sim_params=overwrite_sim_params)
|
||||
|
||||
output_fname = kwargs.get('output', None)
|
||||
if output_fname is not None:
|
||||
|
||||
@@ -120,6 +120,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params,
|
||||
source_id="test_factory"):
|
||||
trades = []
|
||||
current = sim_params.first_open
|
||||
trading.environment.update_asset_finder(identifiers=[sid])
|
||||
|
||||
oneday = timedelta(days=1)
|
||||
use_midnight = interval >= oneday
|
||||
@@ -149,7 +150,6 @@ def create_dividend(sid, payment, declared_date, ex_date, pay_date):
|
||||
'type': DATASOURCE_TYPE.DIVIDEND,
|
||||
'source_id': 'MockDividendSource'
|
||||
})
|
||||
|
||||
return div
|
||||
|
||||
|
||||
@@ -312,6 +312,8 @@ def create_test_df_source(sim_params=None, bars='daily'):
|
||||
|
||||
df = pd.DataFrame(x, index=index, columns=[0])
|
||||
|
||||
trading.environment.update_asset_finder(identifiers=[0])
|
||||
|
||||
return DataFrameSource(df), df
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import os.path
|
||||
import pandas as pd
|
||||
import pytz
|
||||
import zipline
|
||||
from zipline.finance.trading import with_environment
|
||||
|
||||
|
||||
DATE_FORMAT = "%Y%m%d"
|
||||
@@ -12,23 +13,16 @@ zipline_dir = os.path.dirname(zipline.__file__)
|
||||
SECURITY_LISTS_DIR = os.path.join(zipline_dir, 'resources', 'security_lists')
|
||||
|
||||
|
||||
def loopback(symbol, *args, **kwargs):
|
||||
return symbol
|
||||
|
||||
|
||||
class SecurityList(object):
|
||||
|
||||
def __init__(self, lookup_func, data, current_date_func):
|
||||
def __init__(self, data, current_date_func):
|
||||
"""
|
||||
lookup_func: function that takes a string symbol and a date and
|
||||
returns a Security object.
|
||||
data: a nested dictionary:
|
||||
knowledge_date -> lookup_date ->
|
||||
{add: [symbol list], 'delete': []}, delete: [symbol list]}
|
||||
current_date_func: function taking no parameters, returning
|
||||
current datetime
|
||||
"""
|
||||
self.lookup_func = lookup_func
|
||||
self.data = data
|
||||
self._cache = {}
|
||||
self._knowledge_dates = self.make_knowledge_dates(self.data)
|
||||
@@ -74,13 +68,17 @@ class SecurityList(object):
|
||||
self._cache[kd] = self._current_set
|
||||
return self._current_set
|
||||
|
||||
def update_current(self, effective_date, symbols, change_func):
|
||||
@with_environment()
|
||||
def update_current(self, effective_date, symbols, change_func, env=None):
|
||||
for symbol in symbols:
|
||||
sid = self.lookup_func(
|
||||
asset = env.asset_finder.lookup_symbol(
|
||||
symbol,
|
||||
as_of_date=effective_date
|
||||
)
|
||||
change_func(sid)
|
||||
# Pass if no Asset exists for the symbol
|
||||
if asset is None:
|
||||
continue
|
||||
change_func(asset.sid)
|
||||
|
||||
|
||||
class SecurityListSet(object):
|
||||
@@ -88,11 +86,7 @@ class SecurityListSet(object):
|
||||
# list implementations.
|
||||
security_list_type = SecurityList
|
||||
|
||||
def __init__(self, current_date_func, lookup_func=None):
|
||||
if lookup_func is None:
|
||||
self.lookup_func = loopback
|
||||
else:
|
||||
self.lookup_func = lookup_func
|
||||
def __init__(self, current_date_func):
|
||||
self.current_date_func = current_date_func
|
||||
self._leveraged_etf = None
|
||||
|
||||
@@ -100,7 +94,6 @@ class SecurityListSet(object):
|
||||
def leveraged_etf_list(self):
|
||||
if self._leveraged_etf is None:
|
||||
self._leveraged_etf = self.security_list_type(
|
||||
self.lookup_func,
|
||||
load_from_directory('leveraged_etf_list'),
|
||||
self.current_date_func
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ def create_test_zipline(**config):
|
||||
"""
|
||||
:param config: A configuration object that is a dict with:
|
||||
|
||||
- sid - an integer, which will be used as the security ID.
|
||||
- sid - an integer, which will be used as the asset ID.
|
||||
- order_count - the number of orders the test algo will place,
|
||||
defaults to 100
|
||||
- order_amount - the number of shares per order, defaults to 100
|
||||
@@ -25,10 +25,15 @@ def create_test_zipline(**config):
|
||||
:py:mod:`zipline.finance.trading`
|
||||
"""
|
||||
assert isinstance(config, dict)
|
||||
sid_list = config.get('sid_list')
|
||||
if not sid_list:
|
||||
sid = config.get('sid')
|
||||
sid_list = [sid]
|
||||
|
||||
try:
|
||||
sid_list = config['sid_list']
|
||||
except KeyError:
|
||||
try:
|
||||
sid_list = [config['sid']]
|
||||
except KeyError:
|
||||
raise Exception("simfactory create_test_zipline() requires "
|
||||
"argument 'sid_list' or 'sid'")
|
||||
|
||||
concurrent_trades = config.get('concurrent_trades', False)
|
||||
|
||||
@@ -49,12 +54,13 @@ def create_test_zipline(**config):
|
||||
test_algo = config['algorithm']
|
||||
else:
|
||||
test_algo = TestAlgorithm(
|
||||
sid,
|
||||
sid_list[0],
|
||||
order_amount,
|
||||
order_count,
|
||||
sim_params=config.get('sim_params',
|
||||
factory.create_simulation_parameters()),
|
||||
slippage=config.get('slippage'),
|
||||
identifiers=sid_list
|
||||
)
|
||||
|
||||
# -------------------
|
||||
|
||||
Reference in New Issue
Block a user