From 8bcaed9044d67480194a7a6e19d806946a7fbbbe Mon Sep 17 00:00:00 2001 From: scottsanderson Date: Thu, 2 Aug 2012 14:44:11 -0400 Subject: [PATCH] moving to class-style generators --- zipline/gens/examples.py | 67 ++++++++-------- zipline/gens/tradegens.py | 132 ++++++++++++++++++-------------- zipline/gens/tradesimulation.py | 10 +-- 3 files changed, 107 insertions(+), 102 deletions(-) diff --git a/zipline/gens/examples.py b/zipline/gens/examples.py index 967d0808..50b955e4 100644 --- a/zipline/gens/examples.py +++ b/zipline/gens/examples.py @@ -1,4 +1,7 @@ import pytz +from time import sleep + +from pprint import pprint as pp from datetime import datetime, timedelta from zipline.utils.factory import create_trading_environment @@ -18,53 +21,43 @@ if __name__ == "__main__": #Set up source a. One minute between events. args_a = tuple() kwargs_a = { - 'sids' : [1,2,3,4], + 'sids' : [2], 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), - 'delta' : timedelta(hours = 1), + 'delta' : timedelta(minutes = 1), 'filter' : filter } - bundle_a = SourceBundle(SpecificEquityTrades, args_a, kwargs_a) + source_a = SpecificEquityTrades(*args_a, **kwargs_a) #Set up source b. Two minutes between events. args_b = tuple() kwargs_b = { - 'sids' : [1,2,3,4], - 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), - 'delta' : timedelta(hours = 1), + 'sids' : [2], + 'start' : datetime(2012,1,3,14, tzinfo = pytz.utc), + 'delta' : timedelta(minutes = 1), 'filter' : filter } - bundle_b = SourceBundle(SpecificEquityTrades, args_b, kwargs_b) - + source_b = SpecificEquityTrades(*args_a, **kwargs_a) + #Set up source c. Three minutes between events. - args_c = tuple() - kwargs_c = { - 'sids' : [1,2,3,4], - 'start' : datetime(2012,1,3,15, tzinfo = pytz.utc), - 'delta' : timedelta(hours = 1), - 'filter' : filter - } - bundle_c = SourceBundle(SpecificEquityTrades, args_c, kwargs_c) + + # sort_out = date_sorted_sources(source_a, source_b) + +# passthrough = TransformBundle(Passthrough, (), {}) +# mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {}) +# tnfm_bundles = (passthrough, mavg_price) + +# merge_out = merged_transforms(sort_out, tnfm_bundles) + +# # for message in merge_out: +# # print message + +# algo = TestAlgorithm(2, 100, 100) +# environment = create_trading_environment(year = 2012) +# style = zp.SIMULATION_STYLE.FIXED_SLIPPAGE + +# client_out = tsc(merge_out, algo, environment, style) +# for message in client_out: + # pp(message) - source_bundles = (bundle_a, bundle_b, bundle_c) - # Pipe our sources into sort. - sort_out = date_sorted_sources(source_bundles) - - passthrough = TransformBundle(Passthrough, (), {}) - mavg_price = TransformBundle(MovingAverage, (timedelta(minutes = 20), ['price']), {}) - tnfm_bundles = (passthrough, mavg_price) - - merge_out = merged_transforms(sort_out, tnfm_bundles) - - # for message in merge_out: -# print message - - algo = TestAlgorithm(2, 100, 100) - environment = create_trading_environment(year = 2012) - style = zp.SIMULATION_STYLE.PARTIAL_VOLUME - - client_out = tsc(merge_out, algo, environment, style) - import nose.tools; nose.tools.set_trace() - for message in client_out: - pass diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 7420e1b4..8552b530 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -49,7 +49,7 @@ def fuzzy_dates(count = 500): for date in date_gen(count = count): yield date + timedelta(seconds = random.randint(-10, 10)) -def SpecificEquityTrades(*args, **config): +class SpecificEquityTrades(object): """ Yields all events in event_list that match the given sid_filter. If no event_list is specified, generates an internal stream of events @@ -57,71 +57,85 @@ def SpecificEquityTrades(*args, **config): Configuration options: - count: integer representing number of trades - sids : list of values representing simulated internal sids - start: start date - delta: timedelta between internal events - - + count : integer representing number of trades + sids : list of values representing simulated internal sids + start : start date + delta : timedelta between internal events + filter : filter to remove the sids """ - # We shouldn't get any positional arguments. - assert args == () - # Unpack config dictionary with default values. - count = config.get('count', 500) - sids = config.get('sids', [1, 2]) - start = config.get('start', datetime(2012, 6, 6, 0)) - delta = config.get('delta', timedelta(minutes = 1)) + def __init__(self, *args, **kwargs): + # We shouldn't get any positional arguments. + assert len(args) == 0 + + # Unpack config dictionary with default values. + self.count = kwargs.get('count', 500) + self.sids = kwargs.get('sids', [1, 2]) + self.start = kwargs.get('start', datetime(2012, 6, 6, 0)) + self.delta = kwargs.get('delta', timedelta(minutes = 1)) + + # Default to None for event_list and filter. + self.event_list = kwargs.get('event_list') + self.filter = kwargs.get('filter') + + # Hash_value for downstream sorting. + self.arg_string = hash_args(*args, **kwargs) - # Default to None for event_list and filter. - event_list = config.get('event_list') - filter = config.get('filter') + def get_hash(self): + return self.__class__.__name__ + "-" + self.arg_string + + def __iter__(self): + + if self.event_list: + unfiltered = (event for event in event_list) - arg_string = hash_args(*args, **config) - namestring = "SpecificEquityTrades" + arg_string - # If we have an event_list, ignore the other arguments and use the list. - # TODO: still append our namestring? - if event_list: - unfiltered = (event for event in event_list) + # Set up iterators for each expected field. + else: + dates = date_gen(count=self.count, + start=self.start, + delta=self.delta + ) + prices = mock_prices(self.count) + volumes = mock_volumes(self.count) + sids = cycle(self.sids) + + # Combine the iterators into a single iterator of arguments + arg_gen = izip(sids, prices, volumes, dates) - # Set up iterators for each expected field. - else: - dates = date_gen(count = count, start = start, delta = delta) - prices = mock_prices(count) - volumes = mock_volumes(count) + # Convert argument packages into events. + unfiltered = (create_trade(*args, source_id = self.get_hash()) + for args in arg_gen) + + # If we specified a sid filter, filter out elements that don't + # match the filter. + if self.filter: + filtered = ifilter(lambda event: event.sid in self.filter, unfiltered) + + # Otherwise just use all events. + else: + filtered = unfiltered + + # Return the filtered event stream. + return filtered + + +# !!!!!!! Deprecated for now !!!!!!!!! + +def RandomEquityTrades(object): + + def __init__(self): + # We shouldn't get any positional args. + assert args == () + + self.count = config.get('count', 500) + self.sids = config.get('sids', [1,2]) + self.filter = config.get('filter') + + dates = fuzzy_dates(count) + prices = mock_prices(count, rand = True) + volumes = mock_volumes(count, rand = True) sids = cycle(sids) - # Combine the iterators into a single iterator of arguments - arg_gen = izip(sids, prices, volumes, dates) - - # Convert argument packages into events. - unfiltered = (create_trade(*args, source_id = namestring) - for args in arg_gen) - - # If we specified a sid filter, filter out elements that don't match the filter. - if filter: - filtered = ifilter(lambda event: event.sid in filter, unfiltered) - - # Otherwise just use all events. - else: - filtered = unfiltered - - # Return the filtered event stream. - return filtered - -def RandomEquityTrades(*args, **config): - # We shouldn't get any positional args. - assert args == () - - count = config.get('count', 500) - sids = config.get('sids', [1,2]) - filter = config.get('filter') - - dates = fuzzy_dates(count) - prices = mock_prices(count, rand = True) - volumes = mock_volumes(count, rand = True) - sids = cycle(sids) - arg_gen = izip(sids, prices, volumes, dates) unfiltered = (create_trade(*args) for args in arg_gen) diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index 74048982..04567fac 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -84,7 +84,7 @@ def trade_simulation_client(stream_in, algo, environment, sim_style): # Yields perf messages whenever it encounters them. perf_messages = algo_simulator(with_portfolio_and_perf_msg, sids, algo, open_orders) - for message in perf_messages: + for message in perf_messages: yield message @@ -109,7 +109,7 @@ def algo_simulator(stream_in, sids, algo, order_book): sid=event.sid ) log.debug(log) - return + return order_book[sid].append(order) @@ -123,18 +123,17 @@ def algo_simulator(stream_in, sids, algo, order_book): # events. algo.initialize() - this_snapshot_dt = None - universe = ndict() - for sid in sids: universe[sid] = ndict() universe.portfolio = None + this_snapshot_dt = None for event in stream_in: # Yield any perf messages received to be relayed back to the browser. if event.perf_message: yield event.perf_message + del event['perf_message'] # This should only happen for the first event we run. if simulation_dt == None: @@ -151,7 +150,6 @@ def algo_simulator(stream_in, sids, algo, order_book): # If we are constructing a snapshot and we hit a new dt, call # handle_data and record how long it takes. else: - start_tic = datetime.now() algo.handle_data(universe) stop_tic = datetime.now()