MAINT: Refactored serialization code.

This commit is contained in:
Delaney Granizo-Mackenzie
2015-03-02 14:10:35 -05:00
parent ccbc52d803
commit ca210f0778
4 changed files with 34 additions and 37 deletions
+9 -2
View File
@@ -251,6 +251,10 @@ class Blotter(SerializeableZiplineObject):
yield txn, order
def __getinitargs__(self):
# Ensure that init is called on deserialization
return ()
def __getstate__(self):
state_to_save = ['new_orders', 'orders', '_status']
@@ -259,8 +263,7 @@ class Blotter(SerializeableZiplineObject):
if k in self.__dict__}
# Have to handle defaultdicts specially
state_dict['open_orders'] = \
self._defaultdict_list_get_state(self.open_orders)
state_dict['open_orders'] = dict(self.open_orders)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
@@ -275,6 +278,10 @@ class Blotter(SerializeableZiplineObject):
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("Blotter saved is state too old.")
open_orders = defaultdict(list)
open_orders.update(state.pop('open_orders'))
self.open_orders = open_orders
super(Blotter, self).__setstate__(state)
+21 -4
View File
@@ -590,13 +590,13 @@ class PerformancePeriod(SerializeableZiplineObject):
# msgpack will unpack it as a dict, causing KeyError
# nastiness.
state_dict['processed_transactions'] = \
self._defaultdict_list_get_state(self.processed_transactions)
dict(self.processed_transactions)
state_dict['orders_by_modified'] = \
self._defaultdict_ordered_get_state(self.orders_by_modified)
dict(self.orders_by_modified)
state_dict['positions'] = \
self._positiondict_get_state(self.positions)
dict(self.positions)
state_dict['_positions_store'] = \
self._positions_get_state(self._positions_store)
dict(self._positions_store)
STATE_VERSION = 1
state_dict[VERSION_LABEL] = STATE_VERSION
@@ -611,6 +611,23 @@ class PerformancePeriod(SerializeableZiplineObject):
if version < OLDEST_SUPPORTED_STATE:
raise BaseException("PerformancePeriod saved state is too old.")
processed_transactions = defaultdict(list)
processed_transactions.update(state.pop('processed_transactions'))
orders_by_modified = defaultdict(OrderedDict)
orders_by_modified.update(state.pop('orders_by_modified'))
positions = positiondict()
positions.update(state.pop('positions'))
_positions_store = zp.Positions()
_positions_store.update(state.pop('_positions_store'))
self.processed_transactions = processed_transactions
self.orders_by_modified = orders_by_modified
self.positions = positions
self._positions_store = _positions_store
super(PerformancePeriod, self).__setstate__(state)
self.initialize_position_calc_arrays()
+4 -3
View File
@@ -517,6 +517,7 @@ class PerformanceTracker(SerializeableZiplineObject):
# We have to restore the references to the objects,
# as the perf periods have been reconstructed as different objects
# with the same values.
self.perf_periods[0] = self.minute_performance
self.perf_periods[1] = self.cumulative_performance
self.perf_periods[2] = self.todays_performance
self.perf_periods[0] = self.cumulative_performance
self.perf_periods[1] = self.todays_performance
if self.sim_params.emission_rate == 'minute':
self.perf_periods[2] = self.minute_performance
-28
View File
@@ -40,31 +40,3 @@ class SerializeableZiplineObject(object):
Many objects require only this code.
"""
self.__dict__.update(state)
# =====================================================
# These are helper methods for some problem data types.
# =====================================================
def _defaultdict_list_get_state(self, d):
return {
'__original.type__': 'encoded.defaultdict_list',
'as_dict': dict(d)
}
def _defaultdict_ordered_get_state(self, d):
return {
'__original.type__': 'encoded.defaultdict_ordered',
'as_dict': dict(d)
}
def _positiondict_get_state(self, d):
return {
'__original.type__': 'encoded.positiondict',
'as_dict': dict(d)
}
def _positions_get_state(self, d):
return {
'__original.type__': 'encoded.Positions',
'as_dict': dict(d)
}