mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 20:37:52 +08:00
ENH: Removes use of lookup_generic in DataFrame index mapping
This commit is contained in:
@@ -719,7 +719,7 @@ class AssetFinderTestCase(TestCase):
|
||||
# No contracts exist after 12/14/2015, so we should get none
|
||||
self.assertIsNone(finder.lookup_future_by_expiration('AD', dt, jan_16))
|
||||
|
||||
def test_map_identifier_list_to_sids(self):
|
||||
def test_map_identifier_index_to_sids(self):
|
||||
|
||||
# Build an empty finder and some Assets
|
||||
dt = pd.Timestamp('2014-01-01', tz='UTC')
|
||||
@@ -731,12 +731,12 @@ class AssetFinderTestCase(TestCase):
|
||||
|
||||
# Check for correct mapping and types
|
||||
pre_map = [asset1, asset2, asset200, asset201]
|
||||
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
|
||||
post_map = finder.map_identifier_index_to_sids(pre_map, dt)
|
||||
self.assertListEqual([1, 2, 200, 201], post_map)
|
||||
for sid in post_map:
|
||||
self.assertIsInstance(sid, int)
|
||||
|
||||
# Change order and check mapping again
|
||||
pre_map = [asset201, asset2, asset200, asset1]
|
||||
post_map = finder.map_identifier_list_to_sids(pre_map, dt)
|
||||
post_map = finder.map_identifier_index_to_sids(pre_map, dt)
|
||||
self.assertListEqual([201, 2, 200, 1], post_map)
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestDataFrameSource(TestCase):
|
||||
'volume', 'price']
|
||||
|
||||
copy_panel = data.copy()
|
||||
sids = finder.map_identifier_list_to_sids(
|
||||
sids = finder.map_identifier_index_to_sids(
|
||||
data.items, data.major_axis[0]
|
||||
)
|
||||
copy_panel.items = sids
|
||||
|
||||
@@ -463,16 +463,17 @@ class TradingAlgorithm(object):
|
||||
# if DataFrame provided, map columns to sids and wrap
|
||||
# in DataFrameSource
|
||||
copy_frame = source.copy()
|
||||
copy_frame.columns = self.asset_finder.map_identifier_list_to_sids(
|
||||
source.columns, source.index[0]
|
||||
)
|
||||
copy_frame.columns = \
|
||||
self.asset_finder.map_identifier_index_to_sids(
|
||||
source.columns, source.index[0]
|
||||
)
|
||||
source = DataFrameSource(copy_frame)
|
||||
|
||||
elif isinstance(source, pd.Panel):
|
||||
# If Panel provided, map items to sids and wrap
|
||||
# in DataPanelSource
|
||||
copy_panel = source.copy()
|
||||
copy_panel.items = self.asset_finder.map_identifier_list_to_sids(
|
||||
copy_panel.items = self.asset_finder.map_identifier_index_to_sids(
|
||||
source.items, source.major_axis[0]
|
||||
)
|
||||
source = DataPanelSource(copy_panel)
|
||||
|
||||
+40
-11
@@ -32,6 +32,7 @@ from zipline.errors import (
|
||||
SidAssignmentError,
|
||||
SidNotFound,
|
||||
SymbolNotFound,
|
||||
MapAssetIdentifierIndexError,
|
||||
)
|
||||
from zipline.assets._assets import (
|
||||
Asset, Equity, Future
|
||||
@@ -532,29 +533,57 @@ class AssetFinder(object):
|
||||
self._lookup_generic_scalar(obj, as_of_date, matches, missing)
|
||||
return matches, missing
|
||||
|
||||
def map_identifier_list_to_sids(self, index, as_of_date):
|
||||
def map_identifier_index_to_sids(self, index, as_of_date):
|
||||
"""
|
||||
This method is for use in sanitizing a user's DataFrame inputs.
|
||||
Takes the given indices, inserts them in to the AssetFinder as
|
||||
identifiers, rebuilds the Assets, and returns a new object that
|
||||
contains the sids of the given identifiers.
|
||||
This method is for use in sanitizing a user's DataFrame or Panel
|
||||
inputs.
|
||||
|
||||
:param index: The index to be mapped
|
||||
:return: The new index of sids
|
||||
Takes the given index of identifiers, checks their types, builds assets
|
||||
if necessary, and returns a list of the sids that correspond to the
|
||||
input index.
|
||||
|
||||
Parameters
|
||||
__________
|
||||
index : Iterable
|
||||
An iterable containing ints, strings, or Assets
|
||||
as_of_date : pandas.Timestamp
|
||||
A date to be used to resolve any dual-mapped symbols
|
||||
|
||||
Returns
|
||||
_______
|
||||
List
|
||||
A list of integer sids corresponding to the input index
|
||||
"""
|
||||
# Populate the caches with the given indices
|
||||
# This method assumes that the type of the objects in the index is
|
||||
# consistent and can, therefore, be taken from the first identifier
|
||||
first_identifier = index[0]
|
||||
|
||||
# Ensure that input is AssetConvertible (integer, string, or Asset)
|
||||
if not isinstance(first_identifier, AssetConvertible):
|
||||
raise MapAssetIdentifierIndexError(obj=first_identifier)
|
||||
|
||||
# If sids are provided, no mapping is necessary
|
||||
if isinstance(first_identifier, Integral):
|
||||
return index
|
||||
|
||||
# If symbols or Assets are provided, construction and mapping is
|
||||
# necessary
|
||||
self.consume_identifiers(index)
|
||||
self.populate_cache()
|
||||
|
||||
# Find all of the newly built assets corresponding to the indices
|
||||
found, missing = self.lookup_generic(index, as_of_date)
|
||||
# Look up all Assets for mapping
|
||||
matches = []
|
||||
missing = []
|
||||
for identifier in index:
|
||||
self._lookup_generic_scalar(identifier, as_of_date,
|
||||
matches, missing)
|
||||
|
||||
# Handle missing assets
|
||||
if len(missing) > 0:
|
||||
warnings.warn("Missing assets for identifiers: " + missing)
|
||||
|
||||
# Return a list of the sids of the found assets
|
||||
return [asset.sid for asset in found]
|
||||
return [asset.sid for asset in matches]
|
||||
|
||||
def insert_metadata(self, identifier, **kwargs):
|
||||
"""
|
||||
|
||||
+12
-1
@@ -246,7 +246,7 @@ TradingEnvironment can not set asset_finder to object of class {cls}.
|
||||
|
||||
class ConsumeAssetMetaDataError(ZiplineError):
|
||||
"""
|
||||
Raised when AssetMetaData.consume() is called on an invalid object.
|
||||
Raised when AssetFinder.consume() is called on an invalid object.
|
||||
"""
|
||||
msg = """
|
||||
AssetFinder can not consume metadata of type {obj}. Metadata must be a dict, a
|
||||
@@ -255,6 +255,17 @@ must contain both or one of 'sid' or 'symbol'.
|
||||
""".strip()
|
||||
|
||||
|
||||
class MapAssetIdentifierIndexError(ZiplineError):
|
||||
"""
|
||||
Raised when AssetMetaData.map_identifier_index_to_sids() is called on an
|
||||
index of invalid objects.
|
||||
"""
|
||||
msg = """
|
||||
AssetFinder can not map an index with values of type {obj}. Asset indices of
|
||||
DataFrames or Panels must be integer sids, string symbols, or Asset objects.
|
||||
""".strip()
|
||||
|
||||
|
||||
class SidAssignmentError(ZiplineError):
|
||||
"""
|
||||
Raised when an AssetFinder tries to build an Asset that does not have a sid
|
||||
|
||||
Reference in New Issue
Block a user