diff --git a/tests/test_data_portal.py b/tests/test_data_portal.py index 345e0fc7..26391f60 100644 --- a/tests/test_data_portal.py +++ b/tests/test_data_portal.py @@ -296,6 +296,31 @@ class DataPortalTestBase(WithDataPortal, for field in expected.keys()] assert_almost_equal(array(list(expected.values())), result) + def test_get_spot_value_multiple_assets(self): + equity = self.asset_finder.retrieve_asset(1) + future = self.asset_finder.retrieve_asset(10000) + trading_calendar = self.trading_calendars['CME'] + dts = trading_calendar.minutes_for_session(self.trading_days[3]) + + # We expect the outputs to be lists of spot values. + expected = pd.DataFrame( + { + equity: [nan, nan, nan, nan, 0, 101.3], + future: [203.5, 203.9, 203.1, 203.3, 2003, 203.3], + }, + index=['open', 'high', 'low', 'close', 'volume', 'price'], + ) + result = [ + self.data_portal.get_spot_value( + assets=[equity, future], + field=field, + dt=dts[1], + data_frequency='minute', + ) + for field in expected.index + ] + assert_almost_equal(expected.values.tolist(), result) + def test_bar_count_for_simple_transforms(self): # July 2015 # Su Mo Tu We Th Fr Sa diff --git a/zipline/data/data_portal.py b/zipline/data/data_portal.py index 3b77cdaa..762dd636 100644 --- a/zipline/data/data_portal.py +++ b/zipline/data/data_portal.py @@ -24,7 +24,13 @@ from pandas.tslib import normalize_date from six import iteritems from six.moves import reduce -from zipline.assets import Asset, Future, Equity +from zipline.assets import ( + Asset, + AssetConvertible, + Equity, + Future, + PricingDataAssociable, +) from zipline.assets.continuous_futures import ContinuousFuture from zipline.data.continuous_future_reader import ( ContinuousFutureSessionBarReader, @@ -436,15 +442,15 @@ class DataPortal(object): except KeyError: return np.NaN - def get_spot_value(self, asset, field, dt, data_frequency): + def get_spot_value(self, assets, field, dt, data_frequency): """ Public API method that returns a scalar value representing the value of the desired asset's field at either the given dt. Parameters ---------- - asset : Asset - The asset whose data is desired. + assets : Asset, ContinuousFuture, or iterable of same. + The asset or assets whose data is desired. field : {'open', 'high', 'low', 'close', 'volume', 'price', 'last_traded'} The desired field of the asset. @@ -463,41 +469,65 @@ class DataPortal(object): ``field`` is 'volume' the value will be a int. If the ``field`` is 'last_traded' the value will be a Timestamp. """ - if self._is_extra_source(asset, field, self._augmented_sources_map): - return self._get_fetcher_value(asset, field, dt) - - if field not in BASE_FIELDS: - raise KeyError("Invalid column: " + str(field)) + assets_is_scalar = False + if isinstance(assets, (AssetConvertible, PricingDataAssociable)): + assets_is_scalar = True + else: + # If 'assets' was not one of the expected types then it should be + # an iterable. + try: + iter(assets) + except TypeError: + raise TypeError( + "Unexpected 'assets' value of type {}." + .format(type(assets)) + ) session_label = self.trading_calendar.minute_to_session_label(dt) - if dt < asset.start_date or \ - (data_frequency == "daily" and - session_label > asset.end_date) or \ - (data_frequency == "minute" and - session_label > asset.end_date): - if field == "volume": - return 0 - elif field != "last_traded": - return np.NaN - elif field == "contract": - return None + def get_single_asset_value(asset): + if self._is_extra_source( + asset, field, self._augmented_sources_map): + return self._get_fetcher_value(asset, field, dt) - if data_frequency == "daily": - if field == "contract": - return self._get_current_contract(asset, session_label) + if field not in BASE_FIELDS: + raise KeyError("Invalid column: " + str(field)) + + if dt < asset.start_date or \ + (data_frequency == "daily" and + session_label > asset.end_date) or \ + (data_frequency == "minute" and + session_label > asset.end_date): + if field == "volume": + return 0 + elif field != "last_traded": + return np.NaN + elif field == "contract": + return None + + if data_frequency == "daily": + if field == "contract": + return self._get_current_contract(asset, session_label) + else: + return self._get_daily_spot_value( + asset, field, session_label, + ) else: - return self._get_daily_spot_value(asset, field, session_label) + if field == "last_traded": + return self.get_last_traded_dt(asset, dt, 'minute') + elif field == "price": + return self._get_minute_spot_value( + asset, "close", dt, ffill=True, + ) + elif field == "contract": + return self._get_current_contract(asset, dt) + else: + return self._get_minute_spot_value(asset, field, dt) + + if assets_is_scalar: + return get_single_asset_value(assets) else: - if field == "last_traded": - return self.get_last_traded_dt(asset, dt, 'minute') - elif field == "price": - return self._get_minute_spot_value(asset, "close", dt, - ffill=True) - elif field == "contract": - return self._get_current_contract(asset, dt) - else: - return self._get_minute_spot_value(asset, field, dt) + return list(map(get_single_asset_value, assets)) def get_adjustments(self, assets, field, dt, perspective_dt): """