ENH: Added versioning logic to objects.

In order to be able to load from saved state generated by old
code, we need to have a notion of the version of the saved state.
This commit is contained in:
Delaney Granizo-Mackenzie
2015-02-26 18:15:16 -05:00
parent 64eed84bff
commit c6596e2ee2
11 changed files with 297 additions and 17 deletions
+31 -1
View File
@@ -35,7 +35,10 @@ from zipline.finance.commission import PerShare
log = Logger('Blotter')
from zipline.utils.protocol_utils import Enum
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
ORDER_STATUS = Enum(
'OPEN',
@@ -259,8 +262,21 @@ class Blotter(SerializeableZiplineObject):
state_dict['open_orders'] = \
self._defaultdict_list_get_state(self.open_orders)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Blotter saved is state too old.")
super(Blotter, self).__setstate__(state)
class Order(SerializeableZiplineObject):
def __init__(self, dt, sid, amount, stop=None, limit=None, filled=0,
@@ -401,8 +417,22 @@ class Order(SerializeableZiplineObject):
return text_type(repr(self))
def __getstate__(self):
state_dict = super(Order, self).__getstate__()
state_dict['_status'] = self._status
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Order saved state is too old.")
super(Order, self).__setstate__(state)
+61 -1
View File
@@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
class PerShare(SerializeableZiplineObject):
@@ -52,6 +55,25 @@ class PerShare(SerializeableZiplineObject):
commission = max(commission, self.min_trade_cost)
return abs(commission / transaction.amount), commission
def __getstate__(self):
state_dict = super(PerShare, self).__getstate__()
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerShare saved state is too old.")
super(PerShare, self).__setstate__(state)
class PerTrade(SerializeableZiplineObject):
"""
@@ -79,6 +101,25 @@ class PerTrade(SerializeableZiplineObject):
return abs(self.cost / transaction.amount), self.cost
def __getstate__(self):
state_dict = super(PerTrade, self).__getstate__()
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerTrade saved state is too old.")
super(PerTrade, self).__setstate__(state)
class PerDollar(SerializeableZiplineObject):
"""
@@ -105,3 +146,22 @@ class PerDollar(SerializeableZiplineObject):
"""
cost_per_share = transaction.price * self.cost
return cost_per_share, abs(transaction.amount) * cost_per_share
def __getstate__(self):
state_dict = super(PerDollar, self).__getstate__()
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerDollar saved state is too old.")
super(PerDollar, self).__setstate__(state)
+14 -1
View File
@@ -92,7 +92,10 @@ from six import iteritems, itervalues
import zipline.protocol as zp
from . position import positiondict
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Performance')
TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE
@@ -595,9 +598,19 @@ class PerformancePeriod(SerializeableZiplineObject):
state_dict['_positions_store'] = \
self._positions_get_state(self._positions_store)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerformancePeriod saved state is too old.")
super(PerformancePeriod, self).__setstate__(state)
self.initialize_position_calc_arrays()
+21 -2
View File
@@ -41,7 +41,10 @@ from math import (
import logbook
import zipline.protocol as zp
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Performance')
@@ -210,7 +213,23 @@ last_sale_price: {last_sale_price}"
}
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Position saved state is too old.")
super(Position, self).__setstate__(state)
class positiondict(dict):
+14 -1
View File
@@ -71,7 +71,10 @@ from zipline.finance import trading
from . period import PerformancePeriod
from zipline.finance.trading import with_environment
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Performance')
@@ -493,9 +496,19 @@ class PerformanceTracker(SerializeableZiplineObject):
state_dict['_dividend_count'] = self._dividend_count
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerformanceTracker saved state is too old.")
super(PerformanceTracker, self).__setstate__(state)
# Handle the dividend frame specially
+15 -1
View File
@@ -35,7 +35,10 @@ from . risk import (
sortino_ratio,
)
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Risk Cumulative')
@@ -462,9 +465,20 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}"
{k: v for k, v in self.__dict__.iteritems() if
(not k.startswith('_') and not k == 'treasury_curves')}
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskMetricsCumulative \
saved state is too old.")
super(RiskMetricsCumulative, self).__setstate__(state)
# This are big and we don't need to serialize them
+15 -1
View File
@@ -36,7 +36,10 @@ from . risk import (
sortino_ratio,
)
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Risk Period')
@@ -312,9 +315,20 @@ class RiskMetricsPeriod(SerializeableZiplineObject):
{k: v for k, v in self.__dict__.iteritems() if
(not k.startswith('_') and not k == 'treasury_curves')}
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskMetricsPeriod saved state \
is too old.")
super(RiskMetricsPeriod, self).__setstate__(state)
self.treasury_curves = trading.environment.treasury_curves
+17 -1
View File
@@ -61,7 +61,10 @@ from dateutil.relativedelta import relativedelta
from . period import RiskMetricsPeriod
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
log = logbook.Logger('Risk Report')
@@ -147,4 +150,17 @@ class RiskReport(SerializeableZiplineObject):
if '_dividend_count' in dir(self):
state_dict['_dividend_count'] = self._dividend_count
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("RiskReport saved state is too old.")
super(RiskReport, self).__setstate__(state)
+49 -4
View File
@@ -24,7 +24,10 @@ from functools import partial
from six import with_metaclass
from zipline.protocol import DATASOURCE_TYPE
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
SELL = 1 << 0
BUY = 1 << 1
@@ -130,7 +133,23 @@ class Transaction(SerializeableZiplineObject):
return py
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Transaction saved state is too old.")
super(Transaction, self).__setstate__(state)
def create_transaction(event, order, price, amount):
@@ -251,9 +270,22 @@ class VolumeShareSlippage(SlippageModel, SerializeableZiplineObject):
)
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("VolumeShareSlippage saved state is too old.")
self.__dict__.update(state)
@@ -276,7 +308,20 @@ class FixedSlippage(SlippageModel, SerializeableZiplineObject):
)
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("FixedSlippage saved state is too old.")
self.__dict__.update(state)
+55 -4
View File
@@ -22,7 +22,10 @@ from . utils.math_utils import nanstd, nanmean, nansum
from zipline.finance.trading import with_environment
from zipline.utils.algo_instance import get_algo_instance
from zipline.utils.serialization_utils import SerializeableZiplineObject
from zipline.utils.serialization_utils import (
SerializeableZiplineObject,
VERSION_LABEL
)
# Datasource type should completely determine the other fields of a
# message with its type.
@@ -140,7 +143,23 @@ class Portfolio(SerializeableZiplineObject):
return "Portfolio({0})".format(self.__dict__)
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Portfolio saved state is too old.")
super(Portfolio, self).__setstate__(state)
class Account(SerializeableZiplineObject):
@@ -176,7 +195,23 @@ class Account(SerializeableZiplineObject):
return "Account({0})".format(self.__dict__)
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Account saved state is too old.")
super(Account, self).__setstate__(state)
class Position(SerializeableZiplineObject):
@@ -194,7 +229,23 @@ class Position(SerializeableZiplineObject):
return "Position({0})".format(self.__dict__)
def __getstate__(self):
return self.__dict__
state_dict = self.__dict__
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
return state_dict
def __setstate__(self, state):
OLDEST_SUPPORTED_STATE = 1
version = state.pop(VERSION_LABEL)
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Protocol Position saved state is too old.")
super(Position, self).__setstate__(state)
class Positions(dict):
+5
View File
@@ -1,3 +1,8 @@
# Label for the serialization version field in the state returned by
# __getstate__.
VERSION_LABEL = '_stateversion_'
class SerializeableZiplineObject(object):
"""
This class implements the basic set and get state methods used for