diff --git a/zipline/finance/blotter.py b/zipline/finance/blotter.py index 823b3358..53e809dd 100644 --- a/zipline/finance/blotter.py +++ b/zipline/finance/blotter.py @@ -35,7 +35,10 @@ from zipline.finance.commission import PerShare log = Logger('Blotter') from zipline.utils.protocol_utils import Enum -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) ORDER_STATUS = Enum( 'OPEN', @@ -259,8 +262,21 @@ class Blotter(SerializeableZiplineObject): state_dict['open_orders'] = \ self._defaultdict_list_get_state(self.open_orders) + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Blotter saved is state too old.") + + super(Blotter, self).__setstate__(state) + class Order(SerializeableZiplineObject): def __init__(self, dt, sid, amount, stop=None, limit=None, filled=0, @@ -401,8 +417,22 @@ class Order(SerializeableZiplineObject): return text_type(repr(self)) def __getstate__(self): + state_dict = super(Order, self).__getstate__() state_dict['_status'] = self._status + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Order saved state is too old.") + + super(Order, self).__setstate__(state) diff --git a/zipline/finance/commission.py b/zipline/finance/commission.py index f7bd84c7..5a5bfdea 100644 --- a/zipline/finance/commission.py +++ b/zipline/finance/commission.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) class PerShare(SerializeableZiplineObject): @@ -52,6 +55,25 @@ class PerShare(SerializeableZiplineObject): commission = max(commission, self.min_trade_cost) return abs(commission / transaction.amount), commission + def __getstate__(self): + + state_dict = super(PerShare, self).__getstate__() + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("PerShare saved state is too old.") + + super(PerShare, self).__setstate__(state) + class PerTrade(SerializeableZiplineObject): """ @@ -79,6 +101,25 @@ class PerTrade(SerializeableZiplineObject): return abs(self.cost / transaction.amount), self.cost + def __getstate__(self): + + state_dict = super(PerTrade, self).__getstate__() + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("PerTrade saved state is too old.") + + super(PerTrade, self).__setstate__(state) + class PerDollar(SerializeableZiplineObject): """ @@ -105,3 +146,22 @@ class PerDollar(SerializeableZiplineObject): """ cost_per_share = transaction.price * self.cost return cost_per_share, abs(transaction.amount) * cost_per_share + + def __getstate__(self): + + state_dict = super(PerDollar, self).__getstate__() + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("PerDollar saved state is too old.") + + super(PerDollar, self).__setstate__(state) diff --git a/zipline/finance/performance/period.py b/zipline/finance/performance/period.py index 6a9ccaee..6f966259 100644 --- a/zipline/finance/performance/period.py +++ b/zipline/finance/performance/period.py @@ -92,7 +92,10 @@ from six import iteritems, itervalues import zipline.protocol as zp from . position import positiondict -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Performance') TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE @@ -595,9 +598,19 @@ class PerformancePeriod(SerializeableZiplineObject): state_dict['_positions_store'] = \ self._positions_get_state(self._positions_store) + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("PerformancePeriod saved state is too old.") + super(PerformancePeriod, self).__setstate__(state) self.initialize_position_calc_arrays() diff --git a/zipline/finance/performance/position.py b/zipline/finance/performance/position.py index b7f2a2aa..9e309683 100644 --- a/zipline/finance/performance/position.py +++ b/zipline/finance/performance/position.py @@ -41,7 +41,10 @@ from math import ( import logbook import zipline.protocol as zp -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Performance') @@ -210,7 +213,23 @@ last_sale_price: {last_sale_price}" } def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Position saved state is too old.") + + super(Position, self).__setstate__(state) class positiondict(dict): diff --git a/zipline/finance/performance/tracker.py b/zipline/finance/performance/tracker.py index d27ecf2e..9ff2a649 100644 --- a/zipline/finance/performance/tracker.py +++ b/zipline/finance/performance/tracker.py @@ -71,7 +71,10 @@ from zipline.finance import trading from . period import PerformancePeriod from zipline.finance.trading import with_environment -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Performance') @@ -493,9 +496,19 @@ class PerformanceTracker(SerializeableZiplineObject): state_dict['_dividend_count'] = self._dividend_count + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("PerformanceTracker saved state is too old.") + super(PerformanceTracker, self).__setstate__(state) # Handle the dividend frame specially diff --git a/zipline/finance/risk/cumulative.py b/zipline/finance/risk/cumulative.py index 0f5b5129..9dcc595c 100644 --- a/zipline/finance/risk/cumulative.py +++ b/zipline/finance/risk/cumulative.py @@ -35,7 +35,10 @@ from . risk import ( sortino_ratio, ) -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Risk Cumulative') @@ -462,9 +465,20 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}" {k: v for k, v in self.__dict__.iteritems() if (not k.startswith('_') and not k == 'treasury_curves')} + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("RiskMetricsCumulative \ + saved state is too old.") + super(RiskMetricsCumulative, self).__setstate__(state) # This are big and we don't need to serialize them diff --git a/zipline/finance/risk/period.py b/zipline/finance/risk/period.py index 51762e85..0a0b45e5 100644 --- a/zipline/finance/risk/period.py +++ b/zipline/finance/risk/period.py @@ -36,7 +36,10 @@ from . risk import ( sortino_ratio, ) -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Risk Period') @@ -312,9 +315,20 @@ class RiskMetricsPeriod(SerializeableZiplineObject): {k: v for k, v in self.__dict__.iteritems() if (not k.startswith('_') and not k == 'treasury_curves')} + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("RiskMetricsPeriod saved state \ + is too old.") + super(RiskMetricsPeriod, self).__setstate__(state) self.treasury_curves = trading.environment.treasury_curves diff --git a/zipline/finance/risk/report.py b/zipline/finance/risk/report.py index c7f73657..5252221b 100644 --- a/zipline/finance/risk/report.py +++ b/zipline/finance/risk/report.py @@ -61,7 +61,10 @@ from dateutil.relativedelta import relativedelta from . period import RiskMetricsPeriod -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) log = logbook.Logger('Risk Report') @@ -147,4 +150,17 @@ class RiskReport(SerializeableZiplineObject): if '_dividend_count' in dir(self): state_dict['_dividend_count'] = self._dividend_count + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("RiskReport saved state is too old.") + + super(RiskReport, self).__setstate__(state) diff --git a/zipline/finance/slippage.py b/zipline/finance/slippage.py index 656e9aad..59109c38 100644 --- a/zipline/finance/slippage.py +++ b/zipline/finance/slippage.py @@ -24,7 +24,10 @@ from functools import partial from six import with_metaclass from zipline.protocol import DATASOURCE_TYPE -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) SELL = 1 << 0 BUY = 1 << 1 @@ -130,7 +133,23 @@ class Transaction(SerializeableZiplineObject): return py def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Transaction saved state is too old.") + + super(Transaction, self).__setstate__(state) def create_transaction(event, order, price, amount): @@ -251,9 +270,22 @@ class VolumeShareSlippage(SlippageModel, SerializeableZiplineObject): ) def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("VolumeShareSlippage saved state is too old.") + self.__dict__.update(state) @@ -276,7 +308,20 @@ class FixedSlippage(SlippageModel, SerializeableZiplineObject): ) def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("FixedSlippage saved state is too old.") + self.__dict__.update(state) diff --git a/zipline/protocol.py b/zipline/protocol.py index 038ab693..84294dad 100644 --- a/zipline/protocol.py +++ b/zipline/protocol.py @@ -22,7 +22,10 @@ from . utils.math_utils import nanstd, nanmean, nansum from zipline.finance.trading import with_environment from zipline.utils.algo_instance import get_algo_instance -from zipline.utils.serialization_utils import SerializeableZiplineObject +from zipline.utils.serialization_utils import ( + SerializeableZiplineObject, + VERSION_LABEL +) # Datasource type should completely determine the other fields of a # message with its type. @@ -140,7 +143,23 @@ class Portfolio(SerializeableZiplineObject): return "Portfolio({0})".format(self.__dict__) def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Portfolio saved state is too old.") + + super(Portfolio, self).__setstate__(state) class Account(SerializeableZiplineObject): @@ -176,7 +195,23 @@ class Account(SerializeableZiplineObject): return "Account({0})".format(self.__dict__) def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Account saved state is too old.") + + super(Account, self).__setstate__(state) class Position(SerializeableZiplineObject): @@ -194,7 +229,23 @@ class Position(SerializeableZiplineObject): return "Position({0})".format(self.__dict__) def __getstate__(self): - return self.__dict__ + + state_dict = self.__dict__ + + STATE_VERSION = 1 + state_dict[VERSION_LABEL] = STATE_VERSION + + return state_dict + + def __setstate__(self, state): + + OLDEST_SUPPORTED_STATE = 1 + version = state.pop(VERSION_LABEL) + + if version < OLDEST_SUPPORTED_STATE: + raise BaseException("Protocol Position saved state is too old.") + + super(Position, self).__setstate__(state) class Positions(dict): diff --git a/zipline/utils/serialization_utils.py b/zipline/utils/serialization_utils.py index 9da19aa8..60ba5845 100644 --- a/zipline/utils/serialization_utils.py +++ b/zipline/utils/serialization_utils.py @@ -1,3 +1,8 @@ +# Label for the serialization version field in the state returned by +# __getstate__. +VERSION_LABEL = '_stateversion_' + + class SerializeableZiplineObject(object): """ This class implements the basic set and get state methods used for