mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 10:50:00 +08:00
ENH: Added versioning logic to objects.
In order to be able to load from saved state generated by old code, we need to have a notion of the version of the saved state.
This commit is contained in:
@@ -35,7 +35,10 @@ from zipline.finance.commission import PerShare
|
||||
log = Logger('Blotter')
|
||||
|
||||
from zipline.utils.protocol_utils import Enum
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
ORDER_STATUS = Enum(
|
||||
'OPEN',
|
||||
@@ -259,8 +262,21 @@ class Blotter(SerializeableZiplineObject):
|
||||
state_dict['open_orders'] = \
|
||||
self._defaultdict_list_get_state(self.open_orders)
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Blotter saved is state too old.")
|
||||
|
||||
super(Blotter, self).__setstate__(state)
|
||||
|
||||
|
||||
class Order(SerializeableZiplineObject):
|
||||
def __init__(self, dt, sid, amount, stop=None, limit=None, filled=0,
|
||||
@@ -401,8 +417,22 @@ class Order(SerializeableZiplineObject):
|
||||
return text_type(repr(self))
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
state_dict = super(Order, self).__getstate__()
|
||||
|
||||
state_dict['_status'] = self._status
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Order saved state is too old.")
|
||||
|
||||
super(Order, self).__setstate__(state)
|
||||
|
||||
@@ -13,7 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
|
||||
class PerShare(SerializeableZiplineObject):
|
||||
@@ -52,6 +55,25 @@ class PerShare(SerializeableZiplineObject):
|
||||
commission = max(commission, self.min_trade_cost)
|
||||
return abs(commission / transaction.amount), commission
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
state_dict = super(PerShare, self).__getstate__()
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerShare saved state is too old.")
|
||||
|
||||
super(PerShare, self).__setstate__(state)
|
||||
|
||||
|
||||
class PerTrade(SerializeableZiplineObject):
|
||||
"""
|
||||
@@ -79,6 +101,25 @@ class PerTrade(SerializeableZiplineObject):
|
||||
|
||||
return abs(self.cost / transaction.amount), self.cost
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
state_dict = super(PerTrade, self).__getstate__()
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerTrade saved state is too old.")
|
||||
|
||||
super(PerTrade, self).__setstate__(state)
|
||||
|
||||
|
||||
class PerDollar(SerializeableZiplineObject):
|
||||
"""
|
||||
@@ -105,3 +146,22 @@ class PerDollar(SerializeableZiplineObject):
|
||||
"""
|
||||
cost_per_share = transaction.price * self.cost
|
||||
return cost_per_share, abs(transaction.amount) * cost_per_share
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
state_dict = super(PerDollar, self).__getstate__()
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerDollar saved state is too old.")
|
||||
|
||||
super(PerDollar, self).__setstate__(state)
|
||||
|
||||
@@ -92,7 +92,10 @@ from six import iteritems, itervalues
|
||||
import zipline.protocol as zp
|
||||
from . position import positiondict
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Performance')
|
||||
TRADE_TYPE = zp.DATASOURCE_TYPE.TRADE
|
||||
@@ -595,9 +598,19 @@ class PerformancePeriod(SerializeableZiplineObject):
|
||||
state_dict['_positions_store'] = \
|
||||
self._positions_get_state(self._positions_store)
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerformancePeriod saved state is too old.")
|
||||
|
||||
super(PerformancePeriod, self).__setstate__(state)
|
||||
|
||||
self.initialize_position_calc_arrays()
|
||||
|
||||
@@ -41,7 +41,10 @@ from math import (
|
||||
import logbook
|
||||
import zipline.protocol as zp
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Performance')
|
||||
|
||||
@@ -210,7 +213,23 @@ last_sale_price: {last_sale_price}"
|
||||
}
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Position saved state is too old.")
|
||||
|
||||
super(Position, self).__setstate__(state)
|
||||
|
||||
|
||||
class positiondict(dict):
|
||||
|
||||
@@ -71,7 +71,10 @@ from zipline.finance import trading
|
||||
from . period import PerformancePeriod
|
||||
|
||||
from zipline.finance.trading import with_environment
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Performance')
|
||||
|
||||
@@ -493,9 +496,19 @@ class PerformanceTracker(SerializeableZiplineObject):
|
||||
|
||||
state_dict['_dividend_count'] = self._dividend_count
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("PerformanceTracker saved state is too old.")
|
||||
|
||||
super(PerformanceTracker, self).__setstate__(state)
|
||||
|
||||
# Handle the dividend frame specially
|
||||
|
||||
@@ -35,7 +35,10 @@ from . risk import (
|
||||
sortino_ratio,
|
||||
)
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Risk Cumulative')
|
||||
|
||||
@@ -462,9 +465,20 @@ algorithm_returns ({algo_count}) in range {start} : {end} on {dt}"
|
||||
{k: v for k, v in self.__dict__.iteritems() if
|
||||
(not k.startswith('_') and not k == 'treasury_curves')}
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("RiskMetricsCumulative \
|
||||
saved state is too old.")
|
||||
|
||||
super(RiskMetricsCumulative, self).__setstate__(state)
|
||||
|
||||
# This are big and we don't need to serialize them
|
||||
|
||||
@@ -36,7 +36,10 @@ from . risk import (
|
||||
sortino_ratio,
|
||||
)
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Risk Period')
|
||||
|
||||
@@ -312,9 +315,20 @@ class RiskMetricsPeriod(SerializeableZiplineObject):
|
||||
{k: v for k, v in self.__dict__.iteritems() if
|
||||
(not k.startswith('_') and not k == 'treasury_curves')}
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("RiskMetricsPeriod saved state \
|
||||
is too old.")
|
||||
|
||||
super(RiskMetricsPeriod, self).__setstate__(state)
|
||||
|
||||
self.treasury_curves = trading.environment.treasury_curves
|
||||
|
||||
@@ -61,7 +61,10 @@ from dateutil.relativedelta import relativedelta
|
||||
|
||||
from . period import RiskMetricsPeriod
|
||||
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
log = logbook.Logger('Risk Report')
|
||||
|
||||
@@ -147,4 +150,17 @@ class RiskReport(SerializeableZiplineObject):
|
||||
if '_dividend_count' in dir(self):
|
||||
state_dict['_dividend_count'] = self._dividend_count
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("RiskReport saved state is too old.")
|
||||
|
||||
super(RiskReport, self).__setstate__(state)
|
||||
|
||||
@@ -24,7 +24,10 @@ from functools import partial
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
SELL = 1 << 0
|
||||
BUY = 1 << 1
|
||||
@@ -130,7 +133,23 @@ class Transaction(SerializeableZiplineObject):
|
||||
return py
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Transaction saved state is too old.")
|
||||
|
||||
super(Transaction, self).__setstate__(state)
|
||||
|
||||
|
||||
def create_transaction(event, order, price, amount):
|
||||
@@ -251,9 +270,22 @@ class VolumeShareSlippage(SlippageModel, SerializeableZiplineObject):
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("VolumeShareSlippage saved state is too old.")
|
||||
|
||||
self.__dict__.update(state)
|
||||
|
||||
|
||||
@@ -276,7 +308,20 @@ class FixedSlippage(SlippageModel, SerializeableZiplineObject):
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("FixedSlippage saved state is too old.")
|
||||
|
||||
self.__dict__.update(state)
|
||||
|
||||
+55
-4
@@ -22,7 +22,10 @@ from . utils.math_utils import nanstd, nanmean, nansum
|
||||
|
||||
from zipline.finance.trading import with_environment
|
||||
from zipline.utils.algo_instance import get_algo_instance
|
||||
from zipline.utils.serialization_utils import SerializeableZiplineObject
|
||||
from zipline.utils.serialization_utils import (
|
||||
SerializeableZiplineObject,
|
||||
VERSION_LABEL
|
||||
)
|
||||
|
||||
# Datasource type should completely determine the other fields of a
|
||||
# message with its type.
|
||||
@@ -140,7 +143,23 @@ class Portfolio(SerializeableZiplineObject):
|
||||
return "Portfolio({0})".format(self.__dict__)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Portfolio saved state is too old.")
|
||||
|
||||
super(Portfolio, self).__setstate__(state)
|
||||
|
||||
|
||||
class Account(SerializeableZiplineObject):
|
||||
@@ -176,7 +195,23 @@ class Account(SerializeableZiplineObject):
|
||||
return "Account({0})".format(self.__dict__)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Account saved state is too old.")
|
||||
|
||||
super(Account, self).__setstate__(state)
|
||||
|
||||
|
||||
class Position(SerializeableZiplineObject):
|
||||
@@ -194,7 +229,23 @@ class Position(SerializeableZiplineObject):
|
||||
return "Position({0})".format(self.__dict__)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
||||
state_dict = self.__dict__
|
||||
|
||||
STATE_VERSION = 1
|
||||
state_dict[VERSION_LABEL] = STATE_VERSION
|
||||
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
OLDEST_SUPPORTED_STATE = 1
|
||||
version = state.pop(VERSION_LABEL)
|
||||
|
||||
if version < OLDEST_SUPPORTED_STATE:
|
||||
raise BaseException("Protocol Position saved state is too old.")
|
||||
|
||||
super(Position, self).__setstate__(state)
|
||||
|
||||
|
||||
class Positions(dict):
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
# Label for the serialization version field in the state returned by
|
||||
# __getstate__.
|
||||
VERSION_LABEL = '_stateversion_'
|
||||
|
||||
|
||||
class SerializeableZiplineObject(object):
|
||||
"""
|
||||
This class implements the basic set and get state methods used for
|
||||
|
||||
Reference in New Issue
Block a user