From 97e980751f7d7bf6fee8f4a0f5ca7a5c83e7b0ef Mon Sep 17 00:00:00 2001 From: Stewart Douglas Date: Tue, 4 Aug 2015 17:54:36 -0400 Subject: [PATCH] MAINT: Integrate asset writer changes into TradingEnvironment --- tests/test_assets.py | 3 +-- zipline/assets/assets.py | 1 - zipline/finance/trading.py | 25 ++++++++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/test_assets.py b/tests/test_assets.py index 3bfb2470..fb5de11f 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -88,8 +88,7 @@ def build_lookup_generic_cases(): }, ], index='sid') - from nose.tools import set_trace; set_trace() - db_path = '/Users/stewart/temp.db' + db_path = '~/temp.db' conn = sqlite3.connect(db_path) asset_writer = AssetDBWriterFromDataFrame(equities=frame) asset_writer.write_all(conn) diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index e1e286ab..05c8bcb4 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -14,7 +14,6 @@ from abc import ABCMeta from numbers import Integral -# import sqlite3 from sqlite3 import Row import warnings diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index 0df434a5..cda614ed 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -17,6 +17,7 @@ import bisect import logbook import datetime from functools import wraps +import sqlite3 import pandas as pd import numpy as np @@ -24,6 +25,10 @@ import numpy as np from zipline.data.loader import load_market_data from zipline.utils import tradingcalendar from zipline.assets import AssetFinder +from zipline.assets.asset_writer import ( + AssetDBWriterFromList, + AssetDBWriterFromDictionary, + AssetDBWriterFromDataFrame) from zipline.errors import ( NoFurtherDataError, UpdateAssetFinderTypeError, @@ -132,7 +137,10 @@ class TradingEnvironment(object): self.exchange_tz = exchange_tz - self.asset_finder = AssetFinder() + self.conn = sqlite3.connect(':memory:') + asset_writer = AssetDBWriterFromDictionary() + asset_writer.write_all(self.conn) + self.asset_finder = AssetFinder(self.conn) def __enter__(self, *args, **kwargs): global environment @@ -174,7 +182,7 @@ class TradingEnvironment(object): :return: """ if clear_metadata: - self.asset_finder.clear_metadata() + self.conn = sqlite3.connect(':memory:') if asset_finder is not None: if not isinstance(asset_finder, AssetFinder): @@ -182,11 +190,18 @@ class TradingEnvironment(object): self.asset_finder = asset_finder if asset_metadata is not None: - self.asset_finder.clear_metadata() - self.asset_finder.consume_metadata(asset_metadata) + self.conn = sqlite3.connect(':memory:') + if isinstance(asset_metadata, dict): + asset_writer = AssetDBWriterFromDictionary( + equities=asset_metadata) + elif isinstance(asset_metadata, pd.DataFrame): + asset_writer = AssetDBWriterFromDataFrame( + equities=asset_metadata) + asset_writer.write_all(self.conn) if identifiers is not None: - self.asset_finder.consume_identifiers(identifiers) + asset_writer = AssetDBWriterFromList(equities=identifiers) + asset_writer.write_all(self.conn) def normalize_date(self, test_date): test_date = pd.Timestamp(test_date, tz='UTC')