mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 13:10:49 +08:00
MAINT: Refactored serialization code.
This commit is contained in:
@@ -251,6 +251,10 @@ class Blotter(SerializeableZiplineObject):
|
||||
|
||||
yield txn, order
|
||||
|
||||
def __getinitargs__(self):
|
||||
# Ensure that init is called on deserialization
|
||||
return ()
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
state_to_save = ['new_orders', 'orders', '_status']
|
||||
@@ -259,8 +263,7 @@ class Blotter(SerializeableZiplineObject):
|
||||
if k in self.__dict__}
|
||||
|
||||
# Have to handle defaultdicts specially
|
||||
state_dict['open_orders'] = \
|
||||
self._defaultdict_list_get_state(self.open_orders)
|
||||
state_dict['open_orders'] = dict(self.open_orders)
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
@@ -275,6 +278,10 @@ class Blotter(SerializeableZiplineObject):
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Blotter saved is state too old.")
|
||||
|
||||
open_orders = defaultdict(list)
|
||||
open_orders.update(state.pop('open_orders'))
|
||||
self.open_orders = open_orders
|
||||
|
||||
super(Blotter, self).__setstate__(state)
|
||||
|
||||
|
||||
|
||||
@@ -590,13 +590,13 @@ class PerformancePeriod(SerializeableZiplineObject):
|
||||
# msgpack will unpack it as a dict, causing KeyError
|
||||
# nastiness.
|
||||
state_dict['processed_transactions'] = \
|
||||
self._defaultdict_list_get_state(self.processed_transactions)
|
||||
dict(self.processed_transactions)
|
||||
state_dict['orders_by_modified'] = \
|
||||
self._defaultdict_ordered_get_state(self.orders_by_modified)
|
||||
dict(self.orders_by_modified)
|
||||
state_dict['positions'] = \
|
||||
self._positiondict_get_state(self.positions)
|
||||
dict(self.positions)
|
||||
state_dict['_positions_store'] = \
|
||||
self._positions_get_state(self._positions_store)
|
||||
dict(self._positions_store)
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
@@ -611,6 +611,23 @@ class PerformancePeriod(SerializeableZiplineObject):
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerformancePeriod saved state is too old.")
|
||||
|
||||
processed_transactions = defaultdict(list)
|
||||
processed_transactions.update(state.pop('processed_transactions'))
|
||||
|
||||
orders_by_modified = defaultdict(OrderedDict)
|
||||
orders_by_modified.update(state.pop('orders_by_modified'))
|
||||
|
||||
positions = positiondict()
|
||||
positions.update(state.pop('positions'))
|
||||
|
||||
_positions_store = zp.Positions()
|
||||
_positions_store.update(state.pop('_positions_store'))
|
||||
|
||||
self.processed_transactions = processed_transactions
|
||||
self.orders_by_modified = orders_by_modified
|
||||
self.positions = positions
|
||||
self._positions_store = _positions_store
|
||||
|
||||
super(PerformancePeriod, self).__setstate__(state)
|
||||
|
||||
self.initialize_position_calc_arrays()
|
||||
|
||||
@@ -517,6 +517,7 @@ class PerformanceTracker(SerializeableZiplineObject):
|
||||
# We have to restore the references to the objects,
|
||||
# as the perf periods have been reconstructed as different objects
|
||||
# with the same values.
|
||||
self.perf_periods[0] = self.minute_performance
|
||||
self.perf_periods[1] = self.cumulative_performance
|
||||
self.perf_periods[2] = self.todays_performance
|
||||
self.perf_periods[0] = self.cumulative_performance
|
||||
self.perf_periods[1] = self.todays_performance
|
||||
if self.sim_params.emission_rate == 'minute':
|
||||
self.perf_periods[2] = self.minute_performance
|
||||
|
||||
@@ -40,31 +40,3 @@ class SerializeableZiplineObject(object):
|
||||
Many objects require only this code.
|
||||
"""
|
||||
self.__dict__.update(state)
|
||||
|
||||
# =====================================================
|
||||
# These are helper methods for some problem data types.
|
||||
# =====================================================
|
||||
|
||||
def _defaultdict_list_get_state(self, d):
|
||||
return {
|
||||
'__original.type__': 'encoded.defaultdict_list',
|
||||
'as_dict': dict(d)
|
||||
}
|
||||
|
||||
def _defaultdict_ordered_get_state(self, d):
|
||||
return {
|
||||
'__original.type__': 'encoded.defaultdict_ordered',
|
||||
'as_dict': dict(d)
|
||||
}
|
||||
|
||||
def _positiondict_get_state(self, d):
|
||||
return {
|
||||
'__original.type__': 'encoded.positiondict',
|
||||
'as_dict': dict(d)
|
||||
}
|
||||
|
||||
def _positions_get_state(self, d):
|
||||
return {
|
||||
'__original.type__': 'encoded.Positions',
|
||||
'as_dict': dict(d)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user