diff --git a/tests/test_batchtransform.py b/tests/test_batchtransform.py index 490e35f7..17a9b7e3 100644 --- a/tests/test_batchtransform.py +++ b/tests/test_batchtransform.py @@ -147,6 +147,16 @@ class TestBatchTransformMinutely(TestCase): for bt in algo.history[wl:]: self.assertEqual(len(bt), wl) + def test_window_length(self): + algo = BatchTransformAlgorithmMinute(sim_params=self.sim_params, + window_length=1, refresh_period=0) + algo.run(self.source) + wl = int(algo.window_length * 6.5 * 60) + np.testing.assert_array_equal(algo.history[:(wl - 1)], + [None] * (wl - 1)) + for bt in algo.history[wl:]: + self.assertEqual(len(bt), wl) + class TestBatchTransform(TestCase): def setUp(self): diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 40963167..4b40768b 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -456,10 +456,15 @@ class BatchTransform(object): columns=sids)) # update trading day counters - if self.last_dt.day != event.dt.day: - self.last_dt = event.dt + _, mkt_close = trading.environment.get_open_and_close(event.dt) + if self.bars == 'daily': + # Daily bars have their dt set to midnight. + mkt_close = mkt_close.replace(hour=0, minute=0, second=0) + if event.dt >= mkt_close: self.trading_days_total += 1 + self.last_dt = event.dt + if self.trading_days_total >= self.window_length: self.full = True