diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 5cc33f30..35eec0d1 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -95,7 +95,8 @@ class TradeSimulationClient(qmsg.Component): def run_algorithm(self): frame = self.get_frame() - self.algorithm.handle_frame(frame) + if len(frame) > 0: + self.algorithm.handle_frame(frame) def connect_order(self): return self.connect_push_socket(self.addresses['order_address']) @@ -129,6 +130,7 @@ class TradeSimulationClient(qmsg.Component): def get_frame(self): for event in self.event_queue: self.event_frame[event['sid']] = event + self.event_queue = [] return self.event_frame class OrderDataSource(qmsg.DataSource): diff --git a/zipline/lines.py b/zipline/lines.py index 16502fe3..dbef3930 100644 --- a/zipline/lines.py +++ b/zipline/lines.py @@ -261,7 +261,7 @@ class SimulatedTrading(object): """ assert isinstance(source, zmsg.DataSource) self.check_started() - source.set_filter('SID', self.algorithm.get_sid_filter) + source.set_filter('SID', self.algorithm.get_sid_filter()) self.sim.register_components([source]) self.sources[source.get_id] = source diff --git a/zipline/sources.py b/zipline/sources.py index 28d356b9..e1c2a122 100644 --- a/zipline/sources.py +++ b/zipline/sources.py @@ -13,12 +13,19 @@ class TradeDataSource(zm.DataSource): def send(self, event): """ + Sends the event iff it matches the internal SID filter. :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME` :rtype: None """ + event.source_id = self.get_id - message = zp.DATASOURCE_FRAME(event) + if event.sid in self.filter['SID']: + + message = zp.DATASOURCE_FRAME(event) + else: + message = zp.DATASOURCE_FRAME(None) + self.data_socket.send(message) @@ -56,9 +63,8 @@ class RandomEquityTrades(TradeDataSource): "dt" : self.trade_start + (self.minute * self.incr), }) - self.send(event) - self.incr += 1 + class SpecificEquityTrades(TradeDataSource): diff --git a/zipline/test/algorithms.py b/zipline/test/algorithms.py index 8d8366bb..37444919 100644 --- a/zipline/test/algorithms.py +++ b/zipline/test/algorithms.py @@ -50,11 +50,13 @@ class TestAlgorithm(): self.incr = 0 self.done = False self.order = None + self.frame_count = 0 def set_order(self, order_callable): self.order = order_callable def handle_frame(self, frame): + self.frame_count += 1 for dt, s in frame.iteritems(): data = {} data.update(s) diff --git a/zipline/test/factory.py b/zipline/test/factory.py index 31f6b3af..18dfacf4 100644 --- a/zipline/test/factory.py +++ b/zipline/test/factory.py @@ -63,20 +63,27 @@ def create_trade(sid, price, amount, datetime): }) return row +def get_next_trading_dt(current, interval, trading_calendar): + next = current + while True: + next = next + interval + if trading_calendar.is_trading_day(next): + break + else: + next = next + timedelta(days=1) + + return next + def create_trade_history(sid, prices, amounts, start_time, interval, trading_calendar): i = 0 trades = [] current = start_time.replace(tzinfo = pytz.utc) for price, amount in zip(prices, amounts): - - if(trading_calendar.is_trading_day(current)): - trade = create_trade(sid, price, amount, current) - trades.append(trade) - - current = current + interval - else: - current = current + timedelta(days=1) + + current = get_next_trading_dt(current, interval, trading_calendar) + trade = create_trade(sid, price, amount, current) + trades.append(trade) return trades @@ -94,14 +101,10 @@ def create_txn_history(sid, priceList, amtList, startTime, interval, trading_cal current = startTime for price, amount in zip(priceList, amtList): + current = get_next_trading_dt(current, interval, trading_calendar) - if trading_calendar.is_trading_day(current): - txns.append(create_txn(sid, price, amount, current)) - current = current + interval - - else: - current = current + timedelta(days=1) - + txns.append(create_txn(sid, price, amount, current)) + current = current + interval return txns @@ -111,11 +114,16 @@ def create_returns(daycount, start, trading_calendar): current = start.replace(tzinfo=pytz.utc) one_day = timedelta(days = 1) while i < daycount: + current = get_next_trading_dt( + current, + one_day, + trading_calendar + ) i += 1 r = risk.DailyReturn(current, random.random()) test_range.append(r) - current = current + one_day - return [ x for x in test_range if(trading_calendar.is_trading_day(x.date)) ] + + return test_range def create_returns_from_range(start, end, trading_calendar): @@ -123,14 +131,11 @@ def create_returns_from_range(start, end, trading_calendar): end = end.replace(tzinfo=pytz.utc) one_day = timedelta(days = 1) test_range = [] - i = 0 while current <= end: - current = current + one_day - if(not trading_calendar.is_trading_day(current)): - continue + current = get_next_trading_dt(current, one_day, trading_calender) r = risk.DailyReturn(current, random.random()) - i += 1 test_range.append(r) + return test_range @@ -138,13 +143,11 @@ def create_returns_from_list(returns, start, trading_calendar): current = start.replace(tzinfo=pytz.utc) one_day = timedelta(days = 1) test_range = [] - i = 0 while len(test_range) < len(returns): - if(trading_calendar.is_trading_day(current)): - r = risk.DailyReturn(current, returns[i]) - i += 1 - test_range.append(r) - current = current + one_day + current = get_next_trading_dt(current, one_day, trading_calendar) + r = risk.DailyReturn(current, returns[i]) + test_range.append(r) + return sorted(test_range, key=lambda(x):x.date) def create_daily_trade_source(sids, trade_count, trading_environment): diff --git a/zipline/test/test_finance.py b/zipline/test/test_finance.py index 5edc0b0f..31b5b511 100644 --- a/zipline/test/test_finance.py +++ b/zipline/test/test_finance.py @@ -107,3 +107,43 @@ class FinanceTestCase(TestCase): SID, "Portfolio should have one position in " + str(SID) ) + + self.assertEqual( + zipline.algorithm.frame_count, + self.zipline_test_config['trade_count'], + "The algorithm should receive all trades." + ) + + @timed(DEFAULT_TIMEOUT) + def test_sid_filter(self): + """Ensure the algorithm's filter prevents events from arriving.""" + # create a test algorithm whose filter will not match any of the + # trade events sourced inside the zipline. + order_amount = 100 + order_count = 100 + no_match_sid = 222 + test_algo = TestAlgorithm( + no_match_sid, + order_amount, + order_count + ) + + self.zipline_test_config['trade_count'] = 200 + self.zipline_test_config['algorithm'] = test_algo + + zipline = SimulatedTrading.create_test_zipline(**self.zipline_test_config) + + zipline.simulate(blocking=True) + + #check that the algorithm received no events + self.assertEqual( + 0, + test_algo.frame_count, + "The algorithm should not receive any events due to filtering." + ) + + + + + +