mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 21:39:31 +08:00
Merge pull request #1719 from quantopian/vectorize-spot-value
Allow DataPortal.get_spot_value to accept multiple assets
This commit is contained in:
@@ -296,6 +296,31 @@ class DataPortalTestBase(WithDataPortal,
|
||||
for field in expected.keys()]
|
||||
assert_almost_equal(array(list(expected.values())), result)
|
||||
|
||||
def test_get_spot_value_multiple_assets(self):
|
||||
equity = self.asset_finder.retrieve_asset(1)
|
||||
future = self.asset_finder.retrieve_asset(10000)
|
||||
trading_calendar = self.trading_calendars['CME']
|
||||
dts = trading_calendar.minutes_for_session(self.trading_days[3])
|
||||
|
||||
# We expect the outputs to be lists of spot values.
|
||||
expected = pd.DataFrame(
|
||||
{
|
||||
equity: [nan, nan, nan, nan, 0, 101.3],
|
||||
future: [203.5, 203.9, 203.1, 203.3, 2003, 203.3],
|
||||
},
|
||||
index=['open', 'high', 'low', 'close', 'volume', 'price'],
|
||||
)
|
||||
result = [
|
||||
self.data_portal.get_spot_value(
|
||||
assets=[equity, future],
|
||||
field=field,
|
||||
dt=dts[1],
|
||||
data_frequency='minute',
|
||||
)
|
||||
for field in expected.index
|
||||
]
|
||||
assert_almost_equal(expected.values.tolist(), result)
|
||||
|
||||
def test_bar_count_for_simple_transforms(self):
|
||||
# July 2015
|
||||
# Su Mo Tu We Th Fr Sa
|
||||
|
||||
+63
-33
@@ -24,7 +24,13 @@ from pandas.tslib import normalize_date
|
||||
from six import iteritems
|
||||
from six.moves import reduce
|
||||
|
||||
from zipline.assets import Asset, Future, Equity
|
||||
from zipline.assets import (
|
||||
Asset,
|
||||
AssetConvertible,
|
||||
Equity,
|
||||
Future,
|
||||
PricingDataAssociable,
|
||||
)
|
||||
from zipline.assets.continuous_futures import ContinuousFuture
|
||||
from zipline.data.continuous_future_reader import (
|
||||
ContinuousFutureSessionBarReader,
|
||||
@@ -436,15 +442,15 @@ class DataPortal(object):
|
||||
except KeyError:
|
||||
return np.NaN
|
||||
|
||||
def get_spot_value(self, asset, field, dt, data_frequency):
|
||||
def get_spot_value(self, assets, field, dt, data_frequency):
|
||||
"""
|
||||
Public API method that returns a scalar value representing the value
|
||||
of the desired asset's field at either the given dt.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
asset : Asset
|
||||
The asset whose data is desired.
|
||||
assets : Asset, ContinuousFuture, or iterable of same.
|
||||
The asset or assets whose data is desired.
|
||||
field : {'open', 'high', 'low', 'close', 'volume',
|
||||
'price', 'last_traded'}
|
||||
The desired field of the asset.
|
||||
@@ -463,41 +469,65 @@ class DataPortal(object):
|
||||
``field`` is 'volume' the value will be a int. If the ``field`` is
|
||||
'last_traded' the value will be a Timestamp.
|
||||
"""
|
||||
if self._is_extra_source(asset, field, self._augmented_sources_map):
|
||||
return self._get_fetcher_value(asset, field, dt)
|
||||
|
||||
if field not in BASE_FIELDS:
|
||||
raise KeyError("Invalid column: " + str(field))
|
||||
assets_is_scalar = False
|
||||
if isinstance(assets, (AssetConvertible, PricingDataAssociable)):
|
||||
assets_is_scalar = True
|
||||
else:
|
||||
# If 'assets' was not one of the expected types then it should be
|
||||
# an iterable.
|
||||
try:
|
||||
iter(assets)
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
"Unexpected 'assets' value of type {}."
|
||||
.format(type(assets))
|
||||
)
|
||||
|
||||
session_label = self.trading_calendar.minute_to_session_label(dt)
|
||||
|
||||
if dt < asset.start_date or \
|
||||
(data_frequency == "daily" and
|
||||
session_label > asset.end_date) or \
|
||||
(data_frequency == "minute" and
|
||||
session_label > asset.end_date):
|
||||
if field == "volume":
|
||||
return 0
|
||||
elif field != "last_traded":
|
||||
return np.NaN
|
||||
elif field == "contract":
|
||||
return None
|
||||
def get_single_asset_value(asset):
|
||||
if self._is_extra_source(
|
||||
asset, field, self._augmented_sources_map):
|
||||
return self._get_fetcher_value(asset, field, dt)
|
||||
|
||||
if data_frequency == "daily":
|
||||
if field == "contract":
|
||||
return self._get_current_contract(asset, session_label)
|
||||
if field not in BASE_FIELDS:
|
||||
raise KeyError("Invalid column: " + str(field))
|
||||
|
||||
if dt < asset.start_date or \
|
||||
(data_frequency == "daily" and
|
||||
session_label > asset.end_date) or \
|
||||
(data_frequency == "minute" and
|
||||
session_label > asset.end_date):
|
||||
if field == "volume":
|
||||
return 0
|
||||
elif field != "last_traded":
|
||||
return np.NaN
|
||||
elif field == "contract":
|
||||
return None
|
||||
|
||||
if data_frequency == "daily":
|
||||
if field == "contract":
|
||||
return self._get_current_contract(asset, session_label)
|
||||
else:
|
||||
return self._get_daily_spot_value(
|
||||
asset, field, session_label,
|
||||
)
|
||||
else:
|
||||
return self._get_daily_spot_value(asset, field, session_label)
|
||||
if field == "last_traded":
|
||||
return self.get_last_traded_dt(asset, dt, 'minute')
|
||||
elif field == "price":
|
||||
return self._get_minute_spot_value(
|
||||
asset, "close", dt, ffill=True,
|
||||
)
|
||||
elif field == "contract":
|
||||
return self._get_current_contract(asset, dt)
|
||||
else:
|
||||
return self._get_minute_spot_value(asset, field, dt)
|
||||
|
||||
if assets_is_scalar:
|
||||
return get_single_asset_value(assets)
|
||||
else:
|
||||
if field == "last_traded":
|
||||
return self.get_last_traded_dt(asset, dt, 'minute')
|
||||
elif field == "price":
|
||||
return self._get_minute_spot_value(asset, "close", dt,
|
||||
ffill=True)
|
||||
elif field == "contract":
|
||||
return self._get_current_contract(asset, dt)
|
||||
else:
|
||||
return self._get_minute_spot_value(asset, field, dt)
|
||||
return list(map(get_single_asset_value, assets))
|
||||
|
||||
def get_adjustments(self, assets, field, dt, perspective_dt):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user