diff --git a/zipline/assets/asset_writer.py b/zipline/assets/asset_writer.py index 5dcd4da8..ec47a366 100644 --- a/zipline/assets/asset_writer.py +++ b/zipline/assets/asset_writer.py @@ -2,6 +2,7 @@ from abc import ( ABCMeta, abstractmethod, ) +from collections import namedtuple import pandas as pd from six import with_metaclass @@ -9,6 +10,9 @@ import sqlalchemy as sa from zipline.errors import SidAssignmentError +# Define a namedtuple for use with the load_data and _load_data methods +AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols') + ASSET_FIELDS = frozenset({ 'sid', 'asset_type', @@ -66,21 +70,6 @@ ROOT_SYMBOL_TABLE_FIELDS = ({ }) -class AssetData(object): - """ Class to store collection of asset data. - """ - def __init__(self, equities=None, futures=None, - exchanges=None, root_symbols=None): - """ - Data supplied to this object should be of a consistent type. - """ - - self.equities = equities - self.futures = futures - self.exchanges = exchanges - self.root_symbols = root_symbols - - class AssetDBWriter(with_metaclass(ABCMeta)): """ Class used to write arbitrary data to SQLite database. @@ -124,12 +113,12 @@ class AssetDBWriter(with_metaclass(ABCMeta)): # Create SQL tables self.init_db(engine, constraints) # Get the data to add to SQL - equities, futures, exchanges, root_symbols = self.load_data() + data = self.load_data() with engine.begin() as txn: - self._write_exchanges(exchanges, txn) - self._write_root_symbols(root_symbols, txn) - self._write_futures(futures, txn) - self._write_equities(equities, fuzzy_char, txn) + self._write_exchanges(data.exchanges, txn) + self._write_root_symbols(data.root_symbols, txn) + self._write_futures(data.futures, txn) + self._write_equities(data.equities, fuzzy_char, txn) def _write_exchanges(self, exchanges, bind=None): self.futures_exchanges.insert().values( @@ -281,8 +270,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)): equities, futures, exchanges, root_symbols """ - equities_data, futures_data, exchanges_data, root_symbols_data = \ - self._load_data() + data = self._load_data() # ******** Generate equities data ******** @@ -416,7 +404,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)): copy=False ) - return equities_data, futures_data, exchanges_data, root_symbols_data + return AssetData(equities=equities_output, + futures=futures_output, + exchanges=exchanges_output, + root_symbols=root_symbols_output) @staticmethod def dict_subset(dict_, subset): @@ -490,10 +481,10 @@ class AssetDBWriterFromList(AssetDBWriter): _exchanges = pd.DataFrame.from_dict(_exchanges, orient='index') _root_symbols = pd.DataFrame.from_dict(_root_symbols, orient='index') - return asset_data(equities=_equities, - futures=_futures, - exchanges=_exchanges, - root_symbols=_root_symbols) + return AssetData(equities=_equities, + futures=_futures, + exchanges=_exchanges, + root_symbols=_root_symbols) class AssetDBWriterFromDictionary(AssetDBWriter): @@ -521,7 +512,10 @@ class AssetDBWriterFromDictionary(AssetDBWriter): _root_symbols = pd.DataFrame.from_dict(self._root_symbols, orient='index') - return _equities, _futures, _exchanges, _root_symbols + return AssetData(equities=_equities, + futures=_futures, + exchanges=_exchanges, + root_symbols=_root_symbols) class AssetDBWriterFromDataFrame(AssetDBWriter): @@ -539,5 +533,7 @@ class AssetDBWriterFromDataFrame(AssetDBWriter): def _load_data(self): - return self._equities, self._futures, self._exchanges, - self._root_symbols + return AssetData(equities=self._equities, + futures=self._futures, + exchanges=self._exchanges, + root_symbols=self._root_symbols)