diff --git a/zipline/gens/tradegens.py b/zipline/gens/tradegens.py index 2e8f6bea..801d98f4 100644 --- a/zipline/gens/tradegens.py +++ b/zipline/gens/tradegens.py @@ -5,18 +5,22 @@ and zipline development import random import pytz -from itertools import chain, cycle, ifilter, izip +from itertools import chain, cycle, ifilter, izip, repeat from datetime import datetime, timedelta from zipline.gens.utils import hash_args, create_trade def date_gen(start = datetime(2006, 6, 6, 12, tzinfo=pytz.utc), delta = timedelta(minutes = 1), - count = 100): + count = 100, + repeats = None): """ Utility to generate a stream of dates. """ - return (start + (i * delta) for i in xrange(count)) + if repeats: + return (start + (i * delta) for i in xrange(count) for n in xrange(repeats)) + else: + return (start + (i * delta) for i in xrange(count)) def mock_prices(count, rand = False): """ @@ -74,6 +78,7 @@ class SpecificEquityTrades(object): self.sids = kwargs.get('sids', [1, 2]) self.start = kwargs.get('start', datetime(2008, 6, 6, 15, tzinfo = pytz.utc)) self.delta = kwargs.get('delta', timedelta(minutes = 1)) + self.concurrent = kwargs.get('concurrent', False) # Default to None for event_list and filter. self.event_list = kwargs.get('event_list') @@ -103,20 +108,35 @@ class SpecificEquityTrades(object): # Set up iterators for each expected field. else: - dates = date_gen(count=self.count, - start=self.start, - delta=self.delta - ) - prices = mock_prices(self.count) - volumes = mock_volumes(self.count) - sids = cycle(self.sids) + if self.concurrent: + # in this context the count is the number of + # trades per sid, not the total. + dates = date_gen( + count=self.count, + start=self.start, + delta=self.delta, + repeats=len(self.sids), + ) - # Combine the iterators into a single iterator of arguments - arg_gen = izip(sids, prices, volumes, dates) - # Convert argument packages into events. - unfiltered = (create_trade(*args, source_id = self.get_hash()) - for args in arg_gen) + else: + + dates = date_gen( + count=self.count, + start=self.start, + delta=self.delta + ) + + prices = mock_prices(self.count) + volumes = mock_volumes(self.count) + + sids = cycle(self.sids) + # Combine the iterators into a single iterator of arguments + arg_gen = izip(sids, prices, volumes, dates) + + # Convert argument packages into events. + unfiltered = (create_trade(*args, source_id = self.get_hash()) + for args in arg_gen) # If we specified a sid filter, filter out elements that don't # match the filter. diff --git a/zipline/lines.py b/zipline/lines.py index 9bc8b3ac..1c3a558f 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -269,7 +269,12 @@ class SimulatedTrading(object): of StatefulTransform objects. """ assert isinstance(config, dict) - sid = config['sid'] + sid_list = config.get('sid_list') + if not sid_list: + sid = config.get('sid') + sid_list = [sid] + + concurrent_trades = config.get('concurrent_trades', False) #-------------------- # Trading Environment @@ -307,17 +312,17 @@ class SimulatedTrading(object): #------------------- # Trade Source #------------------- - sids = [sid] - #------------------- if config.has_key('trade_source'): trade_source = config['trade_source'] else: trade_source = factory.create_daily_trade_source( - sids, + sid_list, trade_count, - trading_environment + trading_environment, + concurrent=concurrent_trades ) + #------------------- # Transforms #------------------- diff --git a/zipline/utils/factory.py b/zipline/utils/factory.py index e3d92443..cf2168fb 100644 --- a/zipline/utils/factory.py +++ b/zipline/utils/factory.py @@ -174,7 +174,7 @@ def create_random_trade_source(sid, trade_count, trading_environment): return source -def create_daily_trade_source(sids, trade_count, trading_environment): +def create_daily_trade_source(sids, trade_count, trading_environment, concurrent=False): """ creates trade_count trades for each sid in sids list. @@ -189,11 +189,12 @@ def create_daily_trade_source(sids, trade_count, trading_environment): sids, trade_count, timedelta(days=1), - trading_environment + trading_environment, + concurrent=concurrent ) -def create_minutely_trade_source(sids, trade_count, trading_environment): +def create_minutely_trade_source(sids, trade_count, trading_environment, concurrent=False): """ creates trade_count trades for each sid in sids list. @@ -208,10 +209,11 @@ def create_minutely_trade_source(sids, trade_count, trading_environment): sids, trade_count, timedelta(minutes=1), - trading_environment + trading_environment, + concurrent=concurrent ) -def create_trade_source(sids, trade_count, trade_time_increment, trading_environment): +def create_trade_source(sids, trade_count, trade_time_increment, trading_environment, concurrent=False): args = tuple() kwargs = { @@ -219,7 +221,8 @@ def create_trade_source(sids, trade_count, trade_time_increment, trading_environ 'sids' : sids, 'start' : trading_environment.first_open, 'delta' : trade_time_increment, - 'filter' : sids + 'filter' : sids, + 'concurrent' : concurrent } source = SpecificEquityTrades(*args, **kwargs)