ENH: Adds futures trading and asset management logic to TradingAlgorithm and performance classes

This commit is contained in:
jfkirk
2015-06-08 16:29:19 -04:00
parent 035bfbd514
commit b84ac01cbf
25 changed files with 844 additions and 304 deletions
+115 -42
View File
@@ -39,9 +39,10 @@ from zipline.errors import (
UnsupportedCommissionModel,
UnsupportedOrderParameters,
UnsupportedSlippageModel,
SidNotFound,
)
from zipline.finance import trading
from zipline.finance.trading import TradingEnvironment
from zipline.finance.blotter import Blotter
from zipline.finance.commission import PerShare, PerTrade, PerDollar
from zipline.finance.controls import (
@@ -64,6 +65,7 @@ from zipline.finance.slippage import (
SlippageModel,
transact_partial
)
from zipline.assets import Asset, Future
from zipline.gens.composites import date_sorted_sources
from zipline.gens.tradesimulation import AlgorithmSimulator
from zipline.sources import DataFrameSource, DataPanelSource
@@ -133,8 +135,27 @@ class TradingAlgorithm(object):
How much capital to start with.
instant_fill : bool <default: False>
Whether to fill orders immediately or on next bar.
environment : str <default: 'zipline'>
The environment that this algorithm is running in.
asset_finder : An AssetFinder object
A new AssetFinder object to be used in this TradingEnvironment
asset_metadata: can be either:
- dict
- pandas.DataFrame
- object with 'read' property
If dict is provided, it must have the following structure:
* keys are the identifiers
* values are dicts containing the metadata, with the metadata
field name as the key
If pandas.DataFrame is provided, it must have the
following structure:
* column names must be the metadata fields
* index must be the different asset identifiers
* array contents should be the metadata value
If an object with a 'read' property is provided, 'read' must
return rows containing at least one of 'sid' or 'symbol' along
with the other metadata fields.
identifiers : List
Any asset identifiers that are not provided in the
asset_metadata, but will be traded by this TradingAlgorithm
"""
self.datetime = None
@@ -167,10 +188,23 @@ class TradingAlgorithm(object):
self.sim_params = kwargs.pop('sim_params', None)
if self.sim_params is None:
self.sim_params = create_simulation_parameters(
capital_base=self.capital_base
capital_base=self.capital_base,
start=kwargs.pop('start', None),
end=kwargs.pop('end', None)
)
self.perf_tracker = PerformanceTracker(self.sim_params)
# Update the TradingEnvironment with the provided asset metadata
self.trading_environment = kwargs.pop('env',
TradingEnvironment.instance())
self.trading_environment.update_asset_finder(
asset_finder=kwargs.pop('asset_finder', None),
asset_metadata=kwargs.pop('asset_metadata', None),
identifiers=kwargs.pop('identifiers', None)
)
# Pull in the environment's new AssetFinder for quick reference
self.asset_finder = self.trading_environment.asset_finder
self.blotter = kwargs.pop('blotter', None)
if not self.blotter:
self.blotter = Blotter()
@@ -322,11 +356,10 @@ class TradingAlgorithm(object):
sim_params = self.sim_params
if self.benchmark_return_source is None:
env = trading.environment
if sim_params.data_frequency == 'minute' or \
sim_params.emission_rate == 'minute':
def update_time(date):
return env.get_open_and_close(date)[1]
return self.trading_environment.get_open_and_close(date)[1]
else:
def update_time(date):
return date
@@ -336,7 +369,7 @@ class TradingAlgorithm(object):
'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK,
'source_id': 'benchmarks'})
for dt, ret in
trading.environment.benchmark_returns.iteritems()
self.trading_environment.benchmark_returns.iteritems()
if dt.date() >= sim_params.period_start.date() and
dt.date() <= sim_params.period_end.date()
]
@@ -410,8 +443,7 @@ class TradingAlgorithm(object):
If pandas.DataFrame is provided, it must have the
following structure:
* column names must consist of ints representing the
different sids
* column names must be the different asset identifiers
* index must be DatetimeIndex
* array contents should be price info.
@@ -420,9 +452,11 @@ class TradingAlgorithm(object):
Daily performance metrics such as returns, alpha etc.
"""
# Ensure that source is a DataSource object
if isinstance(source, list):
if overwrite_sim_params:
warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end
warnings.warn("""List of sources passed, will not attempt to extract start and end
dates. Make sure to set the correct fields in sim_params passed to
__init__().""", UserWarning)
overwrite_sim_params = False
@@ -443,8 +477,19 @@ class TradingAlgorithm(object):
self.sim_params.period_start = source.start
if hasattr(source, 'end'):
self.sim_params.period_end = source.end
# The sids field of the source is the canonical reference for
# sids in this run
all_sids = [sid for s in self.sources for sid in s.sids]
self.sim_params.sids = set(all_sids)
# Check that all sids from the source are accounted for in
# the AssetFinder
for sid in self.sim_params.sids:
try:
self.asset_finder.retrieve_asset(sid)
except SidNotFound:
warnings.warn("No Asset found for sid '%s'. Make sure "
"that the correct identifiers and asset "
"metadata are passed to __init__()." % sid)
# Changing period_start and period_close might require updating
# of first_open and last_close.
self.sim_params._update_internal()
@@ -604,17 +649,52 @@ class TradingAlgorithm(object):
def symbol(self, symbol_str):
"""
Default symbol lookup for any source that directly maps the
symbol to the identifier (e.g. yahoo finance).
symbol to the Asset (e.g. yahoo finance).
"""
return symbol_str
asset, _ = self.asset_finder.lookup_generic(
asset_convertible_or_iterable=symbol_str,
as_of_date=self.datetime,
)
return asset
@api_method
def symbols(self, *args):
"""
Default symbols lookup for any source that directly maps the
symbol to the identifier (e.g. yahoo finance).
symbol to the Asset (e.g. yahoo finance).
"""
return args
return [self.symbol(identifier) for identifier in args]
@api_method
def sid(self, a_sid):
"""
Default sid lookup for any source that directly maps the integer sid
to the Asset.
"""
return self.asset_finder.retrieve_asset(a_sid)
def _calculate_order_value_amount(self, asset, value):
"""
Calculates how many shares/contracts to order based on the type of
asset being ordered.
"""
last_price = self.trading_client.current_data[asset].price
if tolerant_equals(last_price, 0):
zero_message = "Price of 0 for {psid}; can't infer value".format(
psid=asset
)
if self.logger:
self.logger.debug(zero_message)
# Don't place any order
return 0
if isinstance(asset, Future):
value_multiplier = asset.contract_multiplier
else:
value_multiplier = 1
return value / (last_price * value_multiplier)
@api_method
def order(self, sid, amount,
@@ -655,7 +735,7 @@ class TradingAlgorithm(object):
return self.blotter.order(sid, amount, style)
def validate_order_params(self,
sid,
asset,
amount,
limit_price,
stop_price,
@@ -682,8 +762,14 @@ class TradingAlgorithm(object):
msg="Passing both stop_price and style is not supported."
)
if not isinstance(asset, Asset):
raise UnsupportedOrderParameters(
msg="Passing non-Asset argument to 'order()' is not supported."
" Use 'sid()' or 'symbol()' methods to look up an Asset."
)
for control in self.trading_controls:
control.validate(sid,
control.validate(asset,
amount,
self.updated_portfolio(),
self.get_datetime(),
@@ -718,6 +804,8 @@ class TradingAlgorithm(object):
Place an order by desired value rather than desired number of shares.
If the requested sid is found in the universe, the requested value is
divided by its price to imply the number of shares to transact.
If the Asset being ordered is a Future, the 'value' calculated
is actually the exposure, as Futures have no 'value'.
value > 0 :: Buy/Cover
value < 0 :: Sell/Short
@@ -726,21 +814,11 @@ class TradingAlgorithm(object):
Stop order: order(sid, value, None, stop_price)
StopLimit order: order(sid, value, limit_price, stop_price)
"""
last_price = self.trading_client.current_data[sid].price
if tolerant_equals(last_price, 0):
zero_message = "Price of 0 for {psid}; can't infer value".format(
psid=sid
)
if self.logger:
self.logger.debug(zero_message)
# Don't place any order
return
else:
amount = value / last_price
return self.order(sid, amount,
limit_price=limit_price,
stop_price=stop_price,
style=style)
amount = self._calculate_order_value_amount(sid, value)
return self.order(sid, amount,
limit_price=limit_price,
stop_price=stop_price,
style=style)
@property
def recorded_vars(self):
@@ -855,7 +933,7 @@ class TradingAlgorithm(object):
def order_percent(self, sid, percent,
limit_price=None, stop_price=None, style=None):
"""
Place an order in the specified security corresponding to the given
Place an order in the specified asset corresponding to the given
percent of the current portfolio value.
Note that percent must expressed as a decimal (0.50 means 50\%).
@@ -898,15 +976,10 @@ class TradingAlgorithm(object):
order. If the position does exist, this is equivalent to placing an
order for the difference between the target value and the
current value.
If the Asset being ordered is a Future, the 'target value' calculated
is actually the target exposure, as Futures have no 'value'.
"""
last_price = self.trading_client.current_data[sid].price
if tolerant_equals(last_price, 0):
# Don't place an order
if self.logger:
zero_message = "Price of 0 for {psid}; can't infer value"
self.logger.debug(zero_message.format(psid=sid))
return
target_amount = target / last_price
target_amount = self._calculate_order_value_amount(sid, target)
return self.order_target(sid, target_amount,
limit_price=limit_price,
stop_price=stop_price,
@@ -1066,7 +1139,7 @@ class TradingAlgorithm(object):
increasing the absolute value of shares/dollar value exceeding one of
these limits, raise a TradingControlException.
"""
control = MaxPositionSize(sid=sid,
control = MaxPositionSize(asset=sid,
max_shares=max_shares,
max_notional=max_notional)
self.register_trading_control(control)
@@ -1081,7 +1154,7 @@ class TradingAlgorithm(object):
If an algorithm attempts to place an order that would result in
exceeding one of these limits, raise a TradingControlException.
"""
control = MaxOrderSize(sid=sid,
control = MaxOrderSize(asset=sid,
max_shares=max_shares,
max_notional=max_notional)
self.register_trading_control(control)
+6 -2
View File
@@ -1,4 +1,8 @@
from . import loader
from .loader import load_from_yahoo, load_bars_from_yahoo
from .loader import (
load_from_yahoo, load_bars_from_yahoo, load_prices_from_csv,
load_prices_from_csv_folder
)
__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo']
__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo',
'load_prices_from_csv', 'load_prices_from_csv_folder']
+22 -1
View File
@@ -294,7 +294,7 @@ def load_from_yahoo(indexes=None,
adjusted=True):
"""
Loads price data from Yahoo into a dataframe for each of the indicated
securities. By default, 'price' is taken from Yahoo's 'Adjusted Close',
assets. By default, 'price' is taken from Yahoo's 'Adjusted Close',
which removes the impact of splits and dividends. If the argument
'adjusted' is False, then the non-adjusted 'close' field is used instead.
@@ -367,3 +367,24 @@ def load_bars_from_yahoo(indexes=None,
for col in adj_cols:
panel[ticker][col] *= ratio_filtered
return panel
def load_prices_from_csv(filepath, identifier_col, tz='UTC'):
data = pd.read_csv(filepath, index_col=identifier_col)
data.index = pd.DatetimeIndex(data.index, tz=tz)
data.sort_index(inplace=True)
return data
def load_prices_from_csv_folder(folderpath, identifier_col, tz='UTC'):
data = None
for file in os.listdir(folderpath):
if '.csv' not in file:
continue
raw = load_prices_from_csv(os.path.join(folderpath, file),
identifier_col, tz)
if data is None:
data = raw
else:
data = pd.concat([data, raw], axis=1)
return data
+94 -1
View File
@@ -175,7 +175,8 @@ class TradingControlViolation(ZiplineError):
Raised if an order would violate a constraint set by a TradingControl.
"""
msg = """
Order for {amount} shares of {sid} violates trading constraint {constraint}.
Order for {amount} shares of {asset} at {datetime} violates trading constraint
{constraint}.
""".strip()
@@ -188,3 +189,95 @@ class IncompatibleHistoryFrequency(ZiplineError):
Requested history at frequency '{frequency}' cannot be created with data
at frequency '{data_frequency}'.
""".strip()
class MultipleSymbolsFound(ZiplineError):
"""
Raised when a symbol() call contains a symbol that changed over
time and is thus not resolvable without additional information
provided via as_of_date.
"""
msg = """
Multiple symbols with the name '{symbol}' found. Use the
as_of_date' argument to to specify when the date symbol-lookup
should be valid.
Possible options:{options}
""".strip()
class SymbolNotFound(ZiplineError):
"""
Raised when a symbol() call contains a non-existant symbol.
"""
msg = """
Symbol '{symbol}' was not found.
""".strip()
class SidNotFound(ZiplineError):
"""
Raised when a retrieve_asset() call contains a non-existent sid.
"""
msg = """
Asset with sid '{sid}' was not found.
""".strip()
class IdentifierNotFound(ZiplineError):
"""
Raised when a retrieve_asset_by_identifier() call contains a non-existent
identifier.
"""
msg = """
Asset with identifier '{identifier}' was not found.
""".strip()
class InvalidAssetType(ZiplineError):
"""
Raised when an AssetFinder tries to build an Asset with an invalid
AssetType.
"""
msg = """
AssetMetaData contained an invalid Asset type: '{asset_type}'.
""".strip()
class UpdateAssetFinderTypeError(ZiplineError):
"""
Raised when TradingEnvironment.update_asset_finder() gets an asset_finder
arg that is not of AssetFinder class.
"""
msg = """
TradingEnvironment can not set asset_finder to object of class {cls}.
""".strip()
class ConsumeAssetMetaDataError(ZiplineError):
"""
Raised when AssetMetaData.consume() is called on an invalid object.
"""
msg = """
AssetMetaData can not consume {obj}. MetaData must be a dict, a DataFrame, or
""".strip()
class NoSourceError(ZiplineError):
"""
Raised when no source is given to the pipeline
"""
msg = """
No data source given.
""".strip()
class PipelineDateError(ZiplineError):
"""
Raised when only one date is passed to the pipeline
"""
msg = """
Only one simulation date given. Please specify both the 'start' and 'end' for
the simulation, or neither. If neither is given, the start and end of the
DataSource will be used. Given start = '{start}', end = '{end}'
""".strip()
+8 -7
View File
@@ -29,7 +29,7 @@ from zipline.transforms.ta import EMA
def initialize(context):
context.security = symbol('AAPL')
context.asset = symbol('AAPL')
# Add 2 mavg transforms, one with a long window, one with a short window.
context.short_ema_trans = EMA(timeperiod=20)
@@ -49,17 +49,17 @@ def handle_data(context, data):
sell = False
if (short_ema > long_ema).all() and not context.invested:
order(context.security, 100)
order(context.asset, 100)
context.invested = True
buy = True
elif (short_ema < long_ema).all() and context.invested:
order(context.security, -100)
order(context.asset, -100)
context.invested = False
sell = True
record(AAPL=data[context.security].price,
short_ema=short_ema[context.security],
long_ema=long_ema[context.security],
record(AAPL=data[context.asset].price,
short_ema=short_ema[context.asset],
long_ema=long_ema[context.asset],
buy=buy,
sell=sell)
@@ -77,7 +77,8 @@ if __name__ == '__main__':
data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start,
end=end)
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data)
algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data,
identifiers=['AAPL'])
results = algo.run(data).dropna()
fig = plt.figure()
+9 -9
View File
@@ -31,6 +31,8 @@ def initialize(context):
add_history(100, '1d', 'price')
add_history(300, '1d', 'price')
context.sym = symbol('AAPL')
context.i = 0
@@ -46,17 +48,15 @@ def handle_data(context, data):
short_mavg = history(100, '1d', 'price').mean()
long_mavg = history(300, '1d', 'price').mean()
sym = symbol('AAPL')
# Trading logic
if short_mavg[sym] > long_mavg[sym]:
if short_mavg[context.sym] > long_mavg[context.sym]:
# order_target orders as many shares as needed to
# achieve the desired number of shares.
order_target(sym, 100)
elif short_mavg[sym] < long_mavg[sym]:
order_target(sym, 0)
order_target(context.sym, 100)
elif short_mavg[context.sym] < long_mavg[context.sym]:
order_target(context.sym, 0)
# Save values for later inspection
record(AAPL=data[sym].price,
short_mavg=short_mavg[sym],
long_mavg=long_mavg[sym])
record(AAPL=data[context.sym].price,
short_mavg=short_mavg[context.sym],
long_mavg=long_mavg[context.sym])
+13 -10
View File
@@ -24,6 +24,7 @@ STOCKS = ['AMD', 'CERN', 'COST', 'DELL', 'GPS', 'INTC', 'MMM']
# http://icml.cc/2012/papers/168.pdf
def initialize(algo, eps=1, window_length=5):
algo.stocks = STOCKS
algo.sids = [algo.symbol(symbol) for symbol in algo.stocks]
algo.m = len(algo.stocks)
algo.price = {}
algo.b_t = np.ones(algo.m) / algo.m
@@ -52,11 +53,11 @@ def handle_data(algo, data):
x_tilde = np.zeros(m)
b = np.zeros(m)
# find relative moving average price for each security
for i, stock in enumerate(algo.stocks):
price = data[stock].price
# find relative moving average price for each asset
for i, sid in enumerate(algo.sids):
price = data[sid].price
# Relative mean deviation
x_tilde[i] = data[stock].mavg(algo.window_length) / price
x_tilde[i] = data[sid].mavg(algo.window_length) / price
###########################
# Inside of OLMAR (algo 2)
@@ -98,17 +99,17 @@ def rebalance_portfolio(algo, data, desired_port):
positions_value = algo.portfolio.positions_value + \
algo.portfolio.cash
for i, stock in enumerate(algo.stocks):
current_amount[i] = algo.portfolio.positions[stock].amount
prices[i] = data[stock].price
for i, sid in enumerate(algo.sids):
current_amount[i] = algo.portfolio.positions[sid].amount
prices[i] = data[sid].price
desired_amount = np.round(desired_port * positions_value / prices)
algo.last_desired_port = desired_port
diff_amount = desired_amount - current_amount
for i, stock in enumerate(algo.stocks):
algo.order(stock, diff_amount[i])
for i, sid in enumerate(algo.sids):
algo.order(sid, diff_amount[i])
def simplex_projection(v, b=1):
@@ -154,7 +155,9 @@ if __name__ == '__main__':
end = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=STOCKS, indexes={}, start=start, end=end)
data = data.dropna()
olmar = TradingAlgorithm(handle_data=handle_data, initialize=initialize)
olmar = TradingAlgorithm(handle_data=handle_data,
initialize=initialize,
identifiers=STOCKS)
results = olmar.run(data)
results.portfolio_value.plot()
pl.show()
+18 -12
View File
@@ -23,6 +23,7 @@ import pytz
from zipline.algorithm import TradingAlgorithm
from zipline.transforms import batch_transform
from zipline.utils.factory import load_from_yahoo
from zipline.sources.data_frame_source import DataFrameSource
@batch_transform
@@ -59,11 +60,13 @@ class Pairtrade(TradingAlgorithm):
self.window_length = window_length
self.ols_transform = ols_transform(refresh_period=self.window_length,
window_length=self.window_length)
self.PEP = self.symbol('PEP')
self.KO = self.symbol('KO')
def handle_data(self, data):
######################################################
# 1. Compute regression coefficients between PEP and KO
params = self.ols_transform.handle_data(data, 'PEP', 'KO')
params = self.ols_transform.handle_data(data, self.PEP, self.KO)
if params is None:
return
intercept, slope = params
@@ -81,7 +84,8 @@ class Pairtrade(TradingAlgorithm):
"""1. Compute the spread given slope and intercept.
2. zscore the spread.
"""
spread = (data['PEP'].price - (slope * data['KO'].price + intercept))
spread = (data[self.PEP].price -
(slope * data[self.KO].price + intercept))
self.spreads.append(spread)
spread_wind = self.spreads[-self.window_length:]
zscore = (spread - np.mean(spread_wind)) / np.std(spread_wind)
@@ -91,12 +95,12 @@ class Pairtrade(TradingAlgorithm):
"""Buy spread if zscore is > 2, sell if zscore < .5.
"""
if zscore >= 2.0 and not self.invested:
self.order('PEP', int(100 / data['PEP'].price))
self.order('KO', -int(100 / data['KO'].price))
self.order(self.PEP, int(100 / data[self.PEP].price))
self.order(self.KO, -int(100 / data[self.KO].price))
self.invested = True
elif zscore <= -2.0 and not self.invested:
self.order('PEP', -int(100 / data['PEP'].price))
self.order('KO', int(100 / data['KO'].price))
self.order(self.PEP, -int(100 / data[self.PEP].price))
self.order(self.KO, int(100 / data[self.KO].price))
self.invested = True
elif abs(zscore) < .5 and self.invested:
self.sell_spread()
@@ -107,23 +111,25 @@ class Pairtrade(TradingAlgorithm):
decrease exposure, regardless of position long/short.
buy for a short position, sell for a long.
"""
ko_amount = self.portfolio.positions['KO'].amount
self.order('KO', -1 * ko_amount)
pep_amount = self.portfolio.positions['PEP'].amount
self.order('PEP', -1 * pep_amount)
ko_amount = self.portfolio.positions[self.KO].amount
self.order(self.KO, -1 * ko_amount)
pep_amount = self.portfolio.positions[self.PEP].amount
self.order(self.PEP, -1 * pep_amount)
if __name__ == '__main__':
start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc)
end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc)
data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={},
start=start, end=end)
source = DataFrameSource(data)
pairtrade = Pairtrade()
results = pairtrade.run(data)
results = pairtrade.run(source)
data['spreads'] = np.nan
ax1 = plt.subplot(211)
data[['PEP', 'KO']].plot(ax=ax1)
# TODO Bugged - indices are out of bounds
# data[[pairtrade.PEPsid, pairtrade.KOsid]].plot(ax=ax1)
plt.ylabel('price')
plt.setp(ax1.get_xticklabels(), visible=False)
+5 -3
View File
@@ -19,15 +19,16 @@ import pytz
from zipline import TradingAlgorithm
from zipline.utils.factory import load_from_yahoo
from zipline.api import order
from zipline.api import order, symbol
def initialize(context):
context.test = 10
context.aapl = symbol('AAPL')
def handle_date(context, data):
order('AAPL', 10)
order(context.aapl, 10)
print(context.test)
@@ -39,7 +40,8 @@ if __name__ == '__main__':
end=end)
data = data.dropna()
algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_date)
handle_data=handle_date,
identifiers=['AAPL'])
results = algo.run(data)
results.portfolio_value.plot()
pl.show()
+61 -36
View File
@@ -37,7 +37,7 @@ class TradingControl(with_metaclass(abc.ABCMeta)):
@abc.abstractmethod
def validate(self,
sid,
asset,
amount,
portfolio,
algo_datetime,
@@ -46,21 +46,22 @@ class TradingControl(with_metaclass(abc.ABCMeta)):
Before any order is executed by TradingAlgorithm, this method should be
called *exactly once* on each registered TradingControl object.
If the specified sid and amount do not violate this TradingControl's
If the specified asset and amount do not violate this TradingControl's
restraint given the information in `portfolio`, this method should
return None and have no externally-visible side-effects.
If the desired order violates this TradingControl's contraint, this
method should call self.fail(sid, amount).
method should call self.fail(asset, amount).
"""
raise NotImplementedError
def fail(self, sid, amount):
def fail(self, asset, amount, datetime):
"""
Raise a TradingControlViolation with information about the failure.
"""
raise TradingControlViolation(sid=sid,
raise TradingControlViolation(asset=asset,
amount=amount,
datetime=datetime,
constraint=repr(self))
def __repr__(self):
@@ -82,7 +83,7 @@ class MaxOrderCount(TradingControl):
self.current_date = None
def validate(self,
sid,
asset,
amount,
_portfolio,
algo_datetime,
@@ -98,13 +99,13 @@ class MaxOrderCount(TradingControl):
self.current_date = algo_date
if self.orders_placed >= self.max_count:
self.fail(sid, amount)
self.fail(asset, amount, algo_datetime)
self.orders_placed += 1
class RestrictedListOrder(TradingControl):
"""
TradingControl representing a restricted list of securities that
TradingControl representing a restricted list of assets that
cannot be ordered by the algorithm.
"""
@@ -119,30 +120,30 @@ class RestrictedListOrder(TradingControl):
self.restricted_list = restricted_list
def validate(self,
sid,
asset,
amount,
_portfolio,
_algo_datetime,
_algo_current_data):
"""
Fail if the sid is in the restricted_list.
Fail if the asset is in the restricted_list.
"""
if sid in self.restricted_list:
self.fail(sid, amount)
if asset in self.restricted_list:
self.fail(asset, amount, _algo_datetime)
class MaxOrderSize(TradingControl):
"""
TradingControl representing a limit on the magnitude of any single order
placed with the given security. Can be specified by share or by dollar
placed with the given asset. Can be specified by share or by dollar
value.
"""
def __init__(self, sid=None, max_shares=None, max_notional=None):
super(MaxOrderSize, self).__init__(sid=sid,
def __init__(self, asset=None, max_shares=None, max_notional=None):
super(MaxOrderSize, self).__init__(asset=asset,
max_shares=max_shares,
max_notional=max_notional)
self.sid = sid
self.asset = asset
self.max_shares = max_shares
self.max_notional = max_notional
@@ -162,7 +163,7 @@ class MaxOrderSize(TradingControl):
)
def validate(self,
sid,
asset,
amount,
portfolio,
_algo_datetime,
@@ -172,33 +173,33 @@ class MaxOrderSize(TradingControl):
or self.max_notional.
"""
if self.sid is not None and self.sid != sid:
if self.asset is not None and self.asset != asset:
return
if self.max_shares is not None and abs(amount) > self.max_shares:
self.fail(sid, amount)
self.fail(asset, amount, _algo_datetime)
current_sid_price = algo_current_data[sid].price
order_value = amount * current_sid_price
current_asset_price = algo_current_data[asset].price
order_value = amount * current_asset_price
too_much_value = (self.max_notional is not None and
abs(order_value) > self.max_notional)
if too_much_value:
self.fail(sid, amount)
self.fail(asset, amount, _algo_datetime)
class MaxPositionSize(TradingControl):
"""
TradingControl representing a limit on the maximum position size that can
be held by an algo for a given security.
be held by an algo for a given asset.
"""
def __init__(self, sid=None, max_shares=None, max_notional=None):
super(MaxPositionSize, self).__init__(sid=sid,
def __init__(self, asset=None, max_shares=None, max_notional=None):
super(MaxPositionSize, self).__init__(asset=asset,
max_shares=max_shares,
max_notional=max_notional)
self.sid = sid
self.asset = asset
self.max_shares = max_shares
self.max_notional = max_notional
@@ -218,7 +219,7 @@ class MaxPositionSize(TradingControl):
)
def validate(self,
sid,
asset,
amount,
portfolio,
algo_datetime,
@@ -229,25 +230,25 @@ class MaxPositionSize(TradingControl):
self.max_notional.
"""
if self.sid is not None and self.sid != sid:
if self.asset is not None and self.asset != asset:
return
current_share_count = portfolio.positions[sid].amount
current_share_count = portfolio.positions[asset].amount
shares_post_order = current_share_count + amount
too_many_shares = (self.max_shares is not None and
abs(shares_post_order) > self.max_shares)
if too_many_shares:
self.fail(sid, amount)
self.fail(asset, amount, algo_datetime)
current_price = algo_current_data[sid].price
current_price = algo_current_data[asset].price
value_post_order = shares_post_order * current_price
too_much_value = (self.max_notional is not None and
abs(value_post_order) > self.max_notional)
if too_much_value:
self.fail(sid, amount)
self.fail(asset, amount, algo_datetime)
class LongOnly(TradingControl):
@@ -256,17 +257,41 @@ class LongOnly(TradingControl):
"""
def validate(self,
sid,
asset,
amount,
portfolio,
_algo_datetime,
_algo_current_data):
"""
Fail if we would hold negative shares of sid after completing this
Fail if we would hold negative shares of asset after completing this
order.
"""
if portfolio.positions[sid].amount + amount < 0:
self.fail(sid, amount)
if portfolio.positions[asset].amount + amount < 0:
self.fail(asset, amount, _algo_datetime)
class AssetDateBounds(TradingControl):
"""
TradingControl representing a prohibition against ordering an asset before
its start_date, or after its end_date.
"""
def validate(self,
asset,
amount,
portfolio,
algo_datetime,
algo_current_data):
"""
Fail if the algo has passed this Asset's end_date, or before the
Asset's start date.
"""
# Fail if the algo is before this Asset's start_date
if asset.start_date and (algo_datetime < asset.start_date):
self.fail(asset, amount, algo_datetime)
# Fail if the algo has passed this Asset's end_date
if asset.end_date and (algo_datetime >= asset.end_date):
self.fail(asset, amount, algo_datetime)
class AccountControl(with_metaclass(abc.ABCMeta)):
+41 -5
View File
@@ -31,7 +31,7 @@ omitted).
| | end of the period |
+---------------+------------------------------------------------------+
| cash_flow | the cash flow in the period (negative means spent) |
| | from buying and selling securities in the period. |
| | from buying and selling assets in the period. |
| | Includes dividend payments in the period as well. |
+---------------+------------------------------------------------------+
| starting_value| the total market value of the positions held at the |
@@ -75,6 +75,9 @@ import logbook
import numpy as np
from zipline.finance.trading import TradingEnvironment
from zipline.assets import Future
try:
# optional cython based OrderedDict
from cyordereddict import OrderedDict
@@ -128,6 +131,7 @@ class PerformancePeriod(object):
self.period_close = period_close
self.ending_value = 0.0
self.ending_exposure = 0.0
self.period_cash_flow = 0.0
self.pnl = 0.0
@@ -160,6 +164,7 @@ class PerformancePeriod(object):
def rollover(self):
self.starting_value = self.ending_value
self.starting_exposure = self.ending_exposure
self.starting_cash = self.ending_cash
self.period_cash_flow = 0.0
self.pnl = 0.0
@@ -187,6 +192,7 @@ class PerformancePeriod(object):
def calculate_performance(self):
self.ending_value = self.calculate_positions_value()
self.ending_exposure = self.calculate_positions_exposure()
total_at_start = self.starting_cash + self.starting_value
self.ending_cash = self.starting_cash + self.period_cash_flow
@@ -215,7 +221,12 @@ class PerformancePeriod(object):
self.orders_by_id[order.id] = order
def handle_execution(self, txn):
self.period_cash_flow -= txn.price * txn.amount
asset = TradingEnvironment.instance().asset_finder.\
retrieve_asset(txn.sid)
# Futures experience no cash flow on transactions
if not isinstance(asset, Future):
self.period_cash_flow -= txn.price * txn.amount
if self.keep_transactions:
try:
@@ -232,6 +243,10 @@ class PerformancePeriod(object):
def position_amounts(self):
return self.position_tracker.position_amounts
@position_proxy
def calculate_positions_exposure(self):
raise ProxyError()
@position_proxy
def calculate_positions_value(self):
raise ProxyError()
@@ -244,6 +259,10 @@ class PerformancePeriod(object):
def _long_exposure(self):
raise ProxyError()
@position_proxy
def _long_value(self):
raise ProxyError()
@position_proxy
def _shorts_count(self):
raise ProxyError()
@@ -252,19 +271,29 @@ class PerformancePeriod(object):
def _short_exposure(self):
raise ProxyError()
@position_proxy
def _short_value(self):
raise ProxyError()
@position_proxy
def _gross_exposure(self):
raise ProxyError()
@position_proxy
def _gross_value(self):
raise ProxyError()
@position_proxy
def _net_exposure(self):
raise ProxyError()
@position_proxy
def _net_value(self):
raise ProxyError()
@property
def _net_liquidation_value(self):
return self.ending_cash + \
self._long_exposure() + \
self._short_exposure()
return self.ending_cash + self._long_value() + self._short_value()
def _gross_leverage(self):
net_liq = self._net_liquidation_value
@@ -283,10 +312,12 @@ class PerformancePeriod(object):
def __core_dict(self):
rval = {
'ending_value': self.ending_value,
'ending_exposure': self.ending_exposure,
# this field is renamed to capital_used for backward
# compatibility.
'capital_used': self.period_cash_flow,
'starting_value': self.starting_value,
'starting_exposure': self.starting_exposure,
'starting_cash': self.starting_cash,
'ending_cash': self.ending_cash,
'portfolio_value': self.ending_cash + self.ending_value,
@@ -298,6 +329,8 @@ class PerformancePeriod(object):
'net_leverage': self._net_leverage(),
'short_exposure': self._short_exposure(),
'long_exposure': self._long_exposure(),
'short_value': self._short_value(),
'long_value': self._long_value(),
'longs_count': self._longs_count(),
'shorts_count': self._shorts_count()
}
@@ -372,6 +405,7 @@ class PerformancePeriod(object):
portfolio.start_date = self.period_open
portfolio.positions = self.get_positions()
portfolio.positions_value = self.ending_value
portfolio.positions_exposure = self.ending_exposure
return portfolio
def as_account(self):
@@ -393,6 +427,8 @@ class PerformancePeriod(object):
self.ending_cash + self.ending_value)
account.total_positions_value = \
getattr(self, 'total_positions_value', self.ending_value)
account.total_positions_value = \
getattr(self, 'total_positions_exposure', self.ending_exposure)
account.regt_equity = \
getattr(self, 'regt_equity', self.ending_cash)
account.regt_margin = \
+2 -3
View File
@@ -20,12 +20,11 @@ Position Tracking
+-----------------+----------------------------------------------------+
| key | value |
+=================+====================================================+
| sid | the identifier for the security held in this |
| | position. |
| sid | the sid for the asset held in this position |
+-----------------+----------------------------------------------------+
| amount | whole number of shares in the position |
+-----------------+----------------------------------------------------+
| last_sale_price | price at last sale of the security on the exchange |
| last_sale_price | price at last sale of the asset on the exchange |
+-----------------+----------------------------------------------------+
| cost_basis | the volume weighted average price paid per share |
+-----------------+----------------------------------------------------+
+120 -28
View File
@@ -1,5 +1,4 @@
from __future__ import division
from operator import mul
import logbook
import numpy as np
@@ -10,8 +9,7 @@ try:
from cyordereddict import OrderedDict
except ImportError:
from collections import OrderedDict
from six import iteritems
from six.moves import map, filter
from six import iteritems, itervalues
from zipline.finance.slippage import Transaction
from zipline.utils.serialization_utils import (
@@ -19,6 +17,10 @@ from zipline.utils.serialization_utils import (
)
import zipline.protocol as zp
from zipline.assets import (
Equity, Future
)
from zipline.finance.trading import with_environment
from . position import positiondict
log = logbook.Logger('Performance')
@@ -32,26 +34,69 @@ class PositionTracker(object):
# Arrays for quick calculations of positions value
self._position_amounts = OrderedDict()
self._position_last_sale_prices = OrderedDict()
self._position_value_multipliers = OrderedDict()
self._position_exposure_multipliers = OrderedDict()
self._unpaid_dividends = pd.DataFrame(
columns=zp.DIVIDEND_PAYMENT_FIELDS,
)
self._positions_store = zp.Positions()
# Cached for fast property calculation
self._position_values = None
self._position_exposures = None
def _invalidate_cache(self):
self._position_values = None
self._position_exposures = None
@with_environment()
def _retrieve_asset(self, sid, env=None):
return env.asset_finder.retrieve_asset(sid)
def _update_multipliers(self, sid):
try:
self._position_value_multipliers[sid]
self._position_exposure_multipliers[sid]
except KeyError:
# Collect the value multipliers from applicable sids
asset = self._retrieve_asset(sid)
if isinstance(asset, Equity):
self._position_value_multipliers[sid] = 1
self._position_exposure_multipliers[sid] = 1
if isinstance(asset, Future):
self._position_value_multipliers[sid] = 0
self._position_exposure_multipliers[sid] = \
asset.contract_multiplier
def update_last_sale(self, event):
# NOTE, PerformanceTracker already vetted as TRADE type
sid = event.sid
if sid not in self.positions:
return
return 0
price = event.price
if not checknull(price):
pos = self.positions[sid]
pos.last_sale_date = event.dt
pos.last_sale_price = price
self._position_last_sale_prices[sid] = price
self._position_values = None # invalidate cache
sid = event.sid
price = event.price
if checknull(price):
return 0
pos = self.positions[sid]
old_price = pos.last_sale_price
pos.last_sale_date = event.dt
pos.last_sale_price = price
self._position_last_sale_prices[sid] = price
self._invalidate_cache()
asset = self._retrieve_asset(sid)
if asset is None:
return 0
# Calculate cash adjustment on futures
cash_adjustment = 0
if isinstance(asset, Future):
price_change = price - old_price
cash_adjustment = \
price_change * asset.contract_multiplier * pos.amount
return cash_adjustment
def update_positions(self, positions):
# update positions in batch
@@ -59,8 +104,8 @@ class PositionTracker(object):
for sid, pos in iteritems(positions):
self._position_amounts[sid] = pos.amount
self._position_last_sale_prices[sid] = pos.last_sale_price
# Invalidate cache.
self._position_values = None # invalidate cache
self._update_multipliers(sid)
self._invalidate_cache()
def update_position(self, sid, amount=None, last_sale_price=None,
last_sale_date=None, cost_basis=None):
@@ -70,6 +115,7 @@ class PositionTracker(object):
pos.amount = amount
self._position_amounts[sid] = amount
self._position_values = None # invalidate cache
self._update_multipliers(sid=sid)
if last_sale_price is not None:
pos.last_sale_price = last_sale_price
self._position_last_sale_prices[sid] = last_sale_price
@@ -82,13 +128,13 @@ class PositionTracker(object):
def execute_transaction(self, txn):
# Update Position
# ----------------
sid = txn.sid
position = self.positions[sid]
position.update(txn)
self._position_amounts[sid] = position.amount
self._position_last_sale_prices[sid] = position.last_sale_price
self._position_values = None # invalidate cache
self._update_multipliers(sid)
self._invalidate_cache()
def handle_commission(self, commission):
# Adjust the cost basis of the stock if we own it
@@ -96,8 +142,6 @@ class PositionTracker(object):
self.positions[commission.sid].\
adjust_commission_cost_basis(commission)
_position_values = None
@property
def position_values(self):
"""
@@ -105,33 +149,75 @@ class PositionTracker(object):
self._position_last_sale_prices is changed.
"""
if self._position_values is None:
vals = list(map(mul, self._position_amounts.values(),
self._position_last_sale_prices.values()))
self._position_values = vals
iter_amount_price_multiplier = zip(
itervalues(self._position_amounts),
itervalues(self._position_last_sale_prices),
itervalues(self._position_value_multipliers),
)
self._position_values = [
price * amount * multiplier for
price, amount, multiplier in iter_amount_price_multiplier
]
return self._position_values
@property
def position_exposures(self):
"""
Invalidate any time self._position_amounts or
self._position_last_sale_prices is changed.
"""
if self._position_exposures is None:
iter_amount_price_multiplier = zip(
itervalues(self._position_amounts),
itervalues(self._position_last_sale_prices),
itervalues(self._position_exposure_multipliers),
)
self._position_exposures = [
price * amount * multiplier for
price, amount, multiplier in iter_amount_price_multiplier
]
return self._position_exposures
def calculate_positions_value(self):
if len(self.position_values) == 0:
return np.float64(0)
return sum(self.position_values)
def calculate_positions_exposure(self):
if len(self.position_exposures) == 0:
return np.float64(0)
return sum(self.position_exposures)
def _longs_count(self):
return sum(map(lambda x: x > 0, self.position_values))
return sum(1 for i in self.position_exposures if i > 0)
def _long_exposure(self):
return sum(filter(lambda x: x > 0, self.position_values))
return sum(i for i in self.position_exposures if i > 0)
def _long_value(self):
return sum(i for i in self.position_values if i > 0)
def _shorts_count(self):
return sum(map(lambda x: x < 0, self.position_values))
return sum(1 for i in self.position_exposures if i < 0)
def _short_exposure(self):
return sum(filter(lambda x: x < 0, self.position_values))
return sum(i for i in self.position_exposures if i < 0)
def _short_value(self):
return sum(i for i in self.position_values if i < 0)
def _gross_exposure(self):
return self._long_exposure() + abs(self._short_exposure())
def _gross_value(self):
return self._long_value() + abs(self._short_value())
def _net_exposure(self):
return self.calculate_positions_exposure()
def _net_value(self):
return self.calculate_positions_value()
def handle_split(self, split):
@@ -143,7 +229,8 @@ class PositionTracker(object):
self._position_amounts[split.sid] = position.amount
self._position_last_sale_prices[split.sid] = \
position.last_sale_price
self._position_values = None # invalidate cache
self._update_multipliers(split.sid)
self._invalidate_cache()
return leftover_cash
def _maybe_earn_dividend(self, dividend):
@@ -204,15 +291,17 @@ class PositionTracker(object):
stock = row['payment_sid']
share_count = row['share_count']
# note we create a Position for stock dividend if we don't
# already own the security
# already own the asset
position = self.positions[stock]
position.amount += share_count
self._position_amounts[stock] = position.amount
self._position_last_sale_prices[stock] = position.last_sale_price
self._update_multipliers(stock)
self._invalidate_cache()
# Add cash equal to the net cash payed from all dividends. Note that
# "negative cash" is effectively paid if we're short a security,
# "negative cash" is effectively paid if we're short an asset,
# representing the fact that we're required to reimburse the owner of
# the stock for any dividends paid while borrowing.
net_cash_payment = payments['cash_amount'].fillna(0).sum()
@@ -290,5 +379,8 @@ class PositionTracker(object):
# Arrays for quick calculations of positions value
self._position_amounts = OrderedDict()
self._position_last_sale_prices = OrderedDict()
self._position_value_multipliers = OrderedDict()
self._position_exposure_multipliers = OrderedDict()
self._invalidate_cache()
self.update_positions(state['positions'])
+22 -28
View File
@@ -68,10 +68,9 @@ import pandas as pd
from pandas.tseries.tools import normalize_date
import zipline.finance.risk as risk
from zipline.finance import trading
from zipline.finance.trading import TradingEnvironment
from . period import PerformancePeriod
from zipline.finance.trading import with_environment
from zipline.utils.serialization_utils import (
VERSION_LABEL
)
@@ -84,26 +83,23 @@ class PerformanceTracker(object):
"""
Tracks the performance of the algorithm.
"""
@with_environment()
def __init__(self, sim_params, env=None):
def __init__(self, sim_params):
self.sim_params = sim_params
env = TradingEnvironment.instance()
self.period_start = self.sim_params.period_start
self.period_end = self.sim_params.period_end
self.last_close = self.sim_params.last_close
first_open = self.sim_params.first_open.tz_convert(
trading.environment.exchange_tz)
first_open = self.sim_params.first_open.tz_convert(env.exchange_tz)
self.day = pd.Timestamp(datetime(first_open.year, first_open.month,
first_open.day), tz='UTC')
self.market_open, self.market_close = \
trading.environment.get_open_and_close(self.day)
self.market_open, self.market_close = env.get_open_and_close(self.day)
self.total_days = self.sim_params.days_in_period
self.capital_base = self.sim_params.capital_base
self.emission_rate = sim_params.emission_rate
all_trading_days = trading.environment.trading_days
all_trading_days = env.trading_days
mask = ((all_trading_days >= normalize_date(self.period_start)) &
(all_trading_days <= normalize_date(self.period_end)))
@@ -287,10 +283,13 @@ class PerformanceTracker(object):
return _dict
def process_trade(self, event):
self.position_tracker.update_last_sale(event)
# update last sale, and pay out a cash adjustment
cash_adjustment = self.position_tracker.update_last_sale(event)
if cash_adjustment != 0:
for perf_period in self.perf_periods:
perf_period.handle_cash_payment(cash_adjustment)
def process_transaction(self, event):
self.txn_count += 1
self.position_tracker.execute_transaction(event)
for perf_period in self.perf_periods:
@@ -342,7 +341,7 @@ class PerformanceTracker(object):
if txn:
self.process_transaction(txn)
def check_upcoming_dividends(self, midnight_of_date_that_just_ended):
def check_upcoming_dividends(self, next_trading_day):
"""
Check if we currently own any stocks with dividends whose ex_date is
the next trading day. Track how much we should be payed on those
@@ -357,17 +356,6 @@ class PerformanceTracker(object):
# period, so bail.
return
next_trading_day_idx = self.trading_days.get_loc(
midnight_of_date_that_just_ended,
) + 1
if next_trading_day_idx < len(self.trading_days):
next_trading_day = self.trading_days[next_trading_day_idx]
else:
# Bail if the next trading day is outside our trading range, since
# we won't simulate the next day.
return
# Dividends whose ex_date is the next trading day. We need to check if
# we own any of these stocks so we know to pay them out when the pay
# date comes.
@@ -412,7 +400,10 @@ class PerformanceTracker(object):
# if this is the close, save the returns objects for cumulative risk
# calculations and update dividends for the next day.
if dt == self.market_close:
self.check_upcoming_dividends(todays_date)
next_trading_day = TradingEnvironment.instance().\
next_trading_day(todays_date)
if next_trading_day:
self.check_upcoming_dividends(next_trading_day)
def handle_intraday_market_close(self, new_mkt_open, new_mkt_close):
"""
@@ -454,16 +445,19 @@ class PerformanceTracker(object):
return daily_update
# move the market day markers forward
env = TradingEnvironment.instance()
self.market_open, self.market_close = \
trading.environment.next_open_and_close(self.day)
self.day = trading.environment.next_trading_day(self.day)
env.next_open_and_close(self.day)
self.day = env.next_trading_day(self.day)
# Roll over positions to current day.
self.todays_performance.rollover()
self.todays_performance.period_open = self.market_open
self.todays_performance.period_close = self.market_close
self.check_upcoming_dividends(completed_date)
next_trading_day = env.next_trading_day(completed_date)
if next_trading_day:
self.check_upcoming_dividends(next_trading_day)
return daily_update
+50
View File
@@ -23,6 +23,8 @@ import numpy as np
from zipline.data.loader import load_market_data
from zipline.utils import tradingcalendar
from zipline.assets import AssetFinder
from zipline.errors import UpdateAssetFinderTypeError
log = logbook.Logger('Trading')
@@ -134,6 +136,8 @@ class TradingEnvironment(object):
self.exchange_tz = exchange_tz
self.asset_finder = AssetFinder(trading_calendar=env_trading_calendar)
def __enter__(self, *args, **kwargs):
global environment
self.prev_environment = environment
@@ -149,6 +153,52 @@ class TradingEnvironment(object):
# stack.
return False
def update_asset_finder(self,
clear_metadata=False,
asset_finder=None,
asset_metadata=None,
identifiers=None):
"""
Updates the AssetFinder using the provided asset metadata and
identifiers.
If clear_metadata is True, all metadata and assets held in the
asset_finder will be erased before new metadata is provided.
If asset_finder is provided, the existing asset_finder will be replaced
outright with the new asset_finder.
If asset_metadata is provided, the existing metadata will be cleared
and replaced with the provided metadata.
All identifiers will be inserted in the asset metadata if they are not
already present.
:param clear_metadata: A boolean
:param asset_finder: An AssetFinder object to replace the environment's
existing asset_finder
:param asset_metadata: A dict, DataFrame, or readable object
:param identifiers: A list of identifiers to be inserted
:return:
"""
populate = False
if clear_metadata:
self.asset_finder.clear_metadata()
populate = True
if asset_finder is not None:
if not isinstance(asset_finder, AssetFinder):
raise UpdateAssetFinderTypeError(cls=asset_finder.__class__)
self.asset_finder = asset_finder
if asset_metadata is not None:
self.asset_finder.clear_metadata()
self.asset_finder.consume_metadata(asset_metadata)
populate = True
if identifiers is not None:
self.asset_finder.consume_identifiers(identifiers)
populate = True
if populate:
self.asset_finder.populate_cache()
def normalize_date(self, test_date):
test_date = pd.Timestamp(test_date, tz='UTC')
return pd.tseries.tools.normalize_date(test_date)
+1 -1
View File
@@ -217,7 +217,7 @@ class HistoryContainer(object):
history_specs (dict[Frequency:HistorySpec]): The starting history
specs that this container should be able to service.
initial_sids (set[Security or Int]): The starting sids to watch.
initial_sids (set[Asset or Int]): The starting sids to watch.
initial_dt (datetime): The datetime to start collecting history from.
+7 -3
View File
@@ -58,7 +58,12 @@ DIVIDEND_FIELDS = [
'sid',
]
# Expected fields/index values for a dividend payment Series.
DIVIDEND_PAYMENT_FIELDS = ['id', 'payment_sid', 'cash_amount', 'share_count']
DIVIDEND_PAYMENT_FIELDS = [
'id',
'payment_sid',
'cash_amount',
'share_count',
]
def dividend_payment(data=None):
@@ -73,7 +78,7 @@ def dividend_payment(data=None):
payment.
Additionally, if @data is non-empty, either data['cash_amount'] should be
nonzero or data['payment_sid'] should be a security identifier and
nonzero or data['payment_sid'] should be an asset identifier and
data['share_count'] should be nonzero.
The returned Series is given its id value as a name so that concatenating
@@ -289,7 +294,6 @@ class SIDData(object):
def __init__(self, sid, initial_values=None):
self._sid = sid
self._freqstr = None
# To check if we have data, we use the __len__ which depends on the
+30 -10
View File
@@ -22,6 +22,7 @@ import pandas as pd
from zipline.gens.utils import hash_args
from zipline.sources.data_source import DataSource
from zipline.finance.trading import with_environment
class DataFrameSource(DataSource):
@@ -36,14 +37,23 @@ class DataFrameSource(DataSource):
Bars where the price is nan are filtered out.
"""
def __init__(self, data, **kwargs):
@with_environment()
def __init__(self, data, env=None, **kwargs):
assert isinstance(data.index, pd.tseries.index.DatetimeIndex)
self.data = data
self.data = data.fillna(method='ffill')
# Unpack config dictionary with default values.
self.sids = kwargs.get('sids', data.columns)
self.start = kwargs.get('start', data.index[0])
self.end = kwargs.get('end', data.index[-1])
self.start = kwargs.get('start', self.data.index[0])
self.end = kwargs.get('end', self.data.index[-1])
# Remap sids based on the trading environment
self.identifiers = kwargs.get('sids', self.data.columns)
env.update_asset_finder(identifiers=self.identifiers)
self.data.columns = [
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
for identifier in self.data.columns
]
self.sids = self.data.columns
# Hash_value for downstream sorting.
self.arg_string = hash_args(data, **kwargs)
@@ -105,14 +115,24 @@ class DataPanelSource(DataSource):
Bars where the price is nan are filtered out.
"""
def __init__(self, data, **kwargs):
@with_environment()
def __init__(self, data, env=None, **kwargs):
assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex)
self.data = data
self.data = data.fillna(value={'volume': 0})
self.data = self.data.fillna(method='ffill')
# Unpack config dictionary with default values.
self.sids = kwargs.get('sids', data.items)
self.start = kwargs.get('start', data.major_axis[0])
self.end = kwargs.get('end', data.major_axis[-1])
self.start = kwargs.get('start', self.data.major_axis[0])
self.end = kwargs.get('end', self.data.major_axis[-1])
# Remap sids based on the trading environment
self.identifiers = kwargs.get('sids', self.data.items)
env.update_asset_finder(identifiers=self.identifiers)
self.data.items = [
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
for identifier in self.data.items
]
self.sids = self.data.items
# Hash_value for downstream sorting.
self.arg_string = hash_args(data, **kwargs)
+2
View File
@@ -23,6 +23,7 @@ import pandas as pd
from zipline.sources.data_source import DataSource
from zipline.utils import tradingcalendar as calendar_nyse
from zipline.gens.utils import hash_args
from zipline.finance import trading
class RandomWalkSource(DataSource):
@@ -92,6 +93,7 @@ class RandomWalkSource(DataSource):
self.sd = sd
self.sids = self.start_prices.keys()
trading.environment.update_asset_finder(identifiers=self.sids)
self.open_and_closes = \
calendar.open_and_closes[self.start:self.end]
+24 -7
View File
@@ -22,7 +22,6 @@ import pytz
from six.moves import filter
from datetime import datetime, timedelta
import itertools
import numpy as np
from six.moves import range
@@ -112,8 +111,8 @@ class SpecificEquityTrades(object):
delta : timedelta between internal events
filter : filter to remove the sids
"""
def __init__(self, *args, **kwargs):
@with_environment()
def __init__(self, env=None, *args, **kwargs):
# We shouldn't get any positional arguments.
assert len(args) == 0
@@ -125,9 +124,7 @@ class SpecificEquityTrades(object):
# This isn't really clean and ultimately I think this
# class should serve a single purpose (either take an
# event_list or autocreate events).
self.sids = kwargs.get(
'sids',
np.unique([event.sid for event in self.event_list]).tolist())
self.count = kwargs.get('count', len(self.event_list))
self.start = kwargs.get('start', self.event_list[0].dt)
self.end = kwargs.get('end', self.event_list[-1].dt)
self.delta = kwargs.get(
@@ -135,9 +132,22 @@ class SpecificEquityTrades(object):
self.event_list[1].dt - self.event_list[0].dt)
self.concurrent = kwargs.get('concurrent', False)
self.identifiers = kwargs.get(
'sids',
set(event.sid for event in self.event_list)
)
env.update_asset_finder(identifiers=self.identifiers)
self.sids = [
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
for identifier in self.identifiers
]
for event in self.event_list:
event.sid = env.asset_finder.\
retrieve_asset_by_identifier(event.sid).sid
else:
# Unpack config dictionary with default values.
self.sids = kwargs.get('sids', [1, 2])
self.count = kwargs.get('count', 500)
self.start = kwargs.get(
'start',
datetime(2008, 6, 6, 15, tzinfo=pytz.utc))
@@ -149,6 +159,13 @@ class SpecificEquityTrades(object):
timedelta(minutes=1))
self.concurrent = kwargs.get('concurrent', False)
self.identifiers = kwargs.get('sids', [1, 2])
env.update_asset_finder(identifiers=self.identifiers)
self.sids = [
env.asset_finder.retrieve_asset_by_identifier(identifier).sid
for identifier in self.identifiers
]
# Hash_value for downstream sorting.
self.arg_string = hash_args(*args, **kwargs)
+107 -63
View File
@@ -85,14 +85,17 @@ from zipline.api import (
order,
set_slippage,
record,
sid,
)
from zipline.errors import UnsupportedOrderParameters
from zipline.assets import Future, Equity
from zipline.finance.execution import (
LimitOrder,
MarketOrder,
StopLimitOrder,
StopOrder,
)
from zipline.finance.controls import AssetDateBounds
from zipline.transforms import BatchTransform, batch_transform
@@ -111,14 +114,14 @@ class TestAlgorithm(TradingAlgorithm):
slippage=None,
commission=None):
self.count = order_count
self.sid = sid
self.asset = self.sid(sid)
self.amount = amount
self.incr = 0
if sid_filter:
self.sid_filter = sid_filter
else:
self.sid_filter = [self.sid]
self.sid_filter = [self.asset.sid]
if slippage is not None:
self.set_slippage(slippage)
@@ -129,7 +132,7 @@ class TestAlgorithm(TradingAlgorithm):
def handle_data(self, data):
# place an order for amount shares of sid
if self.incr < self.count:
self.order(self.sid, self.amount)
self.order(self.asset, self.amount)
self.incr += 1
@@ -141,13 +144,13 @@ class HeavyBuyAlgorithm(TradingAlgorithm):
"""
def initialize(self, sid, amount):
self.sid = sid
self.asset = self.sid(sid)
self.amount = amount
self.incr = 0
def handle_data(self, data):
# place an order for 100 shares of sid
self.order(self.sid, self.amount)
self.order(self.asset, self.amount)
self.incr += 1
@@ -177,7 +180,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
def initialize(self, throw_from, sid):
self.throw_from = throw_from
self.sid = sid
self.asset = self.sid(sid)
if self.throw_from == "initialize":
raise Exception("Algo exception in initialize")
@@ -200,7 +203,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
if self.throw_from == "get_sid_filter":
raise Exception("Algo exception in get_sid_filter")
else:
return [self.sid]
return [self.asset]
def set_transact_setter(self, txn_sim_callable):
pass
@@ -209,7 +212,7 @@ class ExceptionAlgorithm(TradingAlgorithm):
class DivByZeroAlgorithm(TradingAlgorithm):
def initialize(self, sid):
self.sid = sid
self.asset = self.sid(sid)
self.incr = 0
def handle_data(self, data):
@@ -222,7 +225,7 @@ class DivByZeroAlgorithm(TradingAlgorithm):
class TooMuchProcessingAlgorithm(TradingAlgorithm):
def initialize(self, sid):
self.sid = sid
self.asset = self.sid(sid)
def handle_data(self, data):
# Unless we're running on some sort of
@@ -234,7 +237,7 @@ class TooMuchProcessingAlgorithm(TradingAlgorithm):
class TimeoutAlgorithm(TradingAlgorithm):
def initialize(self, sid):
self.sid = sid
self.asset = self.sid(sid)
self.incr = 0
def handle_data(self, data):
@@ -269,7 +272,7 @@ class TestOrderAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.incr += 1
self.order(0, 1)
self.order(self.sid(0), 1)
class TestOrderInstantAlgorithm(TradingAlgorithm):
@@ -286,7 +289,7 @@ class TestOrderInstantAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
self.last_price, "Orders was not filled at last price."
self.incr += 2
self.order_value(0, data[0].price * 2.)
self.order_value(self.sid(0), data[0].price * 2.)
self.last_price = data[0].price
@@ -311,7 +314,9 @@ class TestOrderStyleForwardingAlgorithm(TradingAlgorithm):
assert len(self.portfolio.positions.keys()) == 0
method_to_check = getattr(self, self.method_name)
method_to_check(0, data[0].price, style=StopLimitOrder(10, 10))
method_to_check(self.sid(0),
data[0].price,
style=StopLimitOrder(10, 10))
assert len(self.blotter.open_orders[0]) == 1
result = self.blotter.open_orders[0][0]
@@ -335,7 +340,12 @@ class TestOrderValueAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.incr += 2
self.order_value(0, data[0].price * 2.)
multiplier = 2.
if isinstance(self.sid(0), Future):
multiplier *= self.sid(0).contract_multiplier
self.order_value(self.sid(0), data[0].price * multiplier)
class TestTargetAlgorithm(TradingAlgorithm):
@@ -352,7 +362,7 @@ class TestTargetAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.target_shares = np.random.randint(1, 30)
self.order_target(0, self.target_shares)
self.order_target(self.sid(0), self.target_shares)
class TestOrderPercentAlgorithm(TradingAlgorithm):
@@ -363,7 +373,7 @@ class TestOrderPercentAlgorithm(TradingAlgorithm):
def handle_data(self, data):
if self.target_shares == 0:
assert 0 not in self.portfolio.positions
self.order(0, 10)
self.order(self.sid(0), 10)
self.target_shares = 10
return
else:
@@ -372,10 +382,17 @@ class TestOrderPercentAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.order_percent(0, .001)
self.target_shares += np.floor((.001 *
self.portfolio.portfolio_value) /
data[0].price)
self.order_percent(self.sid(0), .001)
if isinstance(self.sid(0), Equity):
self.target_shares += np.floor(
(.001 * self.portfolio.portfolio_value) / data[0].price
)
if isinstance(self.sid(0), Future):
self.target_shares += np.floor(
(.001 * self.portfolio.portfolio_value) /
(data[0].price * self.sid(0).contract_multiplier)
)
class TestTargetPercentAlgorithm(TradingAlgorithm):
@@ -394,7 +411,7 @@ class TestTargetPercentAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.sale_price = data[0].price
self.order_target_percent(0, .002)
self.order_target_percent(self.sid(0), .002)
class TestTargetValueAlgorithm(TradingAlgorithm):
@@ -405,7 +422,7 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
def handle_data(self, data):
if self.target_shares == 0:
assert 0 not in self.portfolio.positions
self.order(0, 10)
self.order(self.sid(0), 10)
self.target_shares = 10
return
else:
@@ -415,9 +432,15 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
assert self.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
self.order_target_value(0, 20)
self.order_target_value(self.sid(0), 20)
self.target_shares = np.round(20 / data[0].price)
if isinstance(self.sid(0), Equity):
self.target_shares = np.round(20 / data[0].price)
if isinstance(self.sid(0), Future):
self.target_shares = np.round(
20 / (data[0].price * self.sid(0).contract_multiplier))
############################
# AccountControl Test Algos#
@@ -468,6 +491,18 @@ class SetLongOnlyAlgorithm(TradingAlgorithm):
self.set_long_only()
class SetAssetDateBoundsAlgorithm(TradingAlgorithm):
"""
Algorithm that tries to order 1 share of sid 0 on every bar and has an
AssetDateBounds() trading control in place.
"""
def initialize(self):
self.register_trading_control(AssetDateBounds())
def handle_data(algo, data):
algo.order(algo.sid(0), 1)
class TestRegisterTransformAlgorithm(TradingAlgorithm):
def initialize(self, *args, **kwargs):
self.set_slippage(FixedSlippage())
@@ -484,7 +519,7 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm):
"""
def initialize(self, *args, **kwargs):
self.sid = kwargs.pop('sid')
self.asset = self.sid(kwargs.pop('sid'))
def handle_data(self, data):
@@ -493,42 +528,42 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm):
########
# Buy with low limit, shouldn't trigger.
self.order(self.sid, 100, limit_price=1)
self.order(self.asset, 100, limit_price=1)
# But with high stop, shouldn't trigger
self.order(self.sid, 100, stop_price=10000000)
self.order(self.asset, 100, stop_price=10000000)
# Buy with high limit (should trigger) but also high stop (should
# prevent trigger).
self.order(self.sid, 100, limit_price=10000000, stop_price=10000000)
self.order(self.asset, 100, limit_price=10000000, stop_price=10000000)
# Buy with low stop (should trigger), but also low limit (should
# prevent trigger).
self.order(self.sid, 100, limit_price=1, stop_price=1)
self.order(self.asset, 100, limit_price=1, stop_price=1)
#########
# Sells #
#########
# Sell with high limit, shouldn't trigger.
self.order(self.sid, -100, limit_price=1000000)
self.order(self.asset, -100, limit_price=1000000)
# Sell with low stop, shouldn't trigger.
self.order(self.sid, -100, stop_price=1)
self.order(self.asset, -100, stop_price=1)
# Sell with low limit (should trigger), but also high stop (should
# prevent trigger).
self.order(self.sid, -100, limit_price=1000000, stop_price=1000000)
self.order(self.asset, -100, limit_price=1000000, stop_price=1000000)
# Sell with low limit (should trigger), but also low stop (should
# prevent trigger).
self.order(self.sid, -100, limit_price=1, stop_price=1)
self.order(self.asset, -100, limit_price=1, stop_price=1)
###################
# Rounding Checks #
###################
self.order(self.sid, 100, limit_price=.00000001)
self.order(self.sid, -100, stop_price=.00000001)
self.order(self.asset, 100, limit_price=.00000001)
self.order(self.asset, -100, stop_price=.00000001)
##########################################
@@ -695,8 +730,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
"batch transform is not updating for new kwargs"
new_data = deepcopy(data)
for sid in new_data:
new_data[sid]['arbitrary'] = 123
for sidint in new_data:
new_data[sidint]['arbitrary'] = 123
self.history_return_arbitrary_fields.append(
self.return_arbitrary_fields.handle_data(new_data))
@@ -707,8 +742,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.return_nan.handle_data(data))
else:
nan_data = deepcopy(data)
for sid in nan_data.iterkeys():
nan_data[sid].price = np.nan
nan_data.price = np.nan
self.history_return_nan.append(
self.return_nan.handle_data(nan_data))
@@ -810,7 +844,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm):
def handle_data(self, data):
if not self.ordered:
for s in data:
self.order(s, 100)
self.order(self.sid(s), 100)
self.ordered = True
if not self.exited:
@@ -821,7 +855,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm):
(len(amounts) == len(data.keys()))
):
for stock in self.portfolio.positions:
self.order(stock, -100)
self.order(self.sid(stock), -100)
self.exited = True
# Should be 0 when all positions are exited.
@@ -834,7 +868,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
appropriate exceptions are raised.
"""
def initialize(self, *args, **kwargs):
self.sid = kwargs.pop('sids')[0]
self.asset = self.sid(kwargs.pop('sids')[0])
def handle_data(self, data):
from zipline.api import (
@@ -849,40 +883,48 @@ class InvalidOrderAlgorithm(TradingAlgorithm):
StopOrder(10), StopLimitOrder(10, 10)]:
with assert_raises(UnsupportedOrderParameters):
order(self.sid, 10, limit_price=10, style=style)
order(self.asset, 10, limit_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order(self.sid, 10, stop_price=10, style=style)
order(self.asset, 10, stop_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_value(self.sid, 300, limit_price=10, style=style)
order_value(self.asset, 300, limit_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_value(self.sid, 300, stop_price=10, style=style)
order_value(self.asset, 300, stop_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_percent(self.sid, .1, limit_price=10, style=style)
order_percent(self.asset, .1, limit_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_percent(self.sid, .1, stop_price=10, style=style)
order_percent(self.asset, .1, stop_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_target(self.sid, 100, limit_price=10, style=style)
order_target(self.asset, 100, limit_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_target(self.sid, 100, stop_price=10, style=style)
order_target(self.asset, 100, stop_price=10, style=style)
with assert_raises(UnsupportedOrderParameters):
order_target_value(self.sid, 100, limit_price=10, style=style)
order_target_value(self.asset, 100,
limit_price=10,
style=style)
with assert_raises(UnsupportedOrderParameters):
order_target_value(self.sid, 100, stop_price=10, style=style)
order_target_value(self.asset, 100,
stop_price=10,
style=style)
with assert_raises(UnsupportedOrderParameters):
order_target_percent(self.sid, .2, limit_price=10, style=style)
order_target_percent(self.asset, .2,
limit_price=10,
style=style)
with assert_raises(UnsupportedOrderParameters):
order_target_percent(self.sid, .2, stop_price=10, style=style)
order_target_percent(self.asset, .2,
stop_price=10,
style=style)
##############################
@@ -913,7 +955,7 @@ def handle_data_api(context, data):
assert context.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
context.incr += 1
order(0, 1)
order(sid(0), 1)
record(incr=context.incr)
@@ -932,7 +974,8 @@ api_algo = """
from zipline.api import (order,
set_slippage,
FixedSlippage,
record)
record,
sid)
def initialize(context):
context.incr = 0
@@ -948,7 +991,7 @@ def handle_data(context, data):
assert context.portfolio.positions[0]['last_sale_price'] == \
data[0].price, "Orders not filled at current price."
context.incr += 1
order(0, 1)
order(sid(0), 1)
record(incr=context.incr)
"""
@@ -1009,18 +1052,19 @@ from zipline.api import (order,
order_percent,
order_target,
order_target_value,
order_target_percent)
order_target_percent,
sid)
def initialize(context):
pass
def handle_data(context, data):
order(0, 10)
order_value(0, 300)
order_percent(0, .1)
order_target(0, 100)
order_target_value(0, 100)
order_target_percent(0, .2)
order(sid(0), 10)
order_value(sid(0), 300)
order_percent(sid(0), .1)
order_target(sid(0), 100)
order_target_value(sid(0), 100)
order_target_percent(sid(0), .2)
"""
record_variables = """
+62 -9
View File
@@ -31,14 +31,15 @@ except:
PYGMENTS = False
import zipline
from zipline.errors import NoSourceError, PipelineDateError
DEFAULTS = {
'start': '2012-01-01',
'end': '2012-12-31',
'data_frequency': 'daily',
'capital_base': '10e6',
'source': 'yahoo',
'symbols': 'AAPL'
'symbols': 'AAPL',
'metadata_index': 'symbol',
'source_time_column': 'Date',
}
@@ -99,9 +100,12 @@ def parse_args(argv, ipython_mode=False):
parser.add_argument('--start', '-s')
parser.add_argument('--end', '-e')
parser.add_argument('--capital_base')
parser.add_argument('--source', choices=('yahoo',))
parser.add_argument('--source', '-d')
parser.add_argument('--source_time_column', '-t')
parser.add_argument('--symbols')
parser.add_argument('--output', '-o')
parser.add_argument('--metadata_path', '-m')
parser.add_argument('--metadata_index', '-x')
if ipython_mode:
parser.add_argument('--local_namespace', action='store_true')
@@ -153,14 +157,59 @@ def run_pipeline(print_algo=True, **kwargs):
pygments syntax coloring if pygments is found.
"""
start = pd.Timestamp(kwargs['start'], tz='UTC')
end = pd.Timestamp(kwargs['end'], tz='UTC')
start = kwargs['start']
end = kwargs['end']
# Compare against None because strings/timestamps may have been given
if start is not None:
start = pd.Timestamp(start, tz='UTC')
if end is not None:
end = pd.Timestamp(end, tz='UTC')
# Fail out if only one bound is provided
if ((start is None) or (end is None)) and (start != end):
raise PipelineDateError(start=start, end=end)
# Check if start and end are provided, and if the sim_params need to read
# a start and end from the DataSource
if start is None:
overwrite_sim_params = True
else:
overwrite_sim_params = False
symbols = kwargs['symbols'].split(',')
asset_identifier = kwargs['metadata_index']
if kwargs['source'] == 'yahoo':
# Pull asset metadata
asset_metadata = kwargs.get('asset_metadata', None)
asset_metadata_path = kwargs['metadata_path']
# Read in a CSV file, if applicable
if asset_metadata_path is not None:
if os.path.isfile(asset_metadata_path):
asset_metadata = pd.read_csv(asset_metadata_path,
index_col=asset_identifier)
source_arg = kwargs['source']
source_time_column = kwargs['source_time_column']
if source_arg is None:
raise NoSourceError()
elif source_arg == 'yahoo':
source = zipline.data.load_bars_from_yahoo(
stocks=symbols, start=start, end=end)
elif os.path.isfile(source_arg):
source = zipline.data.load_prices_from_csv(
filepath=source_arg,
identifier_col=source_time_column
)
elif os.path.isdir(source_arg):
source = zipline.data.load_prices_from_csv_folder(
folderpath=source_arg,
identifier_col=source_time_column
)
else:
raise NotImplementedError(
'Source %s not implemented.' % kwargs['source'])
@@ -188,9 +237,13 @@ def run_pipeline(print_algo=True, **kwargs):
algo = zipline.TradingAlgorithm(script=algo_text,
namespace=kwargs.get('namespace', {}),
capital_base=float(kwargs['capital_base']),
algo_filename=kwargs.get('algofile'))
algo_filename=kwargs.get('algofile'),
asset_metadata=asset_metadata,
identifiers=symbols,
start=start,
end=end)
perf = algo.run(source)
perf = algo.run(source, overwrite_sim_params=overwrite_sim_params)
output_fname = kwargs.get('output', None)
if output_fname is not None:
+3 -1
View File
@@ -120,6 +120,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params,
source_id="test_factory"):
trades = []
current = sim_params.first_open
trading.environment.update_asset_finder(identifiers=[sid])
oneday = timedelta(days=1)
use_midnight = interval >= oneday
@@ -149,7 +150,6 @@ def create_dividend(sid, payment, declared_date, ex_date, pay_date):
'type': DATASOURCE_TYPE.DIVIDEND,
'source_id': 'MockDividendSource'
})
return div
@@ -312,6 +312,8 @@ def create_test_df_source(sim_params=None, bars='daily'):
df = pd.DataFrame(x, index=index, columns=[0])
trading.environment.update_asset_finder(identifiers=[0])
return DataFrameSource(df), df
+10 -17
View File
@@ -5,6 +5,7 @@ import os.path
import pandas as pd
import pytz
import zipline
from zipline.finance.trading import with_environment
DATE_FORMAT = "%Y%m%d"
@@ -12,23 +13,16 @@ zipline_dir = os.path.dirname(zipline.__file__)
SECURITY_LISTS_DIR = os.path.join(zipline_dir, 'resources', 'security_lists')
def loopback(symbol, *args, **kwargs):
return symbol
class SecurityList(object):
def __init__(self, lookup_func, data, current_date_func):
def __init__(self, data, current_date_func):
"""
lookup_func: function that takes a string symbol and a date and
returns a Security object.
data: a nested dictionary:
knowledge_date -> lookup_date ->
{add: [symbol list], 'delete': []}, delete: [symbol list]}
current_date_func: function taking no parameters, returning
current datetime
"""
self.lookup_func = lookup_func
self.data = data
self._cache = {}
self._knowledge_dates = self.make_knowledge_dates(self.data)
@@ -74,13 +68,17 @@ class SecurityList(object):
self._cache[kd] = self._current_set
return self._current_set
def update_current(self, effective_date, symbols, change_func):
@with_environment()
def update_current(self, effective_date, symbols, change_func, env=None):
for symbol in symbols:
sid = self.lookup_func(
asset = env.asset_finder.lookup_symbol(
symbol,
as_of_date=effective_date
)
change_func(sid)
# Pass if no Asset exists for the symbol
if asset is None:
continue
change_func(asset.sid)
class SecurityListSet(object):
@@ -88,11 +86,7 @@ class SecurityListSet(object):
# list implementations.
security_list_type = SecurityList
def __init__(self, current_date_func, lookup_func=None):
if lookup_func is None:
self.lookup_func = loopback
else:
self.lookup_func = lookup_func
def __init__(self, current_date_func):
self.current_date_func = current_date_func
self._leveraged_etf = None
@@ -100,7 +94,6 @@ class SecurityListSet(object):
def leveraged_etf_list(self):
if self._leveraged_etf is None:
self._leveraged_etf = self.security_list_type(
self.lookup_func,
load_from_directory('leveraged_etf_list'),
self.current_date_func
)
+12 -6
View File
@@ -7,7 +7,7 @@ def create_test_zipline(**config):
"""
:param config: A configuration object that is a dict with:
- sid - an integer, which will be used as the security ID.
- sid - an integer, which will be used as the asset ID.
- order_count - the number of orders the test algo will place,
defaults to 100
- order_amount - the number of shares per order, defaults to 100
@@ -25,10 +25,15 @@ def create_test_zipline(**config):
:py:mod:`zipline.finance.trading`
"""
assert isinstance(config, dict)
sid_list = config.get('sid_list')
if not sid_list:
sid = config.get('sid')
sid_list = [sid]
try:
sid_list = config['sid_list']
except KeyError:
try:
sid_list = [config['sid']]
except KeyError:
raise Exception("simfactory create_test_zipline() requires "
"argument 'sid_list' or 'sid'")
concurrent_trades = config.get('concurrent_trades', False)
@@ -49,12 +54,13 @@ def create_test_zipline(**config):
test_algo = config['algorithm']
else:
test_algo = TestAlgorithm(
sid,
sid_list[0],
order_amount,
order_count,
sim_params=config.get('sim_params',
factory.create_simulation_parameters()),
slippage=config.get('slippage'),
identifiers=sid_list
)
# -------------------