diff --git a/tests/test_assets.py b/tests/test_assets.py index 8eb6489c..8fb0646b 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -568,18 +568,24 @@ class AssetFinderTestCase(WithTradingCalendars, ZiplineTestCase): for asof in pd.date_range('2014-01-01', '2014-01-05', tz='utc'): # from 01 through 05 sid 0 held 'A' + A_result = finder.lookup_symbol('A', asof) assert_equal( - finder.lookup_symbol('A', asof), + A_result, finder.retrieve_asset(0), msg=str(asof), ) + # the symbol should always be the last symbol + assert_equal(A_result.symbol, 'B') # from 01 through 05 sid 1 held 'C' + C_result = finder.lookup_symbol('C', asof) assert_equal( - finder.lookup_symbol('C', asof), + C_result, finder.retrieve_asset(1), msg=str(asof), ) + # the symbol should always be the last symbol + assert_equal(C_result.symbol, 'A') # no one held 'B' before 06 with self.assertRaises(SymbolNotFound): @@ -596,20 +602,24 @@ class AssetFinderTestCase(WithTradingCalendars, ZiplineTestCase): # from 06 through 10 sid 0 held 'B' # we test through the 11th because sid 1 is the last to hold 'B' # so it should ffill + B_result = finder.lookup_symbol('B', asof) assert_equal( - finder.lookup_symbol('B', asof), + B_result, finder.retrieve_asset(0), msg=str(asof), ) + assert_equal(B_result.symbol, 'B') # from 06 through 10 sid 1 held 'A' # we test through the 11th because sid 1 is the last to hold 'A' # so it should ffill + A_result = finder.lookup_symbol('A', asof) assert_equal( - finder.lookup_symbol('A', asof), + A_result, finder.retrieve_asset(1), msg=str(asof), ) + assert_equal(A_result.symbol, 'A') def test_lookup_symbol(self): diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index b770014a..6bfeb9f4 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -440,22 +440,27 @@ class AssetFinder(object): def _select_asset_by_symbol(asset_tbl, symbol): return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol) - def _lookup_most_recent_symbols(self, sids): + def _select_most_recent_symbols_chunk(self, sid_group): symbol_cols = self.equity_symbol_mappings.c + inner = sa.select( + (symbol_cols.sid,) + + tuple(map( + op.getitem(symbol_cols), + symbol_columns, + )), + ).where( + symbol_cols.sid.in_(map(int, sid_group)), + ).order_by( + symbol_cols.end_date.asc(), + ) + return sa.select(inner.c).group_by(inner.c.sid) + + def _lookup_most_recent_symbols(self, sids): symbols = { row.sid: {c: row[c] for c in symbol_columns} for row in concat( self.engine.execute( - sa.select( - (symbol_cols.sid,) + - tuple(map(op.getitem(symbol_cols), symbol_columns)), - ).where( - symbol_cols.sid.in_(map(int, sid_group)), - ).order_by( - symbol_cols.end_date.desc(), - ).group_by( - symbol_cols.sid, - ) + self._select_most_recent_symbols_chunk(sid_group), ).fetchall() for sid_group in partition_all( SQLITE_MAX_VARIABLE_NUMBER,