mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 18:21:25 +08:00
moving to class-style generators
This commit is contained in:
+30
-37
@@ -1,4 +1,7 @@
|
||||
import pytz
|
||||
from time import sleep
|
||||
|
||||
from pprint import pprint as pp
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from zipline.utils.factory import create_trading_environment
|
||||
@@ -18,53 +21,43 @@ if __name__ == "__main__":
|
||||
#Set up source a. One minute between events.
|
||||
args_a = tuple()
|
||||
kwargs_a = {
|
||||
'sids' : [1,2,3,4],
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a)
|
||||
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source b. Two minutes between events.
|
||||
args_b = tuple()
|
||||
kwargs_b = {
|
||||
'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'sids' : [2],
|
||||
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(minutes = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b)
|
||||
|
||||
source_b = SpecificEquityTrades(*args_a, **kwargs_a)
|
||||
|
||||
#Set up source c. Three minutes between events.
|
||||
args_c = tuple()
|
||||
kwargs_c = {
|
||||
'sids' : [1,2,3,4],
|
||||
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
|
||||
'delta' : timedelta(hours = 1),
|
||||
'filter' : filter
|
||||
}
|
||||
bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c)
|
||||
|
||||
# sort_out = date_sorted_sources(source_a, source_b)
|
||||
|
||||
# passthrough = TransformBundle(Passthrough, (), {})
|
||||
# mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
|
||||
# tnfm_bundles = (passthrough, mavg_price)
|
||||
|
||||
# merge_out = merged_transforms(sort_out, tnfm_bundles)
|
||||
|
||||
# # for message in merge_out:
|
||||
# # print message
|
||||
|
||||
# algo = TestAlgorithm(2, 100, 100)
|
||||
# environment = create_trading_environment(year = 2012)
|
||||
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
|
||||
|
||||
# client_out = tsc(merge_out, algo, environment, style)
|
||||
# for message in client_out:
|
||||
# pp(message)
|
||||
|
||||
source_bundles = (bundle_a, bundle_b, bundle_c)
|
||||
# Pipe our sources into sort.
|
||||
sort_out = date_sorted_sources(source_bundles)
|
||||
|
||||
passthrough = TransformBundle(Passthrough, (), {})
|
||||
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
|
||||
tnfm_bundles = (passthrough, mavg_price)
|
||||
|
||||
merge_out = merged_transforms(sort_out, tnfm_bundles)
|
||||
|
||||
# for message in merge_out:
|
||||
# print message
|
||||
|
||||
algo = TestAlgorithm(2, 100, 100)
|
||||
environment = create_trading_environment(year = 2012)
|
||||
style = zp.SIMULATION_STYLE.PARTIAL_VOLUME
|
||||
|
||||
client_out = tsc(merge_out, algo, environment, style)
|
||||
import nose.tools; nose.tools.set_trace()
|
||||
for message in client_out:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
+73
-59
@@ -49,7 +49,7 @@ def fuzzy_dates(count = 500):
|
||||
for date in date_gen(count = count):
|
||||
yield date + timedelta(seconds = random.randint(-10, 10))
|
||||
|
||||
def SpecificEquityTrades(*args, **config):
|
||||
class SpecificEquityTrades(object):
|
||||
"""
|
||||
Yields all events in event_list that match the given sid_filter.
|
||||
If no event_list is specified, generates an internal stream of events
|
||||
@@ -57,71 +57,85 @@ def SpecificEquityTrades(*args, **config):
|
||||
|
||||
Configuration options:
|
||||
|
||||
count: integer representing number of trades
|
||||
sids : list of values representing simulated internal sids
|
||||
start: start date
|
||||
delta: timedelta between internal events
|
||||
|
||||
|
||||
count : integer representing number of trades
|
||||
sids : list of values representing simulated internal sids
|
||||
start : start date
|
||||
delta : timedelta between internal events
|
||||
filter : filter to remove the sids
|
||||
"""
|
||||
# We shouldn't get any positional arguments.
|
||||
assert args == ()
|
||||
|
||||
# Unpack config dictionary with default values.
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1, 2])
|
||||
start = config.get('start', datetime(2012, 6, 6, 0))
|
||||
delta = config.get('delta', timedelta(minutes = 1))
|
||||
def __init__(self, *args, **kwargs):
|
||||
# We shouldn't get any positional arguments.
|
||||
assert len(args) == 0
|
||||
|
||||
# Unpack config dictionary with default values.
|
||||
self.count = kwargs.get('count', 500)
|
||||
self.sids = kwargs.get('sids', [1, 2])
|
||||
self.start = kwargs.get('start', datetime(2012, 6, 6, 0))
|
||||
self.delta = kwargs.get('delta', timedelta(minutes = 1))
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
self.event_list = kwargs.get('event_list')
|
||||
self.filter = kwargs.get('filter')
|
||||
|
||||
# Hash_value for downstream sorting.
|
||||
self.arg_string = hash_args(*args, **kwargs)
|
||||
|
||||
# Default to None for event_list and filter.
|
||||
event_list = config.get('event_list')
|
||||
filter = config.get('filter')
|
||||
def get_hash(self):
|
||||
return self.__class__.__name__ + "-" + self.arg_string
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.event_list:
|
||||
unfiltered = (event for event in event_list)
|
||||
|
||||
arg_string = hash_args(*args, **config)
|
||||
namestring = "SpecificEquityTrades" + arg_string
|
||||
# If we have an event_list, ignore the other arguments and use the list.
|
||||
# TODO: still append our namestring?
|
||||
if event_list:
|
||||
unfiltered = (event for event in event_list)
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count=self.count,
|
||||
start=self.start,
|
||||
delta=self.delta
|
||||
)
|
||||
prices = mock_prices(self.count)
|
||||
volumes = mock_volumes(self.count)
|
||||
sids = cycle(self.sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Set up iterators for each expected field.
|
||||
else:
|
||||
dates = date_gen(count = count, start = start, delta = delta)
|
||||
prices = mock_prices(count)
|
||||
volumes = mock_volumes(count)
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = self.get_hash())
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't
|
||||
# match the filter.
|
||||
if self.filter:
|
||||
filtered = ifilter(lambda event: event.sid in self.filter, unfiltered)
|
||||
|
||||
# Otherwise just use all events.
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Return the filtered event stream.
|
||||
return filtered
|
||||
|
||||
|
||||
# !!!!!!! Deprecated for now !!!!!!!!!
|
||||
|
||||
def RandomEquityTrades(object):
|
||||
|
||||
def __init__(self):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
self.count = config.get('count', 500)
|
||||
self.sids = config.get('sids', [1,2])
|
||||
self.filter = config.get('filter')
|
||||
|
||||
dates = fuzzy_dates(count)
|
||||
prices = mock_prices(count, rand = True)
|
||||
volumes = mock_volumes(count, rand = True)
|
||||
sids = cycle(sids)
|
||||
|
||||
# Combine the iterators into a single iterator of arguments
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
# Convert argument packages into events.
|
||||
unfiltered = (create_trade(*args, source_id = namestring)
|
||||
for args in arg_gen)
|
||||
|
||||
# If we specified a sid filter, filter out elements that don't match the filter.
|
||||
if filter:
|
||||
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
|
||||
|
||||
# Otherwise just use all events.
|
||||
else:
|
||||
filtered = unfiltered
|
||||
|
||||
# Return the filtered event stream.
|
||||
return filtered
|
||||
|
||||
def RandomEquityTrades(*args, **config):
|
||||
# We shouldn't get any positional args.
|
||||
assert args == ()
|
||||
|
||||
count = config.get('count', 500)
|
||||
sids = config.get('sids', [1,2])
|
||||
filter = config.get('filter')
|
||||
|
||||
dates = fuzzy_dates(count)
|
||||
prices = mock_prices(count, rand = True)
|
||||
volumes = mock_volumes(count, rand = True)
|
||||
sids = cycle(sids)
|
||||
|
||||
arg_gen = izip(sids, prices, volumes, dates)
|
||||
|
||||
unfiltered = (create_trade(*args) for args in arg_gen)
|
||||
|
||||
@@ -84,7 +84,7 @@ def trade_simulation_client(stream_in, algo, environment, sim_style):
|
||||
# Yields perf messages whenever it encounters them.
|
||||
perf_messages = algo_simulator(with_portfolio_and_perf_msg, sids, algo, open_orders)
|
||||
|
||||
for message in perf_messages:
|
||||
for message in perf_messages:
|
||||
yield message
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def algo_simulator(stream_in, sids, algo, order_book):
|
||||
sid=event.sid
|
||||
)
|
||||
log.debug(log)
|
||||
return
|
||||
return
|
||||
|
||||
order_book[sid].append(order)
|
||||
|
||||
@@ -123,18 +123,17 @@ def algo_simulator(stream_in, sids, algo, order_book):
|
||||
# events.
|
||||
algo.initialize()
|
||||
|
||||
this_snapshot_dt = None
|
||||
|
||||
universe = ndict()
|
||||
|
||||
for sid in sids:
|
||||
universe[sid] = ndict()
|
||||
universe.portfolio = None
|
||||
this_snapshot_dt = None
|
||||
|
||||
for event in stream_in:
|
||||
# Yield any perf messages received to be relayed back to the browser.
|
||||
if event.perf_message:
|
||||
yield event.perf_message
|
||||
del event['perf_message']
|
||||
|
||||
# This should only happen for the first event we run.
|
||||
if simulation_dt == None:
|
||||
@@ -151,7 +150,6 @@ def algo_simulator(stream_in, sids, algo, order_book):
|
||||
# If we are constructing a snapshot and we hit a new dt, call
|
||||
# handle_data and record how long it takes.
|
||||
else:
|
||||
|
||||
start_tic = datetime.now()
|
||||
algo.handle_data(universe)
|
||||
stop_tic = datetime.now()
|
||||
|
||||
Reference in New Issue
Block a user