diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index e6e886d3..f1bdf930 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -17,12 +17,10 @@ import bisect import logbook import datetime from functools import wraps -import sqlite3 import pandas as pd import numpy as np from sqlalchemy import create_engine -# from multipledispatch import dispatch from zipline.data.loader import load_market_data from zipline.utils import tradingcalendar @@ -163,59 +161,83 @@ class TradingEnvironment(object): futures_data={}, exchanges_data={}, root_symbols_data={}, + equities_df=pd.DataFrame(), + futures_df=pd.DataFrame(), + exchanges_df=pd.DataFrame(), + root_symbols_df=pd.DataFrame(), equities_identifiers=[], futures_identifiers=[], exchanges_identifiers=[], - root_symbols_identifiers=[]): + root_symbols_identifiers=[], + allow_sid_assignment=True): """ Write the supplied data to the database. Parameters ---------- - equities_data: dict + equities_data: dict, optional A dictionary of equity metadata - futures_data: dict + futures_data: dict, optional A dictionary of futures metadata - exchanges_data: dict + exchanges_data: dict, optional A dictionary of exchanges metadata - root_symbols_data: dict + root_symbols_data: dict, optional A dictionary of root symbols metadata - equities_identifiers: list - A list of equities identifiers (sids or symbols) - futures_identifiers: list - A list of futures identifiers (sids or symbols) - exchanges_identifiers: list + equities_df: pandas.DataFrame, optional + A pandas.DataFrame of equity metadata + futures_df: pandas.DataFrame, optional + A pandas.DataFrame of futures metadata + exchanges_df: pandas.DataFrame, optional + A pandas.DataFrame of exchanges metadata + root_symbols_df: pandas.DataFrame, optional + A pandas.DataFrame of root symbols metadata + equities_identifiers: list, optional + A list of equities identifiers (sids, symbols, Assets) + futures_identifiers: list, optional + A list of futures identifiers (sids, symbols, Assets) + exchanges_identifiers: list, optional A list of exchanges identifiers (ids or names) - root_symbols_identifiers: list + root_symbols_identifiers: list, optional A list of root symbols identifiers (ids or symbols) """ - if engine: self.engine = engine + # If any pandas.DataFrame data has been provided, + # write it to the database. + if not(equities_df.empty and futures_df.empty and + exchanges_df.empty and root_symbols_df.empty): + self._write_data_dataframes(equities_df, futures_df, + exchanges_df, root_symbols_df) + if (equities_data or futures_data or exchanges_data or root_symbols_data): self._write_data_dicts(equities_data, futures_data, exchanges_data, root_symbols_data) - if (equities_identifiers or futures_identifiers or - exchanges_identifiers or root_symbols_identifiers): - self._write_data_lists(equities_identifiers, - futures_identifiers, - exchanges_identifiers, - root_symbols_identifiers) + # These could be lists or other iterables such as a pandas.Index. + # For simplicity, don't check whether data has been provided. + self._write_data_lists(equities_identifiers, + futures_identifiers, + exchanges_identifiers, + root_symbols_identifiers, + allow_sid_assignment=allow_sid_assignment) def _write_data_lists(self, equities=[], futures=[], - exchanges=[], root_symbols=[]): + exchanges=[], root_symbols=[], + allow_sid_assignment=True): AssetDBWriterFromList(equities, futures, exchanges, root_symbols)\ - .write_all(self.engine) + .write_all(self.engine, allow_sid_assignment=allow_sid_assignment) def _write_data_dicts(self, equities={}, futures={}, - exchanges={}, root_symbols={}): + exchanges={}, root_symbols={}, + allow_sid_assignment=True): AssetDBWriterFromDictionary(equities, futures, exchanges, root_symbols)\ .write_all(self.engine) - def _write_data_dataframes(self, equities, futures, - exchanges, root_symbols): + def _write_data_dataframes(self, equities=pd.DataFrame(), + futures=pd.DataFrame(), + exchanges=pd.DataFrame(), + root_symbols=pd.DataFrame()): AssetDBWriterFromDataFrame(equities, futures, exchanges, root_symbols)\ .write_all(self.engine)