ENH: Return namedtuple from load methods

This commit is contained in:
Stewart Douglas
2015-08-07 15:53:47 -04:00
committed by jfkirk
parent b658f579d2
commit d55920a43b
+26 -30
View File
@@ -2,6 +2,7 @@ from abc import (
ABCMeta,
abstractmethod,
)
from collections import namedtuple
import pandas as pd
from six import with_metaclass
@@ -9,6 +10,9 @@ import sqlalchemy as sa
from zipline.errors import SidAssignmentError
# Define a namedtuple for use with the load_data and _load_data methods
AssetData = namedtuple('AssetData', 'equities futures exchanges root_symbols')
ASSET_FIELDS = frozenset({
'sid',
'asset_type',
@@ -66,21 +70,6 @@ ROOT_SYMBOL_TABLE_FIELDS = ({
})
class AssetData(object):
""" Class to store collection of asset data.
"""
def __init__(self, equities=None, futures=None,
exchanges=None, root_symbols=None):
"""
Data supplied to this object should be of a consistent type.
"""
self.equities = equities
self.futures = futures
self.exchanges = exchanges
self.root_symbols = root_symbols
class AssetDBWriter(with_metaclass(ABCMeta)):
"""
Class used to write arbitrary data to SQLite database.
@@ -124,12 +113,12 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
# Create SQL tables
self.init_db(engine, constraints)
# Get the data to add to SQL
equities, futures, exchanges, root_symbols = self.load_data()
data = self.load_data()
with engine.begin() as txn:
self._write_exchanges(exchanges, txn)
self._write_root_symbols(root_symbols, txn)
self._write_futures(futures, txn)
self._write_equities(equities, fuzzy_char, txn)
self._write_exchanges(data.exchanges, txn)
self._write_root_symbols(data.root_symbols, txn)
self._write_futures(data.futures, txn)
self._write_equities(data.equities, fuzzy_char, txn)
def _write_exchanges(self, exchanges, bind=None):
self.futures_exchanges.insert().values(
@@ -281,8 +270,7 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
equities, futures, exchanges, root_symbols
"""
equities_data, futures_data, exchanges_data, root_symbols_data = \
self._load_data()
data = self._load_data()
# ******** Generate equities data ********
@@ -416,7 +404,10 @@ class AssetDBWriter(with_metaclass(ABCMeta)):
copy=False
)
return equities_data, futures_data, exchanges_data, root_symbols_data
return AssetData(equities=equities_output,
futures=futures_output,
exchanges=exchanges_output,
root_symbols=root_symbols_output)
@staticmethod
def dict_subset(dict_, subset):
@@ -490,10 +481,10 @@ class AssetDBWriterFromList(AssetDBWriter):
_exchanges = pd.DataFrame.from_dict(_exchanges, orient='index')
_root_symbols = pd.DataFrame.from_dict(_root_symbols, orient='index')
return asset_data(equities=_equities,
futures=_futures,
exchanges=_exchanges,
root_symbols=_root_symbols)
return AssetData(equities=_equities,
futures=_futures,
exchanges=_exchanges,
root_symbols=_root_symbols)
class AssetDBWriterFromDictionary(AssetDBWriter):
@@ -521,7 +512,10 @@ class AssetDBWriterFromDictionary(AssetDBWriter):
_root_symbols = pd.DataFrame.from_dict(self._root_symbols,
orient='index')
return _equities, _futures, _exchanges, _root_symbols
return AssetData(equities=_equities,
futures=_futures,
exchanges=_exchanges,
root_symbols=_root_symbols)
class AssetDBWriterFromDataFrame(AssetDBWriter):
@@ -539,5 +533,7 @@ class AssetDBWriterFromDataFrame(AssetDBWriter):
def _load_data(self):
return self._equities, self._futures, self._exchanges,
self._root_symbols
return AssetData(equities=self._equities,
futures=self._futures,
exchanges=self._exchanges,
root_symbols=self._root_symbols)