Merge pull request #1719 from quantopian/vectorize-spot-value

Allow DataPortal.get_spot_value to accept multiple assets
This commit is contained in:
David Michalowicz
2017-03-25 09:22:59 -04:00
committed by GitHub
2 changed files with 88 additions and 33 deletions
+25
View File
@@ -296,6 +296,31 @@ class DataPortalTestBase(WithDataPortal,
for field in expected.keys()]
assert_almost_equal(array(list(expected.values())), result)
def test_get_spot_value_multiple_assets(self):
equity = self.asset_finder.retrieve_asset(1)
future = self.asset_finder.retrieve_asset(10000)
trading_calendar = self.trading_calendars['CME']
dts = trading_calendar.minutes_for_session(self.trading_days[3])
# We expect the outputs to be lists of spot values.
expected = pd.DataFrame(
{
equity: [nan, nan, nan, nan, 0, 101.3],
future: [203.5, 203.9, 203.1, 203.3, 2003, 203.3],
},
index=['open', 'high', 'low', 'close', 'volume', 'price'],
)
result = [
self.data_portal.get_spot_value(
assets=[equity, future],
field=field,
dt=dts[1],
data_frequency='minute',
)
for field in expected.index
]
assert_almost_equal(expected.values.tolist(), result)
def test_bar_count_for_simple_transforms(self):
# July 2015
# Su Mo Tu We Th Fr Sa
+63 -33
View File
@@ -24,7 +24,13 @@ from pandas.tslib import normalize_date
from six import iteritems
from six.moves import reduce
from zipline.assets import Asset, Future, Equity
from zipline.assets import (
Asset,
AssetConvertible,
Equity,
Future,
PricingDataAssociable,
)
from zipline.assets.continuous_futures import ContinuousFuture
from zipline.data.continuous_future_reader import (
ContinuousFutureSessionBarReader,
@@ -436,15 +442,15 @@ class DataPortal(object):
except KeyError:
return np.NaN
def get_spot_value(self, asset, field, dt, data_frequency):
def get_spot_value(self, assets, field, dt, data_frequency):
"""
Public API method that returns a scalar value representing the value
of the desired asset's field at either the given dt.
Parameters
----------
asset : Asset
The asset whose data is desired.
assets : Asset, ContinuousFuture, or iterable of same.
The asset or assets whose data is desired.
field : {'open', 'high', 'low', 'close', 'volume',
'price', 'last_traded'}
The desired field of the asset.
@@ -463,41 +469,65 @@ class DataPortal(object):
``field`` is 'volume' the value will be a int. If the ``field`` is
'last_traded' the value will be a Timestamp.
"""
if self._is_extra_source(asset, field, self._augmented_sources_map):
return self._get_fetcher_value(asset, field, dt)
if field not in BASE_FIELDS:
raise KeyError("Invalid column: " + str(field))
assets_is_scalar = False
if isinstance(assets, (AssetConvertible, PricingDataAssociable)):
assets_is_scalar = True
else:
# If 'assets' was not one of the expected types then it should be
# an iterable.
try:
iter(assets)
except TypeError:
raise TypeError(
"Unexpected 'assets' value of type {}."
.format(type(assets))
)
session_label = self.trading_calendar.minute_to_session_label(dt)
if dt < asset.start_date or \
(data_frequency == "daily" and
session_label > asset.end_date) or \
(data_frequency == "minute" and
session_label > asset.end_date):
if field == "volume":
return 0
elif field != "last_traded":
return np.NaN
elif field == "contract":
return None
def get_single_asset_value(asset):
if self._is_extra_source(
asset, field, self._augmented_sources_map):
return self._get_fetcher_value(asset, field, dt)
if data_frequency == "daily":
if field == "contract":
return self._get_current_contract(asset, session_label)
if field not in BASE_FIELDS:
raise KeyError("Invalid column: " + str(field))
if dt < asset.start_date or \
(data_frequency == "daily" and
session_label > asset.end_date) or \
(data_frequency == "minute" and
session_label > asset.end_date):
if field == "volume":
return 0
elif field != "last_traded":
return np.NaN
elif field == "contract":
return None
if data_frequency == "daily":
if field == "contract":
return self._get_current_contract(asset, session_label)
else:
return self._get_daily_spot_value(
asset, field, session_label,
)
else:
return self._get_daily_spot_value(asset, field, session_label)
if field == "last_traded":
return self.get_last_traded_dt(asset, dt, 'minute')
elif field == "price":
return self._get_minute_spot_value(
asset, "close", dt, ffill=True,
)
elif field == "contract":
return self._get_current_contract(asset, dt)
else:
return self._get_minute_spot_value(asset, field, dt)
if assets_is_scalar:
return get_single_asset_value(assets)
else:
if field == "last_traded":
return self.get_last_traded_dt(asset, dt, 'minute')
elif field == "price":
return self._get_minute_spot_value(asset, "close", dt,
ffill=True)
elif field == "contract":
return self._get_current_contract(asset, dt)
else:
return self._get_minute_spot_value(asset, field, dt)
return list(map(get_single_asset_value, assets))
def get_adjustments(self, assets, field, dt, perspective_dt):
"""