From b84ac01cbfbc24e52f3f30ca70173fc03643974b Mon Sep 17 00:00:00 2001 From: jfkirk Date: Mon, 8 Jun 2015 16:29:19 -0400 Subject: [PATCH] ENH: Adds futures trading and asset management logic to TradingAlgorithm and performance classes --- zipline/algorithm.py | 157 +++++++++++----- zipline/data/__init__.py | 8 +- zipline/data/loader.py | 23 ++- zipline/errors.py | 95 +++++++++- zipline/examples/dual_ema_talib.py | 15 +- zipline/examples/dual_moving_average.py | 18 +- zipline/examples/olmar.py | 23 +-- zipline/examples/pairtrade.py | 30 ++-- zipline/examples/quantopian_buy_apple.py | 8 +- zipline/finance/controls.py | 97 ++++++---- zipline/finance/performance/period.py | 46 ++++- zipline/finance/performance/position.py | 5 +- .../finance/performance/position_tracker.py | 148 ++++++++++++--- zipline/finance/performance/tracker.py | 50 +++--- zipline/finance/trading.py | 50 ++++++ zipline/history/history_container.py | 2 +- zipline/protocol.py | 10 +- zipline/sources/data_frame_source.py | 40 +++-- zipline/sources/simulated.py | 2 + zipline/sources/test_source.py | 31 +++- zipline/test_algorithms.py | 170 +++++++++++------- zipline/utils/cli.py | 71 +++++++- zipline/utils/factory.py | 4 +- zipline/utils/security_list.py | 27 ++- zipline/utils/simfactory.py | 18 +- 25 files changed, 844 insertions(+), 304 deletions(-) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 023f1c7b..354a7fbe 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -39,9 +39,10 @@ from zipline.errors import ( UnsupportedCommissionModel, UnsupportedOrderParameters, UnsupportedSlippageModel, + SidNotFound, ) -from zipline.finance import trading +from zipline.finance.trading import TradingEnvironment from zipline.finance.blotter import Blotter from zipline.finance.commission import PerShare, PerTrade, PerDollar from zipline.finance.controls import ( @@ -64,6 +65,7 @@ from zipline.finance.slippage import ( SlippageModel, transact_partial ) +from zipline.assets import Asset, Future from zipline.gens.composites import date_sorted_sources from zipline.gens.tradesimulation import AlgorithmSimulator from zipline.sources import DataFrameSource, DataPanelSource @@ -133,8 +135,27 @@ class TradingAlgorithm(object): How much capital to start with. instant_fill : bool Whether to fill orders immediately or on next bar. - environment : str - The environment that this algorithm is running in. + asset_finder : An AssetFinder object + A new AssetFinder object to be used in this TradingEnvironment + asset_metadata: can be either: + - dict + - pandas.DataFrame + - object with 'read' property + If dict is provided, it must have the following structure: + * keys are the identifiers + * values are dicts containing the metadata, with the metadata + field name as the key + If pandas.DataFrame is provided, it must have the + following structure: + * column names must be the metadata fields + * index must be the different asset identifiers + * array contents should be the metadata value + If an object with a 'read' property is provided, 'read' must + return rows containing at least one of 'sid' or 'symbol' along + with the other metadata fields. + identifiers : List + Any asset identifiers that are not provided in the + asset_metadata, but will be traded by this TradingAlgorithm """ self.datetime = None @@ -167,10 +188,23 @@ class TradingAlgorithm(object): self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( - capital_base=self.capital_base + capital_base=self.capital_base, + start=kwargs.pop('start', None), + end=kwargs.pop('end', None) ) self.perf_tracker = PerformanceTracker(self.sim_params) + # Update the TradingEnvironment with the provided asset metadata + self.trading_environment = kwargs.pop('env', + TradingEnvironment.instance()) + self.trading_environment.update_asset_finder( + asset_finder=kwargs.pop('asset_finder', None), + asset_metadata=kwargs.pop('asset_metadata', None), + identifiers=kwargs.pop('identifiers', None) + ) + # Pull in the environment's new AssetFinder for quick reference + self.asset_finder = self.trading_environment.asset_finder + self.blotter = kwargs.pop('blotter', None) if not self.blotter: self.blotter = Blotter() @@ -322,11 +356,10 @@ class TradingAlgorithm(object): sim_params = self.sim_params if self.benchmark_return_source is None: - env = trading.environment if sim_params.data_frequency == 'minute' or \ sim_params.emission_rate == 'minute': def update_time(date): - return env.get_open_and_close(date)[1] + return self.trading_environment.get_open_and_close(date)[1] else: def update_time(date): return date @@ -336,7 +369,7 @@ class TradingAlgorithm(object): 'type': zipline.protocol.DATASOURCE_TYPE.BENCHMARK, 'source_id': 'benchmarks'}) for dt, ret in - trading.environment.benchmark_returns.iteritems() + self.trading_environment.benchmark_returns.iteritems() if dt.date() >= sim_params.period_start.date() and dt.date() <= sim_params.period_end.date() ] @@ -410,8 +443,7 @@ class TradingAlgorithm(object): If pandas.DataFrame is provided, it must have the following structure: - * column names must consist of ints representing the - different sids + * column names must be the different asset identifiers * index must be DatetimeIndex * array contents should be price info. @@ -420,9 +452,11 @@ class TradingAlgorithm(object): Daily performance metrics such as returns, alpha etc. """ + + # Ensure that source is a DataSource object if isinstance(source, list): if overwrite_sim_params: - warnings.warn("""List of sources passed, will not attempt to extract sids, and start and end + warnings.warn("""List of sources passed, will not attempt to extract start and end dates. Make sure to set the correct fields in sim_params passed to __init__().""", UserWarning) overwrite_sim_params = False @@ -443,8 +477,19 @@ class TradingAlgorithm(object): self.sim_params.period_start = source.start if hasattr(source, 'end'): self.sim_params.period_end = source.end + # The sids field of the source is the canonical reference for + # sids in this run all_sids = [sid for s in self.sources for sid in s.sids] self.sim_params.sids = set(all_sids) + # Check that all sids from the source are accounted for in + # the AssetFinder + for sid in self.sim_params.sids: + try: + self.asset_finder.retrieve_asset(sid) + except SidNotFound: + warnings.warn("No Asset found for sid '%s'. Make sure " + "that the correct identifiers and asset " + "metadata are passed to __init__()." % sid) # Changing period_start and period_close might require updating # of first_open and last_close. self.sim_params._update_internal() @@ -604,17 +649,52 @@ class TradingAlgorithm(object): def symbol(self, symbol_str): """ Default symbol lookup for any source that directly maps the - symbol to the identifier (e.g. yahoo finance). + symbol to the Asset (e.g. yahoo finance). """ - return symbol_str + asset, _ = self.asset_finder.lookup_generic( + asset_convertible_or_iterable=symbol_str, + as_of_date=self.datetime, + ) + return asset @api_method def symbols(self, *args): """ Default symbols lookup for any source that directly maps the - symbol to the identifier (e.g. yahoo finance). + symbol to the Asset (e.g. yahoo finance). """ - return args + return [self.symbol(identifier) for identifier in args] + + @api_method + def sid(self, a_sid): + """ + Default sid lookup for any source that directly maps the integer sid + to the Asset. + """ + return self.asset_finder.retrieve_asset(a_sid) + + def _calculate_order_value_amount(self, asset, value): + """ + Calculates how many shares/contracts to order based on the type of + asset being ordered. + """ + last_price = self.trading_client.current_data[asset].price + + if tolerant_equals(last_price, 0): + zero_message = "Price of 0 for {psid}; can't infer value".format( + psid=asset + ) + if self.logger: + self.logger.debug(zero_message) + # Don't place any order + return 0 + + if isinstance(asset, Future): + value_multiplier = asset.contract_multiplier + else: + value_multiplier = 1 + + return value / (last_price * value_multiplier) @api_method def order(self, sid, amount, @@ -655,7 +735,7 @@ class TradingAlgorithm(object): return self.blotter.order(sid, amount, style) def validate_order_params(self, - sid, + asset, amount, limit_price, stop_price, @@ -682,8 +762,14 @@ class TradingAlgorithm(object): msg="Passing both stop_price and style is not supported." ) + if not isinstance(asset, Asset): + raise UnsupportedOrderParameters( + msg="Passing non-Asset argument to 'order()' is not supported." + " Use 'sid()' or 'symbol()' methods to look up an Asset." + ) + for control in self.trading_controls: - control.validate(sid, + control.validate(asset, amount, self.updated_portfolio(), self.get_datetime(), @@ -718,6 +804,8 @@ class TradingAlgorithm(object): Place an order by desired value rather than desired number of shares. If the requested sid is found in the universe, the requested value is divided by its price to imply the number of shares to transact. + If the Asset being ordered is a Future, the 'value' calculated + is actually the exposure, as Futures have no 'value'. value > 0 :: Buy/Cover value < 0 :: Sell/Short @@ -726,21 +814,11 @@ class TradingAlgorithm(object): Stop order: order(sid, value, None, stop_price) StopLimit order: order(sid, value, limit_price, stop_price) """ - last_price = self.trading_client.current_data[sid].price - if tolerant_equals(last_price, 0): - zero_message = "Price of 0 for {psid}; can't infer value".format( - psid=sid - ) - if self.logger: - self.logger.debug(zero_message) - # Don't place any order - return - else: - amount = value / last_price - return self.order(sid, amount, - limit_price=limit_price, - stop_price=stop_price, - style=style) + amount = self._calculate_order_value_amount(sid, value) + return self.order(sid, amount, + limit_price=limit_price, + stop_price=stop_price, + style=style) @property def recorded_vars(self): @@ -855,7 +933,7 @@ class TradingAlgorithm(object): def order_percent(self, sid, percent, limit_price=None, stop_price=None, style=None): """ - Place an order in the specified security corresponding to the given + Place an order in the specified asset corresponding to the given percent of the current portfolio value. Note that percent must expressed as a decimal (0.50 means 50\%). @@ -898,15 +976,10 @@ class TradingAlgorithm(object): order. If the position does exist, this is equivalent to placing an order for the difference between the target value and the current value. + If the Asset being ordered is a Future, the 'target value' calculated + is actually the target exposure, as Futures have no 'value'. """ - last_price = self.trading_client.current_data[sid].price - if tolerant_equals(last_price, 0): - # Don't place an order - if self.logger: - zero_message = "Price of 0 for {psid}; can't infer value" - self.logger.debug(zero_message.format(psid=sid)) - return - target_amount = target / last_price + target_amount = self._calculate_order_value_amount(sid, target) return self.order_target(sid, target_amount, limit_price=limit_price, stop_price=stop_price, @@ -1066,7 +1139,7 @@ class TradingAlgorithm(object): increasing the absolute value of shares/dollar value exceeding one of these limits, raise a TradingControlException. """ - control = MaxPositionSize(sid=sid, + control = MaxPositionSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) @@ -1081,7 +1154,7 @@ class TradingAlgorithm(object): If an algorithm attempts to place an order that would result in exceeding one of these limits, raise a TradingControlException. """ - control = MaxOrderSize(sid=sid, + control = MaxOrderSize(asset=sid, max_shares=max_shares, max_notional=max_notional) self.register_trading_control(control) diff --git a/zipline/data/__init__.py b/zipline/data/__init__.py index 57f8e9f0..478d6261 100644 --- a/zipline/data/__init__.py +++ b/zipline/data/__init__.py @@ -1,4 +1,8 @@ from . import loader -from .loader import load_from_yahoo, load_bars_from_yahoo +from .loader import ( + load_from_yahoo, load_bars_from_yahoo, load_prices_from_csv, + load_prices_from_csv_folder +) -__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo'] +__all__ = ['loader', 'load_from_yahoo', 'load_bars_from_yahoo', + 'load_prices_from_csv', 'load_prices_from_csv_folder'] diff --git a/zipline/data/loader.py b/zipline/data/loader.py index 76e335f9..abc4caed 100644 --- a/zipline/data/loader.py +++ b/zipline/data/loader.py @@ -294,7 +294,7 @@ def load_from_yahoo(indexes=None, adjusted=True): """ Loads price data from Yahoo into a dataframe for each of the indicated - securities. By default, 'price' is taken from Yahoo's 'Adjusted Close', + assets. By default, 'price' is taken from Yahoo's 'Adjusted Close', which removes the impact of splits and dividends. If the argument 'adjusted' is False, then the non-adjusted 'close' field is used instead. @@ -367,3 +367,24 @@ def load_bars_from_yahoo(indexes=None, for col in adj_cols: panel[ticker][col] *= ratio_filtered return panel + + +def load_prices_from_csv(filepath, identifier_col, tz='UTC'): + data = pd.read_csv(filepath, index_col=identifier_col) + data.index = pd.DatetimeIndex(data.index, tz=tz) + data.sort_index(inplace=True) + return data + + +def load_prices_from_csv_folder(folderpath, identifier_col, tz='UTC'): + data = None + for file in os.listdir(folderpath): + if '.csv' not in file: + continue + raw = load_prices_from_csv(os.path.join(folderpath, file), + identifier_col, tz) + if data is None: + data = raw + else: + data = pd.concat([data, raw], axis=1) + return data diff --git a/zipline/errors.py b/zipline/errors.py index 38316ba4..43a2939f 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -175,7 +175,8 @@ class TradingControlViolation(ZiplineError): Raised if an order would violate a constraint set by a TradingControl. """ msg = """ -Order for {amount} shares of {sid} violates trading constraint {constraint}. +Order for {amount} shares of {asset} at {datetime} violates trading constraint +{constraint}. """.strip() @@ -188,3 +189,95 @@ class IncompatibleHistoryFrequency(ZiplineError): Requested history at frequency '{frequency}' cannot be created with data at frequency '{data_frequency}'. """.strip() + + +class MultipleSymbolsFound(ZiplineError): + """ + Raised when a symbol() call contains a symbol that changed over + time and is thus not resolvable without additional information + provided via as_of_date. + """ + msg = """ +Multiple symbols with the name '{symbol}' found. Use the +as_of_date' argument to to specify when the date symbol-lookup +should be valid. + +Possible options:{options} + """.strip() + + +class SymbolNotFound(ZiplineError): + """ + Raised when a symbol() call contains a non-existant symbol. + """ + msg = """ +Symbol '{symbol}' was not found. +""".strip() + + +class SidNotFound(ZiplineError): + """ + Raised when a retrieve_asset() call contains a non-existent sid. + """ + msg = """ +Asset with sid '{sid}' was not found. +""".strip() + + +class IdentifierNotFound(ZiplineError): + """ + Raised when a retrieve_asset_by_identifier() call contains a non-existent + identifier. + """ + msg = """ +Asset with identifier '{identifier}' was not found. +""".strip() + + +class InvalidAssetType(ZiplineError): + """ + Raised when an AssetFinder tries to build an Asset with an invalid + AssetType. + """ + msg = """ +AssetMetaData contained an invalid Asset type: '{asset_type}'. +""".strip() + + +class UpdateAssetFinderTypeError(ZiplineError): + """ + Raised when TradingEnvironment.update_asset_finder() gets an asset_finder + arg that is not of AssetFinder class. + """ + msg = """ +TradingEnvironment can not set asset_finder to object of class {cls}. +""".strip() + + +class ConsumeAssetMetaDataError(ZiplineError): + """ + Raised when AssetMetaData.consume() is called on an invalid object. + """ + msg = """ +AssetMetaData can not consume {obj}. MetaData must be a dict, a DataFrame, or +""".strip() + + +class NoSourceError(ZiplineError): + """ + Raised when no source is given to the pipeline + """ + msg = """ +No data source given. +""".strip() + + +class PipelineDateError(ZiplineError): + """ + Raised when only one date is passed to the pipeline + """ + msg = """ +Only one simulation date given. Please specify both the 'start' and 'end' for +the simulation, or neither. If neither is given, the start and end of the +DataSource will be used. Given start = '{start}', end = '{end}' +""".strip() diff --git a/zipline/examples/dual_ema_talib.py b/zipline/examples/dual_ema_talib.py index d078fbe5..96a0283d 100644 --- a/zipline/examples/dual_ema_talib.py +++ b/zipline/examples/dual_ema_talib.py @@ -29,7 +29,7 @@ from zipline.transforms.ta import EMA def initialize(context): - context.security = symbol('AAPL') + context.asset = symbol('AAPL') # Add 2 mavg transforms, one with a long window, one with a short window. context.short_ema_trans = EMA(timeperiod=20) @@ -49,17 +49,17 @@ def handle_data(context, data): sell = False if (short_ema > long_ema).all() and not context.invested: - order(context.security, 100) + order(context.asset, 100) context.invested = True buy = True elif (short_ema < long_ema).all() and context.invested: - order(context.security, -100) + order(context.asset, -100) context.invested = False sell = True - record(AAPL=data[context.security].price, - short_ema=short_ema[context.security], - long_ema=long_ema[context.security], + record(AAPL=data[context.asset].price, + short_ema=short_ema[context.asset], + long_ema=long_ema[context.asset], buy=buy, sell=sell) @@ -77,7 +77,8 @@ if __name__ == '__main__': data = load_from_yahoo(stocks=['AAPL'], indexes={}, start=start, end=end) - algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data) + algo = TradingAlgorithm(initialize=initialize, handle_data=handle_data, + identifiers=['AAPL']) results = algo.run(data).dropna() fig = plt.figure() diff --git a/zipline/examples/dual_moving_average.py b/zipline/examples/dual_moving_average.py index 9d1708ac..5ee23ad7 100755 --- a/zipline/examples/dual_moving_average.py +++ b/zipline/examples/dual_moving_average.py @@ -31,6 +31,8 @@ def initialize(context): add_history(100, '1d', 'price') add_history(300, '1d', 'price') + context.sym = symbol('AAPL') + context.i = 0 @@ -46,17 +48,15 @@ def handle_data(context, data): short_mavg = history(100, '1d', 'price').mean() long_mavg = history(300, '1d', 'price').mean() - sym = symbol('AAPL') - # Trading logic - if short_mavg[sym] > long_mavg[sym]: + if short_mavg[context.sym] > long_mavg[context.sym]: # order_target orders as many shares as needed to # achieve the desired number of shares. - order_target(sym, 100) - elif short_mavg[sym] < long_mavg[sym]: - order_target(sym, 0) + order_target(context.sym, 100) + elif short_mavg[context.sym] < long_mavg[context.sym]: + order_target(context.sym, 0) # Save values for later inspection - record(AAPL=data[sym].price, - short_mavg=short_mavg[sym], - long_mavg=long_mavg[sym]) + record(AAPL=data[context.sym].price, + short_mavg=short_mavg[context.sym], + long_mavg=long_mavg[context.sym]) diff --git a/zipline/examples/olmar.py b/zipline/examples/olmar.py index c04f5c0d..a1277f20 100644 --- a/zipline/examples/olmar.py +++ b/zipline/examples/olmar.py @@ -24,6 +24,7 @@ STOCKS = ['AMD', 'CERN', 'COST', 'DELL', 'GPS', 'INTC', 'MMM'] # http://icml.cc/2012/papers/168.pdf def initialize(algo, eps=1, window_length=5): algo.stocks = STOCKS + algo.sids = [algo.symbol(symbol) for symbol in algo.stocks] algo.m = len(algo.stocks) algo.price = {} algo.b_t = np.ones(algo.m) / algo.m @@ -52,11 +53,11 @@ def handle_data(algo, data): x_tilde = np.zeros(m) b = np.zeros(m) - # find relative moving average price for each security - for i, stock in enumerate(algo.stocks): - price = data[stock].price + # find relative moving average price for each asset + for i, sid in enumerate(algo.sids): + price = data[sid].price # Relative mean deviation - x_tilde[i] = data[stock].mavg(algo.window_length) / price + x_tilde[i] = data[sid].mavg(algo.window_length) / price ########################### # Inside of OLMAR (algo 2) @@ -98,17 +99,17 @@ def rebalance_portfolio(algo, data, desired_port): positions_value = algo.portfolio.positions_value + \ algo.portfolio.cash - for i, stock in enumerate(algo.stocks): - current_amount[i] = algo.portfolio.positions[stock].amount - prices[i] = data[stock].price + for i, sid in enumerate(algo.sids): + current_amount[i] = algo.portfolio.positions[sid].amount + prices[i] = data[sid].price desired_amount = np.round(desired_port * positions_value / prices) algo.last_desired_port = desired_port diff_amount = desired_amount - current_amount - for i, stock in enumerate(algo.stocks): - algo.order(stock, diff_amount[i]) + for i, sid in enumerate(algo.sids): + algo.order(sid, diff_amount[i]) def simplex_projection(v, b=1): @@ -154,7 +155,9 @@ if __name__ == '__main__': end = datetime(2008, 1, 1, 0, 0, 0, 0, pytz.utc) data = load_from_yahoo(stocks=STOCKS, indexes={}, start=start, end=end) data = data.dropna() - olmar = TradingAlgorithm(handle_data=handle_data, initialize=initialize) + olmar = TradingAlgorithm(handle_data=handle_data, + initialize=initialize, + identifiers=STOCKS) results = olmar.run(data) results.portfolio_value.plot() pl.show() diff --git a/zipline/examples/pairtrade.py b/zipline/examples/pairtrade.py index 5b67681d..ca71d8f9 100755 --- a/zipline/examples/pairtrade.py +++ b/zipline/examples/pairtrade.py @@ -23,6 +23,7 @@ import pytz from zipline.algorithm import TradingAlgorithm from zipline.transforms import batch_transform from zipline.utils.factory import load_from_yahoo +from zipline.sources.data_frame_source import DataFrameSource @batch_transform @@ -59,11 +60,13 @@ class Pairtrade(TradingAlgorithm): self.window_length = window_length self.ols_transform = ols_transform(refresh_period=self.window_length, window_length=self.window_length) + self.PEP = self.symbol('PEP') + self.KO = self.symbol('KO') def handle_data(self, data): ###################################################### # 1. Compute regression coefficients between PEP and KO - params = self.ols_transform.handle_data(data, 'PEP', 'KO') + params = self.ols_transform.handle_data(data, self.PEP, self.KO) if params is None: return intercept, slope = params @@ -81,7 +84,8 @@ class Pairtrade(TradingAlgorithm): """1. Compute the spread given slope and intercept. 2. zscore the spread. """ - spread = (data['PEP'].price - (slope * data['KO'].price + intercept)) + spread = (data[self.PEP].price - + (slope * data[self.KO].price + intercept)) self.spreads.append(spread) spread_wind = self.spreads[-self.window_length:] zscore = (spread - np.mean(spread_wind)) / np.std(spread_wind) @@ -91,12 +95,12 @@ class Pairtrade(TradingAlgorithm): """Buy spread if zscore is > 2, sell if zscore < .5. """ if zscore >= 2.0 and not self.invested: - self.order('PEP', int(100 / data['PEP'].price)) - self.order('KO', -int(100 / data['KO'].price)) + self.order(self.PEP, int(100 / data[self.PEP].price)) + self.order(self.KO, -int(100 / data[self.KO].price)) self.invested = True elif zscore <= -2.0 and not self.invested: - self.order('PEP', -int(100 / data['PEP'].price)) - self.order('KO', int(100 / data['KO'].price)) + self.order(self.PEP, -int(100 / data[self.PEP].price)) + self.order(self.KO, int(100 / data[self.KO].price)) self.invested = True elif abs(zscore) < .5 and self.invested: self.sell_spread() @@ -107,23 +111,25 @@ class Pairtrade(TradingAlgorithm): decrease exposure, regardless of position long/short. buy for a short position, sell for a long. """ - ko_amount = self.portfolio.positions['KO'].amount - self.order('KO', -1 * ko_amount) - pep_amount = self.portfolio.positions['PEP'].amount - self.order('PEP', -1 * pep_amount) + ko_amount = self.portfolio.positions[self.KO].amount + self.order(self.KO, -1 * ko_amount) + pep_amount = self.portfolio.positions[self.PEP].amount + self.order(self.PEP, -1 * pep_amount) if __name__ == '__main__': start = datetime(2000, 1, 1, 0, 0, 0, 0, pytz.utc) end = datetime(2002, 1, 1, 0, 0, 0, 0, pytz.utc) data = load_from_yahoo(stocks=['PEP', 'KO'], indexes={}, start=start, end=end) + source = DataFrameSource(data) pairtrade = Pairtrade() - results = pairtrade.run(data) + results = pairtrade.run(source) data['spreads'] = np.nan ax1 = plt.subplot(211) - data[['PEP', 'KO']].plot(ax=ax1) + # TODO Bugged - indices are out of bounds + # data[[pairtrade.PEPsid, pairtrade.KOsid]].plot(ax=ax1) plt.ylabel('price') plt.setp(ax1.get_xticklabels(), visible=False) diff --git a/zipline/examples/quantopian_buy_apple.py b/zipline/examples/quantopian_buy_apple.py index 45beca4c..f07b24ca 100644 --- a/zipline/examples/quantopian_buy_apple.py +++ b/zipline/examples/quantopian_buy_apple.py @@ -19,15 +19,16 @@ import pytz from zipline import TradingAlgorithm from zipline.utils.factory import load_from_yahoo -from zipline.api import order +from zipline.api import order, symbol def initialize(context): context.test = 10 + context.aapl = symbol('AAPL') def handle_date(context, data): - order('AAPL', 10) + order(context.aapl, 10) print(context.test) @@ -39,7 +40,8 @@ if __name__ == '__main__': end=end) data = data.dropna() algo = TradingAlgorithm(initialize=initialize, - handle_data=handle_date) + handle_data=handle_date, + identifiers=['AAPL']) results = algo.run(data) results.portfolio_value.plot() pl.show() diff --git a/zipline/finance/controls.py b/zipline/finance/controls.py index b35b654d..c2470a27 100644 --- a/zipline/finance/controls.py +++ b/zipline/finance/controls.py @@ -37,7 +37,7 @@ class TradingControl(with_metaclass(abc.ABCMeta)): @abc.abstractmethod def validate(self, - sid, + asset, amount, portfolio, algo_datetime, @@ -46,21 +46,22 @@ class TradingControl(with_metaclass(abc.ABCMeta)): Before any order is executed by TradingAlgorithm, this method should be called *exactly once* on each registered TradingControl object. - If the specified sid and amount do not violate this TradingControl's + If the specified asset and amount do not violate this TradingControl's restraint given the information in `portfolio`, this method should return None and have no externally-visible side-effects. If the desired order violates this TradingControl's contraint, this - method should call self.fail(sid, amount). + method should call self.fail(asset, amount). """ raise NotImplementedError - def fail(self, sid, amount): + def fail(self, asset, amount, datetime): """ Raise a TradingControlViolation with information about the failure. """ - raise TradingControlViolation(sid=sid, + raise TradingControlViolation(asset=asset, amount=amount, + datetime=datetime, constraint=repr(self)) def __repr__(self): @@ -82,7 +83,7 @@ class MaxOrderCount(TradingControl): self.current_date = None def validate(self, - sid, + asset, amount, _portfolio, algo_datetime, @@ -98,13 +99,13 @@ class MaxOrderCount(TradingControl): self.current_date = algo_date if self.orders_placed >= self.max_count: - self.fail(sid, amount) + self.fail(asset, amount, algo_datetime) self.orders_placed += 1 class RestrictedListOrder(TradingControl): """ - TradingControl representing a restricted list of securities that + TradingControl representing a restricted list of assets that cannot be ordered by the algorithm. """ @@ -119,30 +120,30 @@ class RestrictedListOrder(TradingControl): self.restricted_list = restricted_list def validate(self, - sid, + asset, amount, _portfolio, _algo_datetime, _algo_current_data): """ - Fail if the sid is in the restricted_list. + Fail if the asset is in the restricted_list. """ - if sid in self.restricted_list: - self.fail(sid, amount) + if asset in self.restricted_list: + self.fail(asset, amount, _algo_datetime) class MaxOrderSize(TradingControl): """ TradingControl representing a limit on the magnitude of any single order - placed with the given security. Can be specified by share or by dollar + placed with the given asset. Can be specified by share or by dollar value. """ - def __init__(self, sid=None, max_shares=None, max_notional=None): - super(MaxOrderSize, self).__init__(sid=sid, + def __init__(self, asset=None, max_shares=None, max_notional=None): + super(MaxOrderSize, self).__init__(asset=asset, max_shares=max_shares, max_notional=max_notional) - self.sid = sid + self.asset = asset self.max_shares = max_shares self.max_notional = max_notional @@ -162,7 +163,7 @@ class MaxOrderSize(TradingControl): ) def validate(self, - sid, + asset, amount, portfolio, _algo_datetime, @@ -172,33 +173,33 @@ class MaxOrderSize(TradingControl): or self.max_notional. """ - if self.sid is not None and self.sid != sid: + if self.asset is not None and self.asset != asset: return if self.max_shares is not None and abs(amount) > self.max_shares: - self.fail(sid, amount) + self.fail(asset, amount, _algo_datetime) - current_sid_price = algo_current_data[sid].price - order_value = amount * current_sid_price + current_asset_price = algo_current_data[asset].price + order_value = amount * current_asset_price too_much_value = (self.max_notional is not None and abs(order_value) > self.max_notional) if too_much_value: - self.fail(sid, amount) + self.fail(asset, amount, _algo_datetime) class MaxPositionSize(TradingControl): """ TradingControl representing a limit on the maximum position size that can - be held by an algo for a given security. + be held by an algo for a given asset. """ - def __init__(self, sid=None, max_shares=None, max_notional=None): - super(MaxPositionSize, self).__init__(sid=sid, + def __init__(self, asset=None, max_shares=None, max_notional=None): + super(MaxPositionSize, self).__init__(asset=asset, max_shares=max_shares, max_notional=max_notional) - self.sid = sid + self.asset = asset self.max_shares = max_shares self.max_notional = max_notional @@ -218,7 +219,7 @@ class MaxPositionSize(TradingControl): ) def validate(self, - sid, + asset, amount, portfolio, algo_datetime, @@ -229,25 +230,25 @@ class MaxPositionSize(TradingControl): self.max_notional. """ - if self.sid is not None and self.sid != sid: + if self.asset is not None and self.asset != asset: return - current_share_count = portfolio.positions[sid].amount + current_share_count = portfolio.positions[asset].amount shares_post_order = current_share_count + amount too_many_shares = (self.max_shares is not None and abs(shares_post_order) > self.max_shares) if too_many_shares: - self.fail(sid, amount) + self.fail(asset, amount, algo_datetime) - current_price = algo_current_data[sid].price + current_price = algo_current_data[asset].price value_post_order = shares_post_order * current_price too_much_value = (self.max_notional is not None and abs(value_post_order) > self.max_notional) if too_much_value: - self.fail(sid, amount) + self.fail(asset, amount, algo_datetime) class LongOnly(TradingControl): @@ -256,17 +257,41 @@ class LongOnly(TradingControl): """ def validate(self, - sid, + asset, amount, portfolio, _algo_datetime, _algo_current_data): """ - Fail if we would hold negative shares of sid after completing this + Fail if we would hold negative shares of asset after completing this order. """ - if portfolio.positions[sid].amount + amount < 0: - self.fail(sid, amount) + if portfolio.positions[asset].amount + amount < 0: + self.fail(asset, amount, _algo_datetime) + + +class AssetDateBounds(TradingControl): + """ + TradingControl representing a prohibition against ordering an asset before + its start_date, or after its end_date. + """ + + def validate(self, + asset, + amount, + portfolio, + algo_datetime, + algo_current_data): + """ + Fail if the algo has passed this Asset's end_date, or before the + Asset's start date. + """ + # Fail if the algo is before this Asset's start_date + if asset.start_date and (algo_datetime < asset.start_date): + self.fail(asset, amount, algo_datetime) + # Fail if the algo has passed this Asset's end_date + if asset.end_date and (algo_datetime >= asset.end_date): + self.fail(asset, amount, algo_datetime) class AccountControl(with_metaclass(abc.ABCMeta)): diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 155e992b..598843f0 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -31,7 +31,7 @@ omitted). | | end of the period | +---------------+------------------------------------------------------+ | cash_flow | the cash flow in the period (negative means spent) | - | | from buying and selling securities in the period. | + | | from buying and selling assets in the period. | | | Includes dividend payments in the period as well. | +---------------+------------------------------------------------------+ | starting_value| the total market value of the positions held at the | @@ -75,6 +75,9 @@ import logbook import numpy as np +from zipline.finance.trading import TradingEnvironment +from zipline.assets import Future + try: # optional cython based OrderedDict from cyordereddict import OrderedDict @@ -128,6 +131,7 @@ class PerformancePeriod(object): self.period_close = period_close self.ending_value = 0.0 + self.ending_exposure = 0.0 self.period_cash_flow = 0.0 self.pnl = 0.0 @@ -160,6 +164,7 @@ class PerformancePeriod(object): def rollover(self): self.starting_value = self.ending_value + self.starting_exposure = self.ending_exposure self.starting_cash = self.ending_cash self.period_cash_flow = 0.0 self.pnl = 0.0 @@ -187,6 +192,7 @@ class PerformancePeriod(object): def calculate_performance(self): self.ending_value = self.calculate_positions_value() + self.ending_exposure = self.calculate_positions_exposure() total_at_start = self.starting_cash + self.starting_value self.ending_cash = self.starting_cash + self.period_cash_flow @@ -215,7 +221,12 @@ class PerformancePeriod(object): self.orders_by_id[order.id] = order def handle_execution(self, txn): - self.period_cash_flow -= txn.price * txn.amount + asset = TradingEnvironment.instance().asset_finder.\ + retrieve_asset(txn.sid) + + # Futures experience no cash flow on transactions + if not isinstance(asset, Future): + self.period_cash_flow -= txn.price * txn.amount if self.keep_transactions: try: @@ -232,6 +243,10 @@ class PerformancePeriod(object): def position_amounts(self): return self.position_tracker.position_amounts + @position_proxy + def calculate_positions_exposure(self): + raise ProxyError() + @position_proxy def calculate_positions_value(self): raise ProxyError() @@ -244,6 +259,10 @@ class PerformancePeriod(object): def _long_exposure(self): raise ProxyError() + @position_proxy + def _long_value(self): + raise ProxyError() + @position_proxy def _shorts_count(self): raise ProxyError() @@ -252,19 +271,29 @@ class PerformancePeriod(object): def _short_exposure(self): raise ProxyError() + @position_proxy + def _short_value(self): + raise ProxyError() + @position_proxy def _gross_exposure(self): raise ProxyError() + @position_proxy + def _gross_value(self): + raise ProxyError() + @position_proxy def _net_exposure(self): raise ProxyError() + @position_proxy + def _net_value(self): + raise ProxyError() + @property def _net_liquidation_value(self): - return self.ending_cash + \ - self._long_exposure() + \ - self._short_exposure() + return self.ending_cash + self._long_value() + self._short_value() def _gross_leverage(self): net_liq = self._net_liquidation_value @@ -283,10 +312,12 @@ class PerformancePeriod(object): def __core_dict(self): rval = { 'ending_value': self.ending_value, + 'ending_exposure': self.ending_exposure, # this field is renamed to capital_used for backward # compatibility. 'capital_used': self.period_cash_flow, 'starting_value': self.starting_value, + 'starting_exposure': self.starting_exposure, 'starting_cash': self.starting_cash, 'ending_cash': self.ending_cash, 'portfolio_value': self.ending_cash + self.ending_value, @@ -298,6 +329,8 @@ class PerformancePeriod(object): 'net_leverage': self._net_leverage(), 'short_exposure': self._short_exposure(), 'long_exposure': self._long_exposure(), + 'short_value': self._short_value(), + 'long_value': self._long_value(), 'longs_count': self._longs_count(), 'shorts_count': self._shorts_count() } @@ -372,6 +405,7 @@ class PerformancePeriod(object): portfolio.start_date = self.period_open portfolio.positions = self.get_positions() portfolio.positions_value = self.ending_value + portfolio.positions_exposure = self.ending_exposure return portfolio def as_account(self): @@ -393,6 +427,8 @@ class PerformancePeriod(object): self.ending_cash + self.ending_value) account.total_positions_value = \ getattr(self, 'total_positions_value', self.ending_value) + account.total_positions_value = \ + getattr(self, 'total_positions_exposure', self.ending_exposure) account.regt_equity = \ getattr(self, 'regt_equity', self.ending_cash) account.regt_margin = \ diff --git a/zipline/finance/performance/position.py b/zipline/finance/performance/position.py index 2443ee12..9a28ba86 100644 --- a/zipline/finance/performance/position.py +++ b/zipline/finance/performance/position.py @@ -20,12 +20,11 @@ Position Tracking +-----------------+----------------------------------------------------+ | key | value | +=================+====================================================+ - | sid | the identifier for the security held in this | - | | position. | + | sid | the sid for the asset held in this position | +-----------------+----------------------------------------------------+ | amount | whole number of shares in the position | +-----------------+----------------------------------------------------+ - | last_sale_price | price at last sale of the security on the exchange | + | last_sale_price | price at last sale of the asset on the exchange | +-----------------+----------------------------------------------------+ | cost_basis | the volume weighted average price paid per share | +-----------------+----------------------------------------------------+ diff --git a/zipline/finance/performance/position_tracker.py b/zipline/finance/performance/position_tracker.py index bd3f674f..cc7581e6 100644 --- a/zipline/finance/performance/position_tracker.py +++ b/zipline/finance/performance/position_tracker.py @@ -1,5 +1,4 @@ from __future__ import division -from operator import mul import logbook import numpy as np @@ -10,8 +9,7 @@ try: from cyordereddict import OrderedDict except ImportError: from collections import OrderedDict -from six import iteritems -from six.moves import map, filter +from six import iteritems, itervalues from zipline.finance.slippage import Transaction from zipline.utils.serialization_utils import ( @@ -19,6 +17,10 @@ from zipline.utils.serialization_utils import ( ) import zipline.protocol as zp +from zipline.assets import ( + Equity, Future +) +from zipline.finance.trading import with_environment from . position import positiondict log = logbook.Logger('Performance') @@ -32,26 +34,69 @@ class PositionTracker(object): # Arrays for quick calculations of positions value self._position_amounts = OrderedDict() self._position_last_sale_prices = OrderedDict() + self._position_value_multipliers = OrderedDict() + self._position_exposure_multipliers = OrderedDict() self._unpaid_dividends = pd.DataFrame( columns=zp.DIVIDEND_PAYMENT_FIELDS, ) self._positions_store = zp.Positions() + # Cached for fast property calculation + self._position_values = None + self._position_exposures = None + + def _invalidate_cache(self): + self._position_values = None + self._position_exposures = None + + @with_environment() + def _retrieve_asset(self, sid, env=None): + return env.asset_finder.retrieve_asset(sid) + + def _update_multipliers(self, sid): + try: + self._position_value_multipliers[sid] + self._position_exposure_multipliers[sid] + except KeyError: + # Collect the value multipliers from applicable sids + asset = self._retrieve_asset(sid) + if isinstance(asset, Equity): + self._position_value_multipliers[sid] = 1 + self._position_exposure_multipliers[sid] = 1 + if isinstance(asset, Future): + self._position_value_multipliers[sid] = 0 + self._position_exposure_multipliers[sid] = \ + asset.contract_multiplier + def update_last_sale(self, event): # NOTE, PerformanceTracker already vetted as TRADE type sid = event.sid if sid not in self.positions: - return + return 0 price = event.price - if not checknull(price): - pos = self.positions[sid] - pos.last_sale_date = event.dt - pos.last_sale_price = price - self._position_last_sale_prices[sid] = price - self._position_values = None # invalidate cache - sid = event.sid - price = event.price + + if checknull(price): + return 0 + + pos = self.positions[sid] + old_price = pos.last_sale_price + pos.last_sale_date = event.dt + pos.last_sale_price = price + self._position_last_sale_prices[sid] = price + self._invalidate_cache() + + asset = self._retrieve_asset(sid) + if asset is None: + return 0 + + # Calculate cash adjustment on futures + cash_adjustment = 0 + if isinstance(asset, Future): + price_change = price - old_price + cash_adjustment = \ + price_change * asset.contract_multiplier * pos.amount + return cash_adjustment def update_positions(self, positions): # update positions in batch @@ -59,8 +104,8 @@ class PositionTracker(object): for sid, pos in iteritems(positions): self._position_amounts[sid] = pos.amount self._position_last_sale_prices[sid] = pos.last_sale_price - # Invalidate cache. - self._position_values = None # invalidate cache + self._update_multipliers(sid) + self._invalidate_cache() def update_position(self, sid, amount=None, last_sale_price=None, last_sale_date=None, cost_basis=None): @@ -70,6 +115,7 @@ class PositionTracker(object): pos.amount = amount self._position_amounts[sid] = amount self._position_values = None # invalidate cache + self._update_multipliers(sid=sid) if last_sale_price is not None: pos.last_sale_price = last_sale_price self._position_last_sale_prices[sid] = last_sale_price @@ -82,13 +128,13 @@ class PositionTracker(object): def execute_transaction(self, txn): # Update Position # ---------------- - sid = txn.sid position = self.positions[sid] position.update(txn) self._position_amounts[sid] = position.amount self._position_last_sale_prices[sid] = position.last_sale_price - self._position_values = None # invalidate cache + self._update_multipliers(sid) + self._invalidate_cache() def handle_commission(self, commission): # Adjust the cost basis of the stock if we own it @@ -96,8 +142,6 @@ class PositionTracker(object): self.positions[commission.sid].\ adjust_commission_cost_basis(commission) - _position_values = None - @property def position_values(self): """ @@ -105,33 +149,75 @@ class PositionTracker(object): self._position_last_sale_prices is changed. """ if self._position_values is None: - vals = list(map(mul, self._position_amounts.values(), - self._position_last_sale_prices.values())) - self._position_values = vals + iter_amount_price_multiplier = zip( + itervalues(self._position_amounts), + itervalues(self._position_last_sale_prices), + itervalues(self._position_value_multipliers), + ) + self._position_values = [ + price * amount * multiplier for + price, amount, multiplier in iter_amount_price_multiplier + ] return self._position_values + @property + def position_exposures(self): + """ + Invalidate any time self._position_amounts or + self._position_last_sale_prices is changed. + """ + if self._position_exposures is None: + iter_amount_price_multiplier = zip( + itervalues(self._position_amounts), + itervalues(self._position_last_sale_prices), + itervalues(self._position_exposure_multipliers), + ) + self._position_exposures = [ + price * amount * multiplier for + price, amount, multiplier in iter_amount_price_multiplier + ] + return self._position_exposures + def calculate_positions_value(self): if len(self.position_values) == 0: return np.float64(0) return sum(self.position_values) + def calculate_positions_exposure(self): + if len(self.position_exposures) == 0: + return np.float64(0) + + return sum(self.position_exposures) + def _longs_count(self): - return sum(map(lambda x: x > 0, self.position_values)) + return sum(1 for i in self.position_exposures if i > 0) def _long_exposure(self): - return sum(filter(lambda x: x > 0, self.position_values)) + return sum(i for i in self.position_exposures if i > 0) + + def _long_value(self): + return sum(i for i in self.position_values if i > 0) def _shorts_count(self): - return sum(map(lambda x: x < 0, self.position_values)) + return sum(1 for i in self.position_exposures if i < 0) def _short_exposure(self): - return sum(filter(lambda x: x < 0, self.position_values)) + return sum(i for i in self.position_exposures if i < 0) + + def _short_value(self): + return sum(i for i in self.position_values if i < 0) def _gross_exposure(self): return self._long_exposure() + abs(self._short_exposure()) + def _gross_value(self): + return self._long_value() + abs(self._short_value()) + def _net_exposure(self): + return self.calculate_positions_exposure() + + def _net_value(self): return self.calculate_positions_value() def handle_split(self, split): @@ -143,7 +229,8 @@ class PositionTracker(object): self._position_amounts[split.sid] = position.amount self._position_last_sale_prices[split.sid] = \ position.last_sale_price - self._position_values = None # invalidate cache + self._update_multipliers(split.sid) + self._invalidate_cache() return leftover_cash def _maybe_earn_dividend(self, dividend): @@ -204,15 +291,17 @@ class PositionTracker(object): stock = row['payment_sid'] share_count = row['share_count'] # note we create a Position for stock dividend if we don't - # already own the security + # already own the asset position = self.positions[stock] position.amount += share_count self._position_amounts[stock] = position.amount self._position_last_sale_prices[stock] = position.last_sale_price + self._update_multipliers(stock) + self._invalidate_cache() # Add cash equal to the net cash payed from all dividends. Note that - # "negative cash" is effectively paid if we're short a security, + # "negative cash" is effectively paid if we're short an asset, # representing the fact that we're required to reimburse the owner of # the stock for any dividends paid while borrowing. net_cash_payment = payments['cash_amount'].fillna(0).sum() @@ -290,5 +379,8 @@ class PositionTracker(object): # Arrays for quick calculations of positions value self._position_amounts = OrderedDict() self._position_last_sale_prices = OrderedDict() + self._position_value_multipliers = OrderedDict() + self._position_exposure_multipliers = OrderedDict() + self._invalidate_cache() self.update_positions(state['positions']) diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index 6ddf03e6..fdfd91e7 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -68,10 +68,9 @@ import pandas as pd from pandas.tseries.tools import normalize_date import zipline.finance.risk as risk -from zipline.finance import trading +from zipline.finance.trading import TradingEnvironment from . period import PerformancePeriod -from zipline.finance.trading import with_environment from zipline.utils.serialization_utils import ( VERSION_LABEL ) @@ -84,26 +83,23 @@ class PerformanceTracker(object): """ Tracks the performance of the algorithm. """ - - @with_environment() - def __init__(self, sim_params, env=None): + def __init__(self, sim_params): self.sim_params = sim_params + env = TradingEnvironment.instance() self.period_start = self.sim_params.period_start self.period_end = self.sim_params.period_end self.last_close = self.sim_params.last_close - first_open = self.sim_params.first_open.tz_convert( - trading.environment.exchange_tz) + first_open = self.sim_params.first_open.tz_convert(env.exchange_tz) self.day = pd.Timestamp(datetime(first_open.year, first_open.month, first_open.day), tz='UTC') - self.market_open, self.market_close = \ - trading.environment.get_open_and_close(self.day) + self.market_open, self.market_close = env.get_open_and_close(self.day) self.total_days = self.sim_params.days_in_period self.capital_base = self.sim_params.capital_base self.emission_rate = sim_params.emission_rate - all_trading_days = trading.environment.trading_days + all_trading_days = env.trading_days mask = ((all_trading_days >= normalize_date(self.period_start)) & (all_trading_days <= normalize_date(self.period_end))) @@ -287,10 +283,13 @@ class PerformanceTracker(object): return _dict def process_trade(self, event): - self.position_tracker.update_last_sale(event) + # update last sale, and pay out a cash adjustment + cash_adjustment = self.position_tracker.update_last_sale(event) + if cash_adjustment != 0: + for perf_period in self.perf_periods: + perf_period.handle_cash_payment(cash_adjustment) def process_transaction(self, event): - self.txn_count += 1 self.position_tracker.execute_transaction(event) for perf_period in self.perf_periods: @@ -342,7 +341,7 @@ class PerformanceTracker(object): if txn: self.process_transaction(txn) - def check_upcoming_dividends(self, midnight_of_date_that_just_ended): + def check_upcoming_dividends(self, next_trading_day): """ Check if we currently own any stocks with dividends whose ex_date is the next trading day. Track how much we should be payed on those @@ -357,17 +356,6 @@ class PerformanceTracker(object): # period, so bail. return - next_trading_day_idx = self.trading_days.get_loc( - midnight_of_date_that_just_ended, - ) + 1 - - if next_trading_day_idx < len(self.trading_days): - next_trading_day = self.trading_days[next_trading_day_idx] - else: - # Bail if the next trading day is outside our trading range, since - # we won't simulate the next day. - return - # Dividends whose ex_date is the next trading day. We need to check if # we own any of these stocks so we know to pay them out when the pay # date comes. @@ -412,7 +400,10 @@ class PerformanceTracker(object): # if this is the close, save the returns objects for cumulative risk # calculations and update dividends for the next day. if dt == self.market_close: - self.check_upcoming_dividends(todays_date) + next_trading_day = TradingEnvironment.instance().\ + next_trading_day(todays_date) + if next_trading_day: + self.check_upcoming_dividends(next_trading_day) def handle_intraday_market_close(self, new_mkt_open, new_mkt_close): """ @@ -454,16 +445,19 @@ class PerformanceTracker(object): return daily_update # move the market day markers forward + env = TradingEnvironment.instance() self.market_open, self.market_close = \ - trading.environment.next_open_and_close(self.day) - self.day = trading.environment.next_trading_day(self.day) + env.next_open_and_close(self.day) + self.day = env.next_trading_day(self.day) # Roll over positions to current day. self.todays_performance.rollover() self.todays_performance.period_open = self.market_open self.todays_performance.period_close = self.market_close - self.check_upcoming_dividends(completed_date) + next_trading_day = env.next_trading_day(completed_date) + if next_trading_day: + self.check_upcoming_dividends(next_trading_day) return daily_update diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 3ac5c7cb..44e369b1 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -23,6 +23,8 @@ import numpy as np from zipline.data.loader import load_market_data from zipline.utils import tradingcalendar +from zipline.assets import AssetFinder +from zipline.errors import UpdateAssetFinderTypeError log = logbook.Logger('Trading') @@ -134,6 +136,8 @@ class TradingEnvironment(object): self.exchange_tz = exchange_tz + self.asset_finder = AssetFinder(trading_calendar=env_trading_calendar) + def __enter__(self, *args, **kwargs): global environment self.prev_environment = environment @@ -149,6 +153,52 @@ class TradingEnvironment(object): # stack. return False + def update_asset_finder(self, + clear_metadata=False, + asset_finder=None, + asset_metadata=None, + identifiers=None): + """ + Updates the AssetFinder using the provided asset metadata and + identifiers. + If clear_metadata is True, all metadata and assets held in the + asset_finder will be erased before new metadata is provided. + If asset_finder is provided, the existing asset_finder will be replaced + outright with the new asset_finder. + If asset_metadata is provided, the existing metadata will be cleared + and replaced with the provided metadata. + All identifiers will be inserted in the asset metadata if they are not + already present. + + :param clear_metadata: A boolean + :param asset_finder: An AssetFinder object to replace the environment's + existing asset_finder + :param asset_metadata: A dict, DataFrame, or readable object + :param identifiers: A list of identifiers to be inserted + :return: + """ + populate = False + if clear_metadata: + self.asset_finder.clear_metadata() + populate = True + + if asset_finder is not None: + if not isinstance(asset_finder, AssetFinder): + raise UpdateAssetFinderTypeError(cls=asset_finder.__class__) + self.asset_finder = asset_finder + + if asset_metadata is not None: + self.asset_finder.clear_metadata() + self.asset_finder.consume_metadata(asset_metadata) + populate = True + + if identifiers is not None: + self.asset_finder.consume_identifiers(identifiers) + populate = True + + if populate: + self.asset_finder.populate_cache() + def normalize_date(self, test_date): test_date = pd.Timestamp(test_date, tz='UTC') return pd.tseries.tools.normalize_date(test_date) diff --git a/zipline/history/history_container.py b/zipline/history/history_container.py index 8fc19b62..83c069b7 100644 --- a/zipline/history/history_container.py +++ b/zipline/history/history_container.py @@ -217,7 +217,7 @@ class HistoryContainer(object): history_specs (dict[Frequency:HistorySpec]): The starting history specs that this container should be able to service. - initial_sids (set[Security or Int]): The starting sids to watch. + initial_sids (set[Asset or Int]): The starting sids to watch. initial_dt (datetime): The datetime to start collecting history from. diff --git a/zipline/protocol.py b/zipline/protocol.py index 7fcfef05..f5d3f999 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -58,7 +58,12 @@ DIVIDEND_FIELDS = [ 'sid', ] # Expected fields/index values for a dividend payment Series. -DIVIDEND_PAYMENT_FIELDS = ['id', 'payment_sid', 'cash_amount', 'share_count'] +DIVIDEND_PAYMENT_FIELDS = [ + 'id', + 'payment_sid', + 'cash_amount', + 'share_count', +] def dividend_payment(data=None): @@ -73,7 +78,7 @@ def dividend_payment(data=None): payment. Additionally, if @data is non-empty, either data['cash_amount'] should be - nonzero or data['payment_sid'] should be a security identifier and + nonzero or data['payment_sid'] should be an asset identifier and data['share_count'] should be nonzero. The returned Series is given its id value as a name so that concatenating @@ -289,7 +294,6 @@ class SIDData(object): def __init__(self, sid, initial_values=None): self._sid = sid - self._freqstr = None # To check if we have data, we use the __len__ which depends on the diff --git a/zipline/sources/data_frame_source.py b/zipline/sources/data_frame_source.py index 7a0771f4..42d4e9ce 100644 --- a/zipline/sources/data_frame_source.py +++ b/zipline/sources/data_frame_source.py @@ -22,6 +22,7 @@ import pandas as pd from zipline.gens.utils import hash_args from zipline.sources.data_source import DataSource +from zipline.finance.trading import with_environment class DataFrameSource(DataSource): @@ -36,14 +37,23 @@ class DataFrameSource(DataSource): Bars where the price is nan are filtered out. """ - def __init__(self, data, **kwargs): + @with_environment() + def __init__(self, data, env=None, **kwargs): assert isinstance(data.index, pd.tseries.index.DatetimeIndex) - self.data = data + self.data = data.fillna(method='ffill') # Unpack config dictionary with default values. - self.sids = kwargs.get('sids', data.columns) - self.start = kwargs.get('start', data.index[0]) - self.end = kwargs.get('end', data.index[-1]) + self.start = kwargs.get('start', self.data.index[0]) + self.end = kwargs.get('end', self.data.index[-1]) + + # Remap sids based on the trading environment + self.identifiers = kwargs.get('sids', self.data.columns) + env.update_asset_finder(identifiers=self.identifiers) + self.data.columns = [ + env.asset_finder.retrieve_asset_by_identifier(identifier).sid + for identifier in self.data.columns + ] + self.sids = self.data.columns # Hash_value for downstream sorting. self.arg_string = hash_args(data, **kwargs) @@ -105,14 +115,24 @@ class DataPanelSource(DataSource): Bars where the price is nan are filtered out. """ - def __init__(self, data, **kwargs): + @with_environment() + def __init__(self, data, env=None, **kwargs): assert isinstance(data.major_axis, pd.tseries.index.DatetimeIndex) - self.data = data + self.data = data.fillna(value={'volume': 0}) + self.data = self.data.fillna(method='ffill') # Unpack config dictionary with default values. - self.sids = kwargs.get('sids', data.items) - self.start = kwargs.get('start', data.major_axis[0]) - self.end = kwargs.get('end', data.major_axis[-1]) + self.start = kwargs.get('start', self.data.major_axis[0]) + self.end = kwargs.get('end', self.data.major_axis[-1]) + + # Remap sids based on the trading environment + self.identifiers = kwargs.get('sids', self.data.items) + env.update_asset_finder(identifiers=self.identifiers) + self.data.items = [ + env.asset_finder.retrieve_asset_by_identifier(identifier).sid + for identifier in self.data.items + ] + self.sids = self.data.items # Hash_value for downstream sorting. self.arg_string = hash_args(data, **kwargs) diff --git a/zipline/sources/simulated.py b/zipline/sources/simulated.py index 8fd65806..45a91ed9 100644 --- a/zipline/sources/simulated.py +++ b/zipline/sources/simulated.py @@ -23,6 +23,7 @@ import pandas as pd from zipline.sources.data_source import DataSource from zipline.utils import tradingcalendar as calendar_nyse from zipline.gens.utils import hash_args +from zipline.finance import trading class RandomWalkSource(DataSource): @@ -92,6 +93,7 @@ class RandomWalkSource(DataSource): self.sd = sd self.sids = self.start_prices.keys() + trading.environment.update_asset_finder(identifiers=self.sids) self.open_and_closes = \ calendar.open_and_closes[self.start:self.end] diff --git a/zipline/sources/test_source.py b/zipline/sources/test_source.py index 5e64a320..998006b5 100644 --- a/zipline/sources/test_source.py +++ b/zipline/sources/test_source.py @@ -22,7 +22,6 @@ import pytz from six.moves import filter from datetime import datetime, timedelta import itertools -import numpy as np from six.moves import range @@ -112,8 +111,8 @@ class SpecificEquityTrades(object): delta : timedelta between internal events filter : filter to remove the sids """ - - def __init__(self, *args, **kwargs): + @with_environment() + def __init__(self, env=None, *args, **kwargs): # We shouldn't get any positional arguments. assert len(args) == 0 @@ -125,9 +124,7 @@ class SpecificEquityTrades(object): # This isn't really clean and ultimately I think this # class should serve a single purpose (either take an # event_list or autocreate events). - self.sids = kwargs.get( - 'sids', - np.unique([event.sid for event in self.event_list]).tolist()) + self.count = kwargs.get('count', len(self.event_list)) self.start = kwargs.get('start', self.event_list[0].dt) self.end = kwargs.get('end', self.event_list[-1].dt) self.delta = kwargs.get( @@ -135,9 +132,22 @@ class SpecificEquityTrades(object): self.event_list[1].dt - self.event_list[0].dt) self.concurrent = kwargs.get('concurrent', False) + self.identifiers = kwargs.get( + 'sids', + set(event.sid for event in self.event_list) + ) + env.update_asset_finder(identifiers=self.identifiers) + self.sids = [ + env.asset_finder.retrieve_asset_by_identifier(identifier).sid + for identifier in self.identifiers + ] + for event in self.event_list: + event.sid = env.asset_finder.\ + retrieve_asset_by_identifier(event.sid).sid + else: # Unpack config dictionary with default values. - self.sids = kwargs.get('sids', [1, 2]) + self.count = kwargs.get('count', 500) self.start = kwargs.get( 'start', datetime(2008, 6, 6, 15, tzinfo=pytz.utc)) @@ -149,6 +159,13 @@ class SpecificEquityTrades(object): timedelta(minutes=1)) self.concurrent = kwargs.get('concurrent', False) + self.identifiers = kwargs.get('sids', [1, 2]) + env.update_asset_finder(identifiers=self.identifiers) + self.sids = [ + env.asset_finder.retrieve_asset_by_identifier(identifier).sid + for identifier in self.identifiers + ] + # Hash_value for downstream sorting. self.arg_string = hash_args(*args, **kwargs) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 3d99f30b..7942d32c 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -85,14 +85,17 @@ from zipline.api import ( order, set_slippage, record, + sid, ) from zipline.errors import UnsupportedOrderParameters +from zipline.assets import Future, Equity from zipline.finance.execution import ( LimitOrder, MarketOrder, StopLimitOrder, StopOrder, ) +from zipline.finance.controls import AssetDateBounds from zipline.transforms import BatchTransform, batch_transform @@ -111,14 +114,14 @@ class TestAlgorithm(TradingAlgorithm): slippage=None, commission=None): self.count = order_count - self.sid = sid + self.asset = self.sid(sid) self.amount = amount self.incr = 0 if sid_filter: self.sid_filter = sid_filter else: - self.sid_filter = [self.sid] + self.sid_filter = [self.asset.sid] if slippage is not None: self.set_slippage(slippage) @@ -129,7 +132,7 @@ class TestAlgorithm(TradingAlgorithm): def handle_data(self, data): # place an order for amount shares of sid if self.incr < self.count: - self.order(self.sid, self.amount) + self.order(self.asset, self.amount) self.incr += 1 @@ -141,13 +144,13 @@ class HeavyBuyAlgorithm(TradingAlgorithm): """ def initialize(self, sid, amount): - self.sid = sid + self.asset = self.sid(sid) self.amount = amount self.incr = 0 def handle_data(self, data): # place an order for 100 shares of sid - self.order(self.sid, self.amount) + self.order(self.asset, self.amount) self.incr += 1 @@ -177,7 +180,7 @@ class ExceptionAlgorithm(TradingAlgorithm): def initialize(self, throw_from, sid): self.throw_from = throw_from - self.sid = sid + self.asset = self.sid(sid) if self.throw_from == "initialize": raise Exception("Algo exception in initialize") @@ -200,7 +203,7 @@ class ExceptionAlgorithm(TradingAlgorithm): if self.throw_from == "get_sid_filter": raise Exception("Algo exception in get_sid_filter") else: - return [self.sid] + return [self.asset] def set_transact_setter(self, txn_sim_callable): pass @@ -209,7 +212,7 @@ class ExceptionAlgorithm(TradingAlgorithm): class DivByZeroAlgorithm(TradingAlgorithm): def initialize(self, sid): - self.sid = sid + self.asset = self.sid(sid) self.incr = 0 def handle_data(self, data): @@ -222,7 +225,7 @@ class DivByZeroAlgorithm(TradingAlgorithm): class TooMuchProcessingAlgorithm(TradingAlgorithm): def initialize(self, sid): - self.sid = sid + self.asset = self.sid(sid) def handle_data(self, data): # Unless we're running on some sort of @@ -234,7 +237,7 @@ class TooMuchProcessingAlgorithm(TradingAlgorithm): class TimeoutAlgorithm(TradingAlgorithm): def initialize(self, sid): - self.sid = sid + self.asset = self.sid(sid) self.incr = 0 def handle_data(self, data): @@ -269,7 +272,7 @@ class TestOrderAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." self.incr += 1 - self.order(0, 1) + self.order(self.sid(0), 1) class TestOrderInstantAlgorithm(TradingAlgorithm): @@ -286,7 +289,7 @@ class TestOrderInstantAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ self.last_price, "Orders was not filled at last price." self.incr += 2 - self.order_value(0, data[0].price * 2.) + self.order_value(self.sid(0), data[0].price * 2.) self.last_price = data[0].price @@ -311,7 +314,9 @@ class TestOrderStyleForwardingAlgorithm(TradingAlgorithm): assert len(self.portfolio.positions.keys()) == 0 method_to_check = getattr(self, self.method_name) - method_to_check(0, data[0].price, style=StopLimitOrder(10, 10)) + method_to_check(self.sid(0), + data[0].price, + style=StopLimitOrder(10, 10)) assert len(self.blotter.open_orders[0]) == 1 result = self.blotter.open_orders[0][0] @@ -335,7 +340,12 @@ class TestOrderValueAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." self.incr += 2 - self.order_value(0, data[0].price * 2.) + + multiplier = 2. + if isinstance(self.sid(0), Future): + multiplier *= self.sid(0).contract_multiplier + + self.order_value(self.sid(0), data[0].price * multiplier) class TestTargetAlgorithm(TradingAlgorithm): @@ -352,7 +362,7 @@ class TestTargetAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." self.target_shares = np.random.randint(1, 30) - self.order_target(0, self.target_shares) + self.order_target(self.sid(0), self.target_shares) class TestOrderPercentAlgorithm(TradingAlgorithm): @@ -363,7 +373,7 @@ class TestOrderPercentAlgorithm(TradingAlgorithm): def handle_data(self, data): if self.target_shares == 0: assert 0 not in self.portfolio.positions - self.order(0, 10) + self.order(self.sid(0), 10) self.target_shares = 10 return else: @@ -372,10 +382,17 @@ class TestOrderPercentAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." - self.order_percent(0, .001) - self.target_shares += np.floor((.001 * - self.portfolio.portfolio_value) / - data[0].price) + self.order_percent(self.sid(0), .001) + + if isinstance(self.sid(0), Equity): + self.target_shares += np.floor( + (.001 * self.portfolio.portfolio_value) / data[0].price + ) + if isinstance(self.sid(0), Future): + self.target_shares += np.floor( + (.001 * self.portfolio.portfolio_value) / + (data[0].price * self.sid(0).contract_multiplier) + ) class TestTargetPercentAlgorithm(TradingAlgorithm): @@ -394,7 +411,7 @@ class TestTargetPercentAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." self.sale_price = data[0].price - self.order_target_percent(0, .002) + self.order_target_percent(self.sid(0), .002) class TestTargetValueAlgorithm(TradingAlgorithm): @@ -405,7 +422,7 @@ class TestTargetValueAlgorithm(TradingAlgorithm): def handle_data(self, data): if self.target_shares == 0: assert 0 not in self.portfolio.positions - self.order(0, 10) + self.order(self.sid(0), 10) self.target_shares = 10 return else: @@ -415,9 +432,15 @@ class TestTargetValueAlgorithm(TradingAlgorithm): assert self.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." - self.order_target_value(0, 20) + self.order_target_value(self.sid(0), 20) self.target_shares = np.round(20 / data[0].price) + if isinstance(self.sid(0), Equity): + self.target_shares = np.round(20 / data[0].price) + if isinstance(self.sid(0), Future): + self.target_shares = np.round( + 20 / (data[0].price * self.sid(0).contract_multiplier)) + ############################ # AccountControl Test Algos# @@ -468,6 +491,18 @@ class SetLongOnlyAlgorithm(TradingAlgorithm): self.set_long_only() +class SetAssetDateBoundsAlgorithm(TradingAlgorithm): + """ + Algorithm that tries to order 1 share of sid 0 on every bar and has an + AssetDateBounds() trading control in place. + """ + def initialize(self): + self.register_trading_control(AssetDateBounds()) + + def handle_data(algo, data): + algo.order(algo.sid(0), 1) + + class TestRegisterTransformAlgorithm(TradingAlgorithm): def initialize(self, *args, **kwargs): self.set_slippage(FixedSlippage()) @@ -484,7 +519,7 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm): """ def initialize(self, *args, **kwargs): - self.sid = kwargs.pop('sid') + self.asset = self.sid(kwargs.pop('sid')) def handle_data(self, data): @@ -493,42 +528,42 @@ class AmbitiousStopLimitAlgorithm(TradingAlgorithm): ######## # Buy with low limit, shouldn't trigger. - self.order(self.sid, 100, limit_price=1) + self.order(self.asset, 100, limit_price=1) # But with high stop, shouldn't trigger - self.order(self.sid, 100, stop_price=10000000) + self.order(self.asset, 100, stop_price=10000000) # Buy with high limit (should trigger) but also high stop (should # prevent trigger). - self.order(self.sid, 100, limit_price=10000000, stop_price=10000000) + self.order(self.asset, 100, limit_price=10000000, stop_price=10000000) # Buy with low stop (should trigger), but also low limit (should # prevent trigger). - self.order(self.sid, 100, limit_price=1, stop_price=1) + self.order(self.asset, 100, limit_price=1, stop_price=1) ######### # Sells # ######### # Sell with high limit, shouldn't trigger. - self.order(self.sid, -100, limit_price=1000000) + self.order(self.asset, -100, limit_price=1000000) # Sell with low stop, shouldn't trigger. - self.order(self.sid, -100, stop_price=1) + self.order(self.asset, -100, stop_price=1) # Sell with low limit (should trigger), but also high stop (should # prevent trigger). - self.order(self.sid, -100, limit_price=1000000, stop_price=1000000) + self.order(self.asset, -100, limit_price=1000000, stop_price=1000000) # Sell with low limit (should trigger), but also low stop (should # prevent trigger). - self.order(self.sid, -100, limit_price=1, stop_price=1) + self.order(self.asset, -100, limit_price=1, stop_price=1) ################### # Rounding Checks # ################### - self.order(self.sid, 100, limit_price=.00000001) - self.order(self.sid, -100, stop_price=.00000001) + self.order(self.asset, 100, limit_price=.00000001) + self.order(self.asset, -100, stop_price=.00000001) ########################################## @@ -695,8 +730,8 @@ class BatchTransformAlgorithm(TradingAlgorithm): "batch transform is not updating for new kwargs" new_data = deepcopy(data) - for sid in new_data: - new_data[sid]['arbitrary'] = 123 + for sidint in new_data: + new_data[sidint]['arbitrary'] = 123 self.history_return_arbitrary_fields.append( self.return_arbitrary_fields.handle_data(new_data)) @@ -707,8 +742,7 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.return_nan.handle_data(data)) else: nan_data = deepcopy(data) - for sid in nan_data.iterkeys(): - nan_data[sid].price = np.nan + nan_data.price = np.nan self.history_return_nan.append( self.return_nan.handle_data(nan_data)) @@ -810,7 +844,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm): def handle_data(self, data): if not self.ordered: for s in data: - self.order(s, 100) + self.order(self.sid(s), 100) self.ordered = True if not self.exited: @@ -821,7 +855,7 @@ class EmptyPositionsAlgorithm(TradingAlgorithm): (len(amounts) == len(data.keys())) ): for stock in self.portfolio.positions: - self.order(stock, -100) + self.order(self.sid(stock), -100) self.exited = True # Should be 0 when all positions are exited. @@ -834,7 +868,7 @@ class InvalidOrderAlgorithm(TradingAlgorithm): appropriate exceptions are raised. """ def initialize(self, *args, **kwargs): - self.sid = kwargs.pop('sids')[0] + self.asset = self.sid(kwargs.pop('sids')[0]) def handle_data(self, data): from zipline.api import ( @@ -849,40 +883,48 @@ class InvalidOrderAlgorithm(TradingAlgorithm): StopOrder(10), StopLimitOrder(10, 10)]: with assert_raises(UnsupportedOrderParameters): - order(self.sid, 10, limit_price=10, style=style) + order(self.asset, 10, limit_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order(self.sid, 10, stop_price=10, style=style) + order(self.asset, 10, stop_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_value(self.sid, 300, limit_price=10, style=style) + order_value(self.asset, 300, limit_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_value(self.sid, 300, stop_price=10, style=style) + order_value(self.asset, 300, stop_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_percent(self.sid, .1, limit_price=10, style=style) + order_percent(self.asset, .1, limit_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_percent(self.sid, .1, stop_price=10, style=style) + order_percent(self.asset, .1, stop_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_target(self.sid, 100, limit_price=10, style=style) + order_target(self.asset, 100, limit_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_target(self.sid, 100, stop_price=10, style=style) + order_target(self.asset, 100, stop_price=10, style=style) with assert_raises(UnsupportedOrderParameters): - order_target_value(self.sid, 100, limit_price=10, style=style) + order_target_value(self.asset, 100, + limit_price=10, + style=style) with assert_raises(UnsupportedOrderParameters): - order_target_value(self.sid, 100, stop_price=10, style=style) + order_target_value(self.asset, 100, + stop_price=10, + style=style) with assert_raises(UnsupportedOrderParameters): - order_target_percent(self.sid, .2, limit_price=10, style=style) + order_target_percent(self.asset, .2, + limit_price=10, + style=style) with assert_raises(UnsupportedOrderParameters): - order_target_percent(self.sid, .2, stop_price=10, style=style) + order_target_percent(self.asset, .2, + stop_price=10, + style=style) ############################## @@ -913,7 +955,7 @@ def handle_data_api(context, data): assert context.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." context.incr += 1 - order(0, 1) + order(sid(0), 1) record(incr=context.incr) @@ -932,7 +974,8 @@ api_algo = """ from zipline.api import (order, set_slippage, FixedSlippage, - record) + record, + sid) def initialize(context): context.incr = 0 @@ -948,7 +991,7 @@ def handle_data(context, data): assert context.portfolio.positions[0]['last_sale_price'] == \ data[0].price, "Orders not filled at current price." context.incr += 1 - order(0, 1) + order(sid(0), 1) record(incr=context.incr) """ @@ -1009,18 +1052,19 @@ from zipline.api import (order, order_percent, order_target, order_target_value, - order_target_percent) + order_target_percent, + sid) def initialize(context): pass def handle_data(context, data): - order(0, 10) - order_value(0, 300) - order_percent(0, .1) - order_target(0, 100) - order_target_value(0, 100) - order_target_percent(0, .2) + order(sid(0), 10) + order_value(sid(0), 300) + order_percent(sid(0), .1) + order_target(sid(0), 100) + order_target_value(sid(0), 100) + order_target_percent(sid(0), .2) """ record_variables = """ diff --git a/zipline/utils/cli.py b/zipline/utils/cli.py index 017b011f..ad9d58d1 100644 --- a/zipline/utils/cli.py +++ b/zipline/utils/cli.py @@ -31,14 +31,15 @@ except: PYGMENTS = False import zipline +from zipline.errors import NoSourceError, PipelineDateError DEFAULTS = { - 'start': '2012-01-01', - 'end': '2012-12-31', 'data_frequency': 'daily', 'capital_base': '10e6', 'source': 'yahoo', - 'symbols': 'AAPL' + 'symbols': 'AAPL', + 'metadata_index': 'symbol', + 'source_time_column': 'Date', } @@ -99,9 +100,12 @@ def parse_args(argv, ipython_mode=False): parser.add_argument('--start', '-s') parser.add_argument('--end', '-e') parser.add_argument('--capital_base') - parser.add_argument('--source', choices=('yahoo',)) + parser.add_argument('--source', '-d') + parser.add_argument('--source_time_column', '-t') parser.add_argument('--symbols') parser.add_argument('--output', '-o') + parser.add_argument('--metadata_path', '-m') + parser.add_argument('--metadata_index', '-x') if ipython_mode: parser.add_argument('--local_namespace', action='store_true') @@ -153,14 +157,59 @@ def run_pipeline(print_algo=True, **kwargs): pygments syntax coloring if pygments is found. """ - start = pd.Timestamp(kwargs['start'], tz='UTC') - end = pd.Timestamp(kwargs['end'], tz='UTC') + start = kwargs['start'] + end = kwargs['end'] + # Compare against None because strings/timestamps may have been given + if start is not None: + start = pd.Timestamp(start, tz='UTC') + if end is not None: + end = pd.Timestamp(end, tz='UTC') + + # Fail out if only one bound is provided + if ((start is None) or (end is None)) and (start != end): + raise PipelineDateError(start=start, end=end) + + # Check if start and end are provided, and if the sim_params need to read + # a start and end from the DataSource + if start is None: + overwrite_sim_params = True + else: + overwrite_sim_params = False symbols = kwargs['symbols'].split(',') + asset_identifier = kwargs['metadata_index'] - if kwargs['source'] == 'yahoo': + # Pull asset metadata + asset_metadata = kwargs.get('asset_metadata', None) + asset_metadata_path = kwargs['metadata_path'] + # Read in a CSV file, if applicable + if asset_metadata_path is not None: + if os.path.isfile(asset_metadata_path): + asset_metadata = pd.read_csv(asset_metadata_path, + index_col=asset_identifier) + + source_arg = kwargs['source'] + source_time_column = kwargs['source_time_column'] + + if source_arg is None: + raise NoSourceError() + + elif source_arg == 'yahoo': source = zipline.data.load_bars_from_yahoo( stocks=symbols, start=start, end=end) + + elif os.path.isfile(source_arg): + source = zipline.data.load_prices_from_csv( + filepath=source_arg, + identifier_col=source_time_column + ) + + elif os.path.isdir(source_arg): + source = zipline.data.load_prices_from_csv_folder( + folderpath=source_arg, + identifier_col=source_time_column + ) + else: raise NotImplementedError( 'Source %s not implemented.' % kwargs['source']) @@ -188,9 +237,13 @@ def run_pipeline(print_algo=True, **kwargs): algo = zipline.TradingAlgorithm(script=algo_text, namespace=kwargs.get('namespace', {}), capital_base=float(kwargs['capital_base']), - algo_filename=kwargs.get('algofile')) + algo_filename=kwargs.get('algofile'), + asset_metadata=asset_metadata, + identifiers=symbols, + start=start, + end=end) - perf = algo.run(source) + perf = algo.run(source, overwrite_sim_params=overwrite_sim_params) output_fname = kwargs.get('output', None) if output_fname is not None: diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index eee575f2..beb05aa2 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -120,6 +120,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params, source_id="test_factory"): trades = [] current = sim_params.first_open + trading.environment.update_asset_finder(identifiers=[sid]) oneday = timedelta(days=1) use_midnight = interval >= oneday @@ -149,7 +150,6 @@ def create_dividend(sid, payment, declared_date, ex_date, pay_date): 'type': DATASOURCE_TYPE.DIVIDEND, 'source_id': 'MockDividendSource' }) - return div @@ -312,6 +312,8 @@ def create_test_df_source(sim_params=None, bars='daily'): df = pd.DataFrame(x, index=index, columns=[0]) + trading.environment.update_asset_finder(identifiers=[0]) + return DataFrameSource(df), df diff --git a/zipline/utils/security_list.py b/zipline/utils/security_list.py index 341b2135..c9ba8658 100644 --- a/zipline/utils/security_list.py +++ b/zipline/utils/security_list.py @@ -5,6 +5,7 @@ import os.path import pandas as pd import pytz import zipline +from zipline.finance.trading import with_environment DATE_FORMAT = "%Y%m%d" @@ -12,23 +13,16 @@ zipline_dir = os.path.dirname(zipline.__file__) SECURITY_LISTS_DIR = os.path.join(zipline_dir, 'resources', 'security_lists') -def loopback(symbol, *args, **kwargs): - return symbol - - class SecurityList(object): - def __init__(self, lookup_func, data, current_date_func): + def __init__(self, data, current_date_func): """ - lookup_func: function that takes a string symbol and a date and - returns a Security object. data: a nested dictionary: knowledge_date -> lookup_date -> {add: [symbol list], 'delete': []}, delete: [symbol list]} current_date_func: function taking no parameters, returning current datetime """ - self.lookup_func = lookup_func self.data = data self._cache = {} self._knowledge_dates = self.make_knowledge_dates(self.data) @@ -74,13 +68,17 @@ class SecurityList(object): self._cache[kd] = self._current_set return self._current_set - def update_current(self, effective_date, symbols, change_func): + @with_environment() + def update_current(self, effective_date, symbols, change_func, env=None): for symbol in symbols: - sid = self.lookup_func( + asset = env.asset_finder.lookup_symbol( symbol, as_of_date=effective_date ) - change_func(sid) + # Pass if no Asset exists for the symbol + if asset is None: + continue + change_func(asset.sid) class SecurityListSet(object): @@ -88,11 +86,7 @@ class SecurityListSet(object): # list implementations. security_list_type = SecurityList - def __init__(self, current_date_func, lookup_func=None): - if lookup_func is None: - self.lookup_func = loopback - else: - self.lookup_func = lookup_func + def __init__(self, current_date_func): self.current_date_func = current_date_func self._leveraged_etf = None @@ -100,7 +94,6 @@ class SecurityListSet(object): def leveraged_etf_list(self): if self._leveraged_etf is None: self._leveraged_etf = self.security_list_type( - self.lookup_func, load_from_directory('leveraged_etf_list'), self.current_date_func ) diff --git a/zipline/utils/simfactory.py b/zipline/utils/simfactory.py index cce56ab7..876d6734 100644 --- a/zipline/utils/simfactory.py +++ b/zipline/utils/simfactory.py @@ -7,7 +7,7 @@ def create_test_zipline(**config): """ :param config: A configuration object that is a dict with: - - sid - an integer, which will be used as the security ID. + - sid - an integer, which will be used as the asset ID. - order_count - the number of orders the test algo will place, defaults to 100 - order_amount - the number of shares per order, defaults to 100 @@ -25,10 +25,15 @@ def create_test_zipline(**config): :py:mod:`zipline.finance.trading` """ assert isinstance(config, dict) - sid_list = config.get('sid_list') - if not sid_list: - sid = config.get('sid') - sid_list = [sid] + + try: + sid_list = config['sid_list'] + except KeyError: + try: + sid_list = [config['sid']] + except KeyError: + raise Exception("simfactory create_test_zipline() requires " + "argument 'sid_list' or 'sid'") concurrent_trades = config.get('concurrent_trades', False) @@ -49,12 +54,13 @@ def create_test_zipline(**config): test_algo = config['algorithm'] else: test_algo = TestAlgorithm( - sid, + sid_list[0], order_amount, order_count, sim_params=config.get('sim_params', factory.create_simulation_parameters()), slippage=config.get('slippage'), + identifiers=sid_list ) # -------------------