moving to class-style generators

This commit is contained in:
scottsanderson
2012-08-02 14:44:11 -04:00
parent bf3d9cef02
commit 8bcaed9044
3 changed files with 107 additions and 102 deletions
+30 -37
View File
@@ -1,4 +1,7 @@
import pytz
from time import sleep
from pprint import pprint as pp
from datetime import datetime, timedelta
from zipline.utils.factory import create_trading_environment
@@ -18,53 +21,43 @@ if __name__ == "__main__":
#Set up source a. One minute between events.
args_a = tuple()
kwargs_a = {
'sids' : [1,2,3,4],
'sids' : [2],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 1),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a)
source_a = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source b. Two minutes between events.
args_b = tuple()
kwargs_b = {
'sids' : [1,2,3,4],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 1),
'sids' : [2],
'start' : datetime(2012,1,3,14, tzinfo = pytz.utc),
'delta' : timedelta(minutes = 1),
'filter' : filter
}
bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b)
source_b = SpecificEquityTrades(*args_a, **kwargs_a)
#Set up source c. Three minutes between events.
args_c = tuple()
kwargs_c = {
'sids' : [1,2,3,4],
'start' : datetime(2012,1,3,15, tzinfo = pytz.utc),
'delta' : timedelta(hours = 1),
'filter' : filter
}
bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c)
# sort_out = date_sorted_sources(source_a, source_b)
# passthrough = TransformBundle(Passthrough, (), {})
# mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
# tnfm_bundles = (passthrough, mavg_price)
# merge_out = merged_transforms(sort_out, tnfm_bundles)
# # for message in merge_out:
# # print message
# algo = TestAlgorithm(2, 100, 100)
# environment = create_trading_environment(year = 2012)
# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE
# client_out = tsc(merge_out, algo, environment, style)
# for message in client_out:
# pp(message)
source_bundles = (bundle_a, bundle_b, bundle_c)
# Pipe our sources into sort.
sort_out = date_sorted_sources(source_bundles)
passthrough = TransformBundle(Passthrough, (), {})
mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {})
tnfm_bundles = (passthrough, mavg_price)
merge_out = merged_transforms(sort_out, tnfm_bundles)
# for message in merge_out:
# print message
algo = TestAlgorithm(2, 100, 100)
environment = create_trading_environment(year = 2012)
style = zp.SIMULATION_STYLE.PARTIAL_VOLUME
client_out = tsc(merge_out, algo, environment, style)
import nose.tools; nose.tools.set_trace()
for message in client_out:
pass
+73 -59
View File
@@ -49,7 +49,7 @@ def fuzzy_dates(count = 500):
for date in date_gen(count = count):
yield date + timedelta(seconds = random.randint(-10, 10))
def SpecificEquityTrades(*args, **config):
class SpecificEquityTrades(object):
"""
Yields all events in event_list that match the given sid_filter.
If no event_list is specified, generates an internal stream of events
@@ -57,71 +57,85 @@ def SpecificEquityTrades(*args, **config):
Configuration options:
count: integer representing number of trades
sids : list of values representing simulated internal sids
start: start date
delta: timedelta between internal events
count : integer representing number of trades
sids : list of values representing simulated internal sids
start : start date
delta : timedelta between internal events
filter : filter to remove the sids
"""
# We shouldn't get any positional arguments.
assert args == ()
# Unpack config dictionary with default values.
count = config.get('count', 500)
sids = config.get('sids', [1, 2])
start = config.get('start', datetime(2012, 6, 6, 0))
delta = config.get('delta', timedelta(minutes = 1))
def __init__(self, *args, **kwargs):
# We shouldn't get any positional arguments.
assert len(args) == 0
# Unpack config dictionary with default values.
self.count = kwargs.get('count', 500)
self.sids = kwargs.get('sids', [1, 2])
self.start = kwargs.get('start', datetime(2012, 6, 6, 0))
self.delta = kwargs.get('delta', timedelta(minutes = 1))
# Default to None for event_list and filter.
self.event_list = kwargs.get('event_list')
self.filter = kwargs.get('filter')
# Hash_value for downstream sorting.
self.arg_string = hash_args(*args, **kwargs)
# Default to None for event_list and filter.
event_list = config.get('event_list')
filter = config.get('filter')
def get_hash(self):
return self.__class__.__name__ + "-" + self.arg_string
def __iter__(self):
if self.event_list:
unfiltered = (event for event in event_list)
arg_string = hash_args(*args, **config)
namestring = "SpecificEquityTrades" + arg_string
# If we have an event_list, ignore the other arguments and use the list.
# TODO: still append our namestring?
if event_list:
unfiltered = (event for event in event_list)
# Set up iterators for each expected field.
else:
dates = date_gen(count=self.count,
start=self.start,
delta=self.delta
)
prices = mock_prices(self.count)
volumes = mock_volumes(self.count)
sids = cycle(self.sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
# Set up iterators for each expected field.
else:
dates = date_gen(count = count, start = start, delta = delta)
prices = mock_prices(count)
volumes = mock_volumes(count)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id = self.get_hash())
for args in arg_gen)
# If we specified a sid filter, filter out elements that don't
# match the filter.
if self.filter:
filtered = ifilter(lambda event: event.sid in self.filter, unfiltered)
# Otherwise just use all events.
else:
filtered = unfiltered
# Return the filtered event stream.
return filtered
# !!!!!!! Deprecated for now !!!!!!!!!
def RandomEquityTrades(object):
def __init__(self):
# We shouldn't get any positional args.
assert args == ()
self.count = config.get('count', 500)
self.sids = config.get('sids', [1,2])
self.filter = config.get('filter')
dates = fuzzy_dates(count)
prices = mock_prices(count, rand = True)
volumes = mock_volumes(count, rand = True)
sids = cycle(sids)
# Combine the iterators into a single iterator of arguments
arg_gen = izip(sids, prices, volumes, dates)
# Convert argument packages into events.
unfiltered = (create_trade(*args, source_id = namestring)
for args in arg_gen)
# If we specified a sid filter, filter out elements that don't match the filter.
if filter:
filtered = ifilter(lambda event: event.sid in filter, unfiltered)
# Otherwise just use all events.
else:
filtered = unfiltered
# Return the filtered event stream.
return filtered
def RandomEquityTrades(*args, **config):
# We shouldn't get any positional args.
assert args == ()
count = config.get('count', 500)
sids = config.get('sids', [1,2])
filter = config.get('filter')
dates = fuzzy_dates(count)
prices = mock_prices(count, rand = True)
volumes = mock_volumes(count, rand = True)
sids = cycle(sids)
arg_gen = izip(sids, prices, volumes, dates)
unfiltered = (create_trade(*args) for args in arg_gen)
+4 -6
View File
@@ -84,7 +84,7 @@ def trade_simulation_client(stream_in, algo, environment, sim_style):
# Yields perf messages whenever it encounters them.
perf_messages = algo_simulator(with_portfolio_and_perf_msg, sids, algo, open_orders)
for message in perf_messages:
for message in perf_messages:
yield message
@@ -109,7 +109,7 @@ def algo_simulator(stream_in, sids, algo, order_book):
sid=event.sid
)
log.debug(log)
return
return
order_book[sid].append(order)
@@ -123,18 +123,17 @@ def algo_simulator(stream_in, sids, algo, order_book):
# events.
algo.initialize()
this_snapshot_dt = None
universe = ndict()
for sid in sids:
universe[sid] = ndict()
universe.portfolio = None
this_snapshot_dt = None
for event in stream_in:
# Yield any perf messages received to be relayed back to the browser.
if event.perf_message:
yield event.perf_message
del event['perf_message']
# This should only happen for the first event we run.
if simulation_dt == None:
@@ -151,7 +150,6 @@ def algo_simulator(stream_in, sids, algo, order_book):
# If we are constructing a snapshot and we hit a new dt, call
# handle_data and record how long it takes.
else:
start_tic = datetime.now()
algo.handle_data(universe)
stop_tic = datetime.now()