mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 12:29:43 +08:00
ENH: Replace update_asset_finder with write_data
The write_data methods invokes the relevant AssetDBWriter subclass to write data to the database. update_asset_finder is no longer a relevant method since the AssetFinder is strictly a reader class.
This commit is contained in:
@@ -38,7 +38,7 @@ class BlotterTestCase(TestCase):
|
||||
@with_environment()
|
||||
def setUp(self, env=None):
|
||||
setup_logger(self)
|
||||
env.update_asset_finder(identifiers=[24])
|
||||
env.write_data(equities_identifiers=[24])
|
||||
|
||||
def tearDown(self):
|
||||
teardown_logger(self)
|
||||
|
||||
@@ -1894,7 +1894,7 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
foosid = 1
|
||||
barsid = 2
|
||||
|
||||
env.update_asset_finder(identifiers=[foosid, barsid])
|
||||
env.write_data(equities_identifiers=[foosid, barsid])
|
||||
|
||||
foo_event_1 = factory.create_trade(foosid, 10.0, 20, start_dt)
|
||||
order_event_1 = Order(sid=foo_event_1.sid,
|
||||
@@ -1989,7 +1989,7 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
|
||||
@with_environment()
|
||||
def test_close_position_event(self, env=None):
|
||||
env.update_asset_finder(identifiers=[1, 2])
|
||||
env.write_data(equities_identifiers=[1, 2])
|
||||
pt = perf.PositionTracker()
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(120.0),
|
||||
@@ -2133,9 +2133,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
2: {'asset_type': 'future',
|
||||
'contract_multiplier': 1000}}
|
||||
asset_finder = AssetFinder()
|
||||
env.update_asset_finder(
|
||||
asset_finder=asset_finder,
|
||||
asset_metadata=metadata)
|
||||
env.write_data(equities_data=metadata)
|
||||
pt = perf.PositionTracker()
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(100.0),
|
||||
@@ -2163,7 +2161,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
'contract_multiplier': 1000},
|
||||
4: {'asset_type': 'future',
|
||||
'contract_multiplier': 1000}}
|
||||
env.update_asset_finder(asset_metadata=metadata)
|
||||
env.write_data(equities_data=metadata)
|
||||
pt = perf.PositionTracker()
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(10.0),
|
||||
@@ -2197,7 +2195,7 @@ class TestPositionTracker(unittest.TestCase):
|
||||
metadata = {1: {'asset_type': 'equity'},
|
||||
2: {'asset_type': 'future',
|
||||
'contract_multiplier': 1000}}
|
||||
env.update_asset_finder(asset_metadata=metadata)
|
||||
env.write_data(equities_data=metadata)
|
||||
pt = perf.PositionTracker()
|
||||
dt = pd.Timestamp("1984/03/06 3:00PM")
|
||||
pos1 = perf.Position(1, amount=np.float64(120.0),
|
||||
|
||||
@@ -66,9 +66,8 @@ class SecurityListTestCase(TestCase):
|
||||
self.trading_day_before_first_kd = datetime(
|
||||
2015, 1, 23, 0, 0, tzinfo=pytz.utc)
|
||||
|
||||
env.update_asset_finder(
|
||||
clear_metadata=True,
|
||||
identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"]
|
||||
env.write_data(
|
||||
equities_identifiers=["BZQ", "URTY", "JFT", "AAPL", "GOOG"]
|
||||
)
|
||||
|
||||
setup_logger(self)
|
||||
|
||||
@@ -205,11 +205,11 @@ class TradingAlgorithm(object):
|
||||
# Update the TradingEnvironment with the provided asset metadata
|
||||
self.trading_environment = kwargs.pop('env',
|
||||
TradingEnvironment.instance())
|
||||
self.trading_environment.update_asset_finder(
|
||||
asset_finder=kwargs.pop('asset_finder', None),
|
||||
asset_metadata=kwargs.pop('asset_metadata', None),
|
||||
identifiers=kwargs.pop('identifiers', None)
|
||||
self.trading_environment.write_data(
|
||||
equities_data=kwargs.pop('asset_metadata', None),
|
||||
equities_identifiers=kwargs.pop('identifiers', None),
|
||||
)
|
||||
|
||||
# Pull in the environment's new AssetFinder for quick reference
|
||||
self.asset_finder = self.trading_environment.asset_finder
|
||||
self.init_engine(kwargs.pop('ffc_loader', None))
|
||||
|
||||
@@ -243,16 +243,6 @@ AssetMetaData contained an invalid Asset type: '{asset_type}'.
|
||||
""".strip()
|
||||
|
||||
|
||||
class UpdateAssetFinderTypeError(ZiplineError):
|
||||
"""
|
||||
Raised when TradingEnvironment.update_asset_finder() gets an asset_finder
|
||||
arg that is not of AssetFinder class.
|
||||
"""
|
||||
msg = """
|
||||
TradingEnvironment can not set asset_finder to object of class {cls}.
|
||||
""".strip()
|
||||
|
||||
|
||||
class ConsumeAssetMetaDataError(ZiplineError):
|
||||
"""
|
||||
Raised when AssetFinder.consume() is called on an invalid object.
|
||||
|
||||
+58
-42
@@ -22,6 +22,7 @@ import sqlite3
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sqlalchemy import create_engine
|
||||
# from multipledispatch import dispatch
|
||||
|
||||
from zipline.data.loader import load_market_data
|
||||
from zipline.utils import tradingcalendar
|
||||
@@ -31,8 +32,7 @@ from zipline.assets.asset_writer import (
|
||||
AssetDBWriterFromDictionary,
|
||||
AssetDBWriterFromDataFrame)
|
||||
from zipline.errors import (
|
||||
NoFurtherDataError,
|
||||
UpdateAssetFinderTypeError,
|
||||
NoFurtherDataError
|
||||
)
|
||||
|
||||
|
||||
@@ -157,51 +157,67 @@ class TradingEnvironment(object):
|
||||
# stack.
|
||||
return False
|
||||
|
||||
def update_asset_finder(self,
|
||||
clear_metadata=False,
|
||||
asset_finder=None,
|
||||
asset_metadata=None,
|
||||
identifiers=None):
|
||||
def write_data(self,
|
||||
engine=None,
|
||||
equities_data={},
|
||||
futures_data={},
|
||||
exchanges_data={},
|
||||
root_symbols_data={},
|
||||
equities_identifiers=[],
|
||||
futures_identifiers=[],
|
||||
exchanges_identifiers=[],
|
||||
root_symbols_identifiers=[]):
|
||||
""" Write the supplied data to the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
equities_data: dict
|
||||
A dictionary of equity metadata
|
||||
futures_data: dict
|
||||
A dictionary of futures metadata
|
||||
exchanges_data: dict
|
||||
A dictionary of exchanges metadata
|
||||
root_symbols_data: dict
|
||||
A dictionary of root symbols metadata
|
||||
equities_identifiers: list
|
||||
A list of equities identifiers (sids or symbols)
|
||||
futures_identifiers: list
|
||||
A list of futures identifiers (sids or symbols)
|
||||
exchanges_identifiers: list
|
||||
A list of exchanges identifiers (ids or names)
|
||||
root_symbols_identifiers: list
|
||||
A list of root symbols identifiers (ids or symbols)
|
||||
"""
|
||||
Updates the AssetFinder using the provided asset metadata and
|
||||
identifiers.
|
||||
If clear_metadata is True, all metadata and assets held in the
|
||||
asset_finder will be erased before new metadata is provided.
|
||||
If asset_finder is provided, the existing asset_finder will be replaced
|
||||
outright with the new asset_finder.
|
||||
If asset_metadata is provided, the existing metadata will be cleared
|
||||
and replaced with the provided metadata.
|
||||
All identifiers will be inserted in the asset metadata if they are not
|
||||
already present.
|
||||
|
||||
:param clear_metadata: A boolean
|
||||
:param asset_finder: An AssetFinder object to replace the environment's
|
||||
existing asset_finder
|
||||
:param asset_metadata: A dict, DataFrame, or readable object
|
||||
:param identifiers: A list of identifiers to be inserted
|
||||
:return:
|
||||
"""
|
||||
if clear_metadata:
|
||||
self.engine = create_engine('sqlite:///:memory:')
|
||||
if engine:
|
||||
self.engine = engine
|
||||
|
||||
if asset_finder is not None:
|
||||
if not isinstance(asset_finder, AssetFinder):
|
||||
raise UpdateAssetFinderTypeError(cls=asset_finder.__class__)
|
||||
self.asset_finder = asset_finder
|
||||
if (equities_data or futures_data or exchanges_data or
|
||||
root_symbols_data):
|
||||
self._write_data_dicts(equities_data, futures_data,
|
||||
exchanges_data, root_symbols_data)
|
||||
|
||||
if asset_metadata is not None:
|
||||
self.engine = create_engine('sqlite:///:memory:')
|
||||
if isinstance(asset_metadata, dict):
|
||||
asset_writer = AssetDBWriterFromDictionary(
|
||||
equities=asset_metadata)
|
||||
elif isinstance(asset_metadata, pd.DataFrame):
|
||||
asset_writer = AssetDBWriterFromDataFrame(
|
||||
equities=asset_metadata)
|
||||
asset_writer.write_all(self.engine)
|
||||
if (equities_identifiers or futures_identifiers or
|
||||
exchanges_identifiers or root_symbols_identifiers):
|
||||
self._write_data_lists(equities_identifiers,
|
||||
futures_identifiers,
|
||||
exchanges_identifiers,
|
||||
root_symbols_identifiers)
|
||||
|
||||
if identifiers is not None:
|
||||
asset_writer = AssetDBWriterFromList(equities=identifiers)
|
||||
asset_writer.write_all(self.engine)
|
||||
def _write_data_lists(self, equities=[], futures=[],
|
||||
exchanges=[], root_symbols=[]):
|
||||
AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\
|
||||
.write_all(self.engine)
|
||||
|
||||
def _write_data_dicts(self, equities={}, futures={},
|
||||
exchanges={}, root_symbols={}):
|
||||
AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\
|
||||
.write_all(self.engine)
|
||||
|
||||
def _write_data_dataframes(self, equities, futures,
|
||||
exchanges, root_symbols):
|
||||
AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\
|
||||
.write_all(self.engine)
|
||||
|
||||
def normalize_date(self, test_date):
|
||||
test_date = pd.Timestamp(test_date, tz='UTC')
|
||||
|
||||
@@ -93,8 +93,8 @@ class RandomWalkSource(DataSource):
|
||||
self.sd = sd
|
||||
|
||||
self.sids = self.start_prices.keys()
|
||||
TradingEnvironment.instance().update_asset_finder(
|
||||
identifiers=self.sids
|
||||
TradingEnvironment.instance().write_data(
|
||||
equities_identifiers=self.sids
|
||||
)
|
||||
|
||||
self.open_and_closes = \
|
||||
|
||||
@@ -136,7 +136,7 @@ class SpecificEquityTrades(object):
|
||||
'sids',
|
||||
set(event.sid for event in self.event_list)
|
||||
)
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
env.write_data(equities_identifiers=self.identifiers)
|
||||
assets_by_identifier = {}
|
||||
for identifier in self.identifiers:
|
||||
assets_by_identifier[identifier] = env.asset_finder.\
|
||||
@@ -160,7 +160,7 @@ class SpecificEquityTrades(object):
|
||||
self.concurrent = kwargs.get('concurrent', False)
|
||||
|
||||
self.identifiers = kwargs.get('sids', [1, 2])
|
||||
env.update_asset_finder(identifiers=self.identifiers)
|
||||
env.write_data(equities_identifiers=self.identifiers)
|
||||
assets_by_identifier = {}
|
||||
for identifier in self.identifiers:
|
||||
assets_by_identifier[identifier] = env.asset_finder.\
|
||||
|
||||
@@ -119,7 +119,7 @@ def create_trade_history(sid, prices, amounts, interval, sim_params,
|
||||
source_id="test_factory"):
|
||||
trades = []
|
||||
current = sim_params.first_open
|
||||
trading.environment.update_asset_finder(identifiers=[sid])
|
||||
trading.environment.write_data(equities_identifiers=[sid])
|
||||
|
||||
oneday = timedelta(days=1)
|
||||
use_midnight = interval >= oneday
|
||||
@@ -311,7 +311,7 @@ def create_test_df_source(sim_params=None, bars='daily'):
|
||||
|
||||
df = pd.DataFrame(x, index=index, columns=[0])
|
||||
|
||||
trading.environment.update_asset_finder(identifiers=[0])
|
||||
trading.environment.write_data(equities_identifiers=[0])
|
||||
|
||||
return DataFrameSource(df), df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user