mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 17:47:56 +08:00
Merge pull request #1788 from quantopian/rounding-cutoff-2
Do not explicitly round asset prices
This commit is contained in:
+128
-22
@@ -19,41 +19,64 @@ from numpy.testing import assert_almost_equal
|
||||
import pandas as pd
|
||||
from pandas.tslib import Timedelta
|
||||
|
||||
from zipline.assets import Equity
|
||||
from zipline.assets import Equity, Future
|
||||
from zipline.data.data_portal import HISTORY_FREQUENCIES, OHLCV_FIELDS
|
||||
from zipline.data.minute_bars import (
|
||||
FUTURES_MINUTES_PER_DAY,
|
||||
US_EQUITIES_MINUTES_PER_DAY,
|
||||
)
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import (
|
||||
ZiplineTestCase,
|
||||
WithTradingSessions,
|
||||
WithDataPortal,
|
||||
alias,
|
||||
)
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
|
||||
|
||||
class DataPortalTestBase(WithDataPortal,
|
||||
WithTradingSessions,
|
||||
ZiplineTestCase):
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = (1,)
|
||||
ASSET_FINDER_EQUITY_SIDS = (1, 2)
|
||||
START_DATE = pd.Timestamp('2016-08-01')
|
||||
END_DATE = pd.Timestamp('2016-08-08')
|
||||
|
||||
TRADING_CALENDAR_STRS = ('NYSE', 'CME')
|
||||
TRADING_CALENDAR_STRS = ('NYSE', 'us_futures')
|
||||
|
||||
EQUITY_DAILY_BAR_SOURCE_FROM_MINUTE = True
|
||||
|
||||
# Since the future with sid 10001 has a tick size of 0.0001, its prices
|
||||
# should be rounded out to 4 decimal places. To test that this rounding
|
||||
# occurs correctly, store its prices out to 5 decimal places by using a
|
||||
# multiplier of 100,000 when writing its values.
|
||||
OHLC_RATIOS_PER_SID = {10001: 100000}
|
||||
|
||||
@classmethod
|
||||
def make_root_symbols_info(self):
|
||||
return pd.DataFrame({
|
||||
'root_symbol': ['BAR', 'BUZ'],
|
||||
'root_symbol_id': [1, 2],
|
||||
'exchange': ['CME', 'CME'],
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def make_futures_info(cls):
|
||||
trading_sessions = cls.trading_sessions['CME']
|
||||
trading_sessions = cls.trading_sessions['us_futures']
|
||||
return pd.DataFrame({
|
||||
'sid': [10000],
|
||||
'root_symbol': ['BAR'],
|
||||
'symbol': ['BARA'],
|
||||
'start_date': [trading_sessions[1]],
|
||||
'end_date': [cls.END_DATE],
|
||||
'sid': [10000, 10001],
|
||||
'root_symbol': ['BAR', 'BUZ'],
|
||||
'symbol': ['BARA', 'BUZZ'],
|
||||
'start_date': [trading_sessions[1], trading_sessions[0]],
|
||||
'end_date': [cls.END_DATE, cls.END_DATE],
|
||||
# TODO: Make separate from 'end_date'
|
||||
'notice_date': [cls.END_DATE],
|
||||
'expiration_date': [cls.END_DATE],
|
||||
'multiplier': [500],
|
||||
'exchange': ['CME'],
|
||||
'notice_date': [cls.END_DATE, cls.END_DATE],
|
||||
'expiration_date': [cls.END_DATE, cls.END_DATE],
|
||||
'tick_size': [0.01, 0.0001],
|
||||
'multiplier': [500, 50000],
|
||||
'exchange': ['CME', 'CME'],
|
||||
})
|
||||
|
||||
@classmethod
|
||||
@@ -102,13 +125,25 @@ class DataPortalTestBase(WithDataPortal,
|
||||
'volume': full(len(dts), 0),
|
||||
},
|
||||
index=dts))
|
||||
yield 1, pd.concat(dfs)
|
||||
asset1_df = pd.concat(dfs)
|
||||
yield 1, asset1_df
|
||||
|
||||
asset2_df = pd.DataFrame(
|
||||
{
|
||||
'open': 1.0055,
|
||||
'high': 1.0059,
|
||||
'low': 1.0051,
|
||||
'close': 1.0055,
|
||||
'volume': 100,
|
||||
},
|
||||
index=asset1_df.index,
|
||||
)
|
||||
yield 2, asset2_df
|
||||
|
||||
@classmethod
|
||||
def make_future_minute_bar_data(cls):
|
||||
asset = cls.asset_finder.retrieve_asset(10000)
|
||||
trading_calendar = cls.trading_calendars[asset.exchange]
|
||||
trading_sessions = cls.trading_sessions[asset.exchange]
|
||||
trading_calendar = cls.trading_calendars[Future]
|
||||
trading_sessions = cls.trading_sessions['us_futures']
|
||||
# No data on first day, future asset intentionally not on the same
|
||||
# dates as equities, so that cross-wiring of results do not create a
|
||||
# false positive.
|
||||
@@ -154,7 +189,21 @@ class DataPortalTestBase(WithDataPortal,
|
||||
'volume': full(len(dts), 0),
|
||||
},
|
||||
index=dts))
|
||||
yield asset.sid, pd.concat(dfs)
|
||||
asset10000_df = pd.concat(dfs)
|
||||
yield 10000, asset10000_df
|
||||
|
||||
missing_dts = trading_calendar.minutes_for_session(trading_sessions[0])
|
||||
asset10001_df = pd.DataFrame(
|
||||
{
|
||||
'open': 1.00549,
|
||||
'high': 1.00591,
|
||||
'low': 1.00507,
|
||||
'close': 1.0055,
|
||||
'volume': 100,
|
||||
},
|
||||
index=missing_dts.append(asset10000_df.index),
|
||||
)
|
||||
yield 10001, asset10001_df
|
||||
|
||||
def test_get_last_traded_equity_minute(self):
|
||||
trading_calendar = self.trading_calendars[Equity]
|
||||
@@ -180,7 +229,7 @@ class DataPortalTestBase(WithDataPortal,
|
||||
|
||||
def test_get_last_traded_future_minute(self):
|
||||
asset = self.asset_finder.retrieve_asset(10000)
|
||||
trading_calendar = self.trading_calendars[asset.exchange]
|
||||
trading_calendar = self.trading_calendars[Future]
|
||||
# Case: Missing data at front of data set, and request dt is before
|
||||
# first value.
|
||||
dts = trading_calendar.minutes_for_session(self.trading_days[0])
|
||||
@@ -258,7 +307,7 @@ class DataPortalTestBase(WithDataPortal,
|
||||
assert_almost_equal(array(list(expected.values())), result)
|
||||
|
||||
def test_get_spot_value_future_minute(self):
|
||||
trading_calendar = self.trading_calendars['CME']
|
||||
trading_calendar = self.trading_calendars[Future]
|
||||
asset = self.asset_finder.retrieve_asset(10000)
|
||||
dts = trading_calendar.minutes_for_session(self.trading_days[3])
|
||||
|
||||
@@ -299,7 +348,7 @@ class DataPortalTestBase(WithDataPortal,
|
||||
def test_get_spot_value_multiple_assets(self):
|
||||
equity = self.asset_finder.retrieve_asset(1)
|
||||
future = self.asset_finder.retrieve_asset(10000)
|
||||
trading_calendar = self.trading_calendars['CME']
|
||||
trading_calendar = self.trading_calendars[Future]
|
||||
dts = trading_calendar.minutes_for_session(self.trading_days[3])
|
||||
|
||||
# We expect the outputs to be lists of spot values.
|
||||
@@ -384,7 +433,7 @@ class DataPortalTestBase(WithDataPortal,
|
||||
"return that as the last trade on the fifth.")
|
||||
|
||||
future = self.asset_finder.retrieve_asset(10000)
|
||||
calendar = self.trading_calendars[future.exchange]
|
||||
calendar = self.trading_calendars[Future]
|
||||
minutes = calendar.minutes_for_session(self.trading_days[3])
|
||||
result = self.data_portal.get_last_traded_dt(future,
|
||||
minutes[3],
|
||||
@@ -405,6 +454,63 @@ class DataPortalTestBase(WithDataPortal,
|
||||
splits = self.data_portal.get_splits([], self.trading_days[2])
|
||||
self.assertEqual([], splits)
|
||||
|
||||
@parameter_space(frequency=HISTORY_FREQUENCIES, field=OHLCV_FIELDS)
|
||||
def test_price_rounding(self, frequency, field):
|
||||
equity = self.asset_finder.retrieve_asset(2)
|
||||
future = self.asset_finder.retrieve_asset(10001)
|
||||
cf = self.data_portal.asset_finder.create_continuous_future(
|
||||
'BUZ', 0, 'calendar', None,
|
||||
)
|
||||
minutes = self.nyse_calendar.minutes_for_session(self.trading_days[0])
|
||||
|
||||
if frequency == '1m':
|
||||
minute = minutes[0]
|
||||
expected_equity_volume = 100
|
||||
expected_future_volume = 100
|
||||
data_frequency = 'minute'
|
||||
else:
|
||||
minute = minutes[0].normalize()
|
||||
expected_equity_volume = 100 * US_EQUITIES_MINUTES_PER_DAY
|
||||
expected_future_volume = 100 * FUTURES_MINUTES_PER_DAY
|
||||
data_frequency = 'daily'
|
||||
|
||||
# Equity prices should be floored to three decimal places.
|
||||
expected_equity_values = {
|
||||
'open': 1.005,
|
||||
'high': 1.005,
|
||||
'low': 1.005,
|
||||
'close': 1.005,
|
||||
'volume': expected_equity_volume,
|
||||
}
|
||||
# Futures prices should be rounded to four decimal places.
|
||||
expected_future_values = {
|
||||
'open': 1.0055,
|
||||
'high': 1.0059,
|
||||
'low': 1.0051,
|
||||
'close': 1.0055,
|
||||
'volume': expected_future_volume,
|
||||
}
|
||||
|
||||
result = self.data_portal.get_history_window(
|
||||
assets=[equity, future, cf],
|
||||
end_dt=minute,
|
||||
bar_count=1,
|
||||
frequency=frequency,
|
||||
field=field,
|
||||
data_frequency=data_frequency,
|
||||
)
|
||||
expected_result = pd.DataFrame(
|
||||
{
|
||||
equity: expected_equity_values[field],
|
||||
future: expected_future_values[field],
|
||||
cf: expected_future_values[field],
|
||||
},
|
||||
index=[minute],
|
||||
dtype=float64_dtype,
|
||||
)
|
||||
|
||||
assert_equal(result, expected_result)
|
||||
|
||||
|
||||
class TestDataPortal(DataPortalTestBase):
|
||||
DATA_PORTAL_LAST_AVAILABLE_SESSION = None
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from zipline.utils.math_utils import number_of_decimal_places
|
||||
|
||||
|
||||
class MathUtilsTestCase(TestCase):
|
||||
|
||||
def test_number_of_decimal_places(self):
|
||||
self.assertEqual(number_of_decimal_places(1), 0)
|
||||
self.assertEqual(number_of_decimal_places(3.14), 2)
|
||||
self.assertEqual(number_of_decimal_places('3.14'), 2)
|
||||
self.assertEqual(number_of_decimal_places(-3.14), 2)
|
||||
@@ -26,16 +26,20 @@ from toolz import sliding_window
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.assets import Equity
|
||||
from zipline.assets import Equity, Future
|
||||
from zipline.assets.continuous_futures import ContinuousFuture
|
||||
from zipline.lib._int64window import AdjustedArrayWindow as Int64Window
|
||||
from zipline.lib._float64window import AdjustedArrayWindow as Float64Window
|
||||
from zipline.lib.adjustment import Float64Multiply, Float64Add
|
||||
from zipline.utils.cache import ExpiringCache
|
||||
from zipline.utils.math_utils import number_of_decimal_places
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import float64_dtype
|
||||
from zipline.utils.pandas_utils import find_in_sorted_index
|
||||
|
||||
# Default number of decimal places used for rounding asset prices.
|
||||
DEFAULT_ASSET_PRICE_DECIMALS = 3
|
||||
|
||||
|
||||
class HistoryCompatibleUSEquityAdjustmentReader(object):
|
||||
|
||||
@@ -343,6 +347,21 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
def _array(self, start, end, assets, field):
|
||||
pass
|
||||
|
||||
def _decimal_places_for_asset(self, asset, reference_date):
|
||||
if isinstance(asset, Future) and asset.tick_size:
|
||||
return number_of_decimal_places(asset.tick_size)
|
||||
elif isinstance(asset, ContinuousFuture):
|
||||
# Tick size should be the same for all contracts of a continuous
|
||||
# future, so arbitrarily get the contract with next upcoming auto
|
||||
# close date.
|
||||
oc = self._asset_finder.get_ordered_contracts(asset.root_symbol)
|
||||
contract_sid = oc.contract_before_auto_close(reference_date.value)
|
||||
if contract_sid is not None:
|
||||
contract = self._asset_finder.retrieve_asset(contract_sid)
|
||||
if contract.tick_size:
|
||||
return number_of_decimal_places(contract.tick_size)
|
||||
return DEFAULT_ASSET_PRICE_DECIMALS
|
||||
|
||||
def _ensure_sliding_windows(self, assets, dts, field,
|
||||
is_perspective_after):
|
||||
"""
|
||||
@@ -438,7 +457,8 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
adjs,
|
||||
offset,
|
||||
size,
|
||||
int(is_perspective_after)
|
||||
int(is_perspective_after),
|
||||
self._decimal_places_for_asset(asset, dts[-1]),
|
||||
)
|
||||
sliding_window = SlidingWindow(window, size, start_ix, offset)
|
||||
asset_windows[asset] = sliding_window
|
||||
@@ -533,7 +553,7 @@ class HistoryLoader(with_metaclass(ABCMeta)):
|
||||
return concatenate(
|
||||
[window.get(end_ix) for window in block],
|
||||
axis=1,
|
||||
).round(3)
|
||||
)
|
||||
|
||||
|
||||
class DailyHistoryLoader(HistoryLoader):
|
||||
|
||||
@@ -12,7 +12,7 @@ zipline.lib._intwindow
|
||||
zipline.lib._datewindow
|
||||
"""
|
||||
from numpy cimport ndarray
|
||||
from numpy import asanyarray
|
||||
from numpy import asanyarray, dtype, issubdtype
|
||||
|
||||
|
||||
class Exhausted(Exception):
|
||||
@@ -32,6 +32,10 @@ cdef class AdjustedArrayWindow:
|
||||
|
||||
The arrays yielded by this iterator are always views over the underlying
|
||||
data.
|
||||
|
||||
The `rounding_places` attribute is an integer used to specify the number of
|
||||
decimal places to which the data should be rounded, given that the data is
|
||||
of dtype float. If `rounding_places` is None, no rounding occurs.
|
||||
"""
|
||||
cdef:
|
||||
# ctype must be defined by the file into which this is being copied.
|
||||
@@ -40,6 +44,7 @@ cdef class AdjustedArrayWindow:
|
||||
readonly Py_ssize_t window_length
|
||||
Py_ssize_t anchor, max_anchor, next_adj
|
||||
Py_ssize_t perspective_offset
|
||||
object rounding_places
|
||||
dict adjustments
|
||||
list adjustment_indices
|
||||
ndarray output
|
||||
@@ -50,7 +55,8 @@ cdef class AdjustedArrayWindow:
|
||||
dict adjustments not None,
|
||||
Py_ssize_t offset,
|
||||
Py_ssize_t window_length,
|
||||
Py_ssize_t perspective_offset):
|
||||
Py_ssize_t perspective_offset,
|
||||
object rounding_places):
|
||||
self.data = data
|
||||
self.view_kwargs = view_kwargs
|
||||
self.adjustments = adjustments
|
||||
@@ -67,6 +73,7 @@ cdef class AdjustedArrayWindow:
|
||||
"is perspective_offset={0}".format(
|
||||
perspective_offset))
|
||||
self.perspective_offset = perspective_offset
|
||||
self.rounding_places = rounding_places
|
||||
self.max_anchor = data.shape[0]
|
||||
|
||||
self.next_adj = self.pop_next_adj()
|
||||
@@ -138,6 +145,9 @@ cdef class AdjustedArrayWindow:
|
||||
new_out = asanyarray(self.data[anchor - self.window_length:anchor])
|
||||
if view_kwargs:
|
||||
new_out = new_out.view(**view_kwargs)
|
||||
if self.rounding_places is not None and \
|
||||
issubdtype(new_out.dtype, dtype('float64')):
|
||||
new_out = new_out.round(self.rounding_places)
|
||||
new_out.setflags(write=False)
|
||||
self.output = new_out
|
||||
|
||||
|
||||
@@ -227,6 +227,7 @@ class AdjustedArray(object):
|
||||
offset,
|
||||
window_length,
|
||||
perspective_offset,
|
||||
rounding_places=None,
|
||||
)
|
||||
|
||||
def inspect(self):
|
||||
|
||||
@@ -1136,6 +1136,7 @@ class WithBcolzFutureMinuteBarReader(WithFutureMinuteBarData, WithTmpDir):
|
||||
zipline.testing.create_minute_bar_data
|
||||
"""
|
||||
BCOLZ_FUTURE_MINUTE_BAR_PATH = 'minute_future_pricing'
|
||||
OHLC_RATIOS_PER_SID = None
|
||||
|
||||
@classmethod
|
||||
def make_bcolz_future_minute_bar_rootdir_path(cls):
|
||||
@@ -1155,6 +1156,7 @@ class WithBcolzFutureMinuteBarReader(WithFutureMinuteBarData, WithTmpDir):
|
||||
days[0],
|
||||
days[-1],
|
||||
FUTURES_MINUTES_PER_DAY,
|
||||
ohlc_ratios_per_sid=cls.OHLC_RATIOS_PER_SID,
|
||||
)
|
||||
writer.write(cls.make_future_minute_bar_data())
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from decimal import Decimal
|
||||
import math
|
||||
|
||||
from numpy import isnan
|
||||
@@ -77,3 +78,20 @@ def round_if_near_integer(a, epsilon=1e-4):
|
||||
return round(a)
|
||||
else:
|
||||
return a
|
||||
|
||||
|
||||
def number_of_decimal_places(n):
|
||||
"""
|
||||
Compute the number of decimal places in a number.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> number_of_decimal_places(1)
|
||||
0
|
||||
>>> number_of_decimal_places(3.14)
|
||||
2
|
||||
>>> number_of_decimal_places('3.14')
|
||||
2
|
||||
"""
|
||||
decimal = Decimal(str(n))
|
||||
return -decimal.as_tuple().exponent
|
||||
|
||||
Reference in New Issue
Block a user