mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 09:42:50 +08:00
DEP: Removes sids field from SimulationParameters
This commit is contained in:
@@ -138,7 +138,6 @@ class TestMiscellaneousAPI(TestCase):
|
||||
sids = [1, 2]
|
||||
self.sim_params = factory.create_simulation_parameters(
|
||||
num_days=2,
|
||||
sids=sids,
|
||||
data_frequency='minute',
|
||||
emission_rate='minute',
|
||||
)
|
||||
|
||||
@@ -120,13 +120,11 @@ class TransformTestCase(TestCase):
|
||||
|
||||
minute_sim_ps = factory.create_simulation_parameters(
|
||||
num_days=3,
|
||||
sids=cls.sids,
|
||||
data_frequency='minute',
|
||||
emission_rate='minute',
|
||||
)
|
||||
daily_sim_ps = factory.create_simulation_parameters(
|
||||
num_days=30,
|
||||
sids=cls.sids,
|
||||
data_frequency='daily',
|
||||
emission_rate='daily',
|
||||
)
|
||||
|
||||
+14
-15
@@ -477,23 +477,22 @@ class TradingAlgorithm(object):
|
||||
self.sim_params.period_start = source.start
|
||||
if hasattr(source, 'end'):
|
||||
self.sim_params.period_end = source.end
|
||||
# The sids field of the source is the canonical reference for
|
||||
# sids in this run
|
||||
all_sids = [sid for s in self.sources for sid in s.sids]
|
||||
self.sim_params.sids = set(all_sids)
|
||||
# Check that all sids from the source are accounted for in
|
||||
# the AssetFinder
|
||||
for sid in self.sim_params.sids:
|
||||
try:
|
||||
self.asset_finder.retrieve_asset(sid)
|
||||
except SidNotFound:
|
||||
warnings.warn("No Asset found for sid '%s'. Make sure "
|
||||
"that the correct identifiers and asset "
|
||||
"metadata are passed to __init__()." % sid)
|
||||
# Changing period_start and period_close might require updating
|
||||
# of first_open and last_close.
|
||||
self.sim_params._update_internal()
|
||||
|
||||
# The sids field of the source is the reference for the universe at
|
||||
# the start of the run
|
||||
self._current_universe = set()
|
||||
for source in self.sources:
|
||||
for sid in source.sids:
|
||||
self._current_universe.add(sid)
|
||||
# Check that all sids from the source are accounted for in
|
||||
# the AssetFinder. This retrieve call will raise an exception if the
|
||||
# sid is not found.
|
||||
for sid in self._current_universe:
|
||||
self.asset_finder.retrieve_asset(sid)
|
||||
|
||||
# force a reset of the performance tracker, in case
|
||||
# this is a repeat run of the algorithm.
|
||||
self.perf_tracker = None
|
||||
@@ -505,7 +504,7 @@ class TradingAlgorithm(object):
|
||||
if self.history_specs:
|
||||
self.history_container = self.history_container_class(
|
||||
self.history_specs,
|
||||
self.sim_params.sids,
|
||||
self.current_universe(),
|
||||
self.sim_params.first_open,
|
||||
self.sim_params.data_frequency,
|
||||
)
|
||||
@@ -1183,7 +1182,7 @@ class TradingAlgorithm(object):
|
||||
self.register_trading_control(LongOnly())
|
||||
|
||||
def current_universe(self):
|
||||
return self.sim_params.sids
|
||||
return self._current_universe
|
||||
|
||||
@classmethod
|
||||
def all_api_methods(cls):
|
||||
|
||||
@@ -440,8 +440,7 @@ class SimulationParameters(object):
|
||||
def __init__(self, period_start, period_end,
|
||||
capital_base=10e3,
|
||||
emission_rate='daily',
|
||||
data_frequency='daily',
|
||||
sids=None):
|
||||
data_frequency='daily'):
|
||||
|
||||
self.period_start = period_start
|
||||
self.period_end = period_end
|
||||
@@ -449,7 +448,6 @@ class SimulationParameters(object):
|
||||
|
||||
self.emission_rate = emission_rate
|
||||
self.data_frequency = data_frequency
|
||||
self.sids = sids
|
||||
|
||||
# copied to algorithm's environment for runtime access
|
||||
self.arena = 'backtest'
|
||||
|
||||
@@ -48,8 +48,9 @@ class DataFrameSource(DataSource):
|
||||
self.end = kwargs.get('end', self.data.index[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
self.identifiers = kwargs.get('sids', self.data.columns)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
env.update_asset_finder(
|
||||
identifiers=kwargs.get('sids', self.data.columns)
|
||||
)
|
||||
self.data.columns, _ = env.asset_finder.lookup_generic(
|
||||
self.data.columns, datetime.datetime.now()
|
||||
)
|
||||
@@ -78,22 +79,21 @@ class DataFrameSource(DataSource):
|
||||
def raw_data_gen(self):
|
||||
for dt, series in self.data.iterrows():
|
||||
for sid, price in series.iteritems():
|
||||
if sid in self.sids:
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(price) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(price) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
'price': price,
|
||||
# Just chose something large
|
||||
# if no volume available.
|
||||
'volume': 1e9,
|
||||
}
|
||||
yield event
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
'price': price,
|
||||
# Just chose something large
|
||||
# if no volume available.
|
||||
'volume': 1e9,
|
||||
}
|
||||
yield event
|
||||
|
||||
@property
|
||||
def raw_data(self):
|
||||
@@ -126,8 +126,9 @@ class DataPanelSource(DataSource):
|
||||
self.end = kwargs.get('end', self.data.major_axis[-1])
|
||||
|
||||
# Remap sids based on the trading environment
|
||||
self.identifiers = kwargs.get('sids', self.data.items)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
env.update_asset_finder(
|
||||
identifiers=kwargs.get('sids', self.data.items)
|
||||
)
|
||||
self.data.items, _ = env.asset_finder.lookup_generic(
|
||||
self.data.items, datetime.datetime.now()
|
||||
)
|
||||
@@ -165,21 +166,20 @@ class DataPanelSource(DataSource):
|
||||
for dt in self.data.major_axis:
|
||||
df = self.data.major_xs(dt)
|
||||
for sid, series in df.iteritems():
|
||||
if sid in self.sids:
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(series['price']) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
# Skip SIDs that can not be forward filled
|
||||
if np.isnan(series['price']) and \
|
||||
sid not in self.started_sids:
|
||||
continue
|
||||
self.started_sids.add(sid)
|
||||
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
}
|
||||
for field_name, value in series.iteritems():
|
||||
event[field_name] = value
|
||||
event = {
|
||||
'dt': dt,
|
||||
'sid': sid,
|
||||
}
|
||||
for field_name, value in series.iteritems():
|
||||
event[field_name] = value
|
||||
|
||||
yield event
|
||||
yield event
|
||||
|
||||
@property
|
||||
def raw_data(self):
|
||||
|
||||
@@ -43,7 +43,7 @@ __all__ = ['load_from_yahoo', 'load_bars_from_yahoo']
|
||||
def create_simulation_parameters(year=2006, start=None, end=None,
|
||||
capital_base=float("1.0e5"),
|
||||
num_days=None, load=None,
|
||||
sids=None, data_frequency='daily',
|
||||
data_frequency='daily',
|
||||
emission_rate='daily'):
|
||||
"""Construct a complete environment with reasonable defaults"""
|
||||
if start is None:
|
||||
@@ -60,7 +60,6 @@ def create_simulation_parameters(year=2006, start=None, end=None,
|
||||
period_start=start,
|
||||
period_end=end,
|
||||
capital_base=capital_base,
|
||||
sids=sids,
|
||||
data_frequency=data_frequency,
|
||||
emission_rate=emission_rate,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user