mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 20:06:18 +08:00
BUG: Fix time spent checking equality of floating point numbers.
The use of np.allclose introduced a severe performance penalty, caused by the creation of two `np.array`s for each check. Instead create and use a similar check which maintains tolerance to floating point rounding, but operates only on scalars.
This commit is contained in:
@@ -64,6 +64,7 @@ from dateutil.relativedelta import relativedelta
|
||||
|
||||
import zipline.finance.trading as trading
|
||||
from zipline.utils.date_utils import epoch_now
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -94,7 +95,7 @@ def sharpe_ratio(algorithm_volatility, algorithm_return, treasury_return):
|
||||
Returns:
|
||||
float. The Sharpe ratio.
|
||||
"""
|
||||
if np.allclose(algorithm_volatility, 0):
|
||||
if zp_math.tolerant_equals(algorithm_volatility, 0):
|
||||
return 0.0
|
||||
|
||||
return (algorithm_return - treasury_return) / algorithm_volatility
|
||||
@@ -121,7 +122,7 @@ def sortino_ratio(algorithm_returns, algorithm_period_return, mar):
|
||||
downside = (rets[rets < mar] - mar) ** 2
|
||||
dr = np.sqrt(downside.sum() / len(rets))
|
||||
|
||||
if np.allclose(dr, 0):
|
||||
if zp_math.tolerant_equals(dr, 0):
|
||||
return 0.0
|
||||
|
||||
return (algorithm_period_return - mar) / dr
|
||||
@@ -144,7 +145,11 @@ def information_ratio(algorithm_returns, benchmark_returns):
|
||||
|
||||
relative_deviation = relative_returns.std(ddof=1)
|
||||
|
||||
if np.allclose(relative_deviation, 0) or np.isnan(relative_deviation):
|
||||
if (
|
||||
zp_math.tolerant_equals(relative_deviation, 0)
|
||||
or
|
||||
np.isnan(relative_deviation)
|
||||
):
|
||||
return 0.0
|
||||
|
||||
return np.mean(relative_returns) / relative_deviation
|
||||
|
||||
@@ -18,8 +18,7 @@ import math
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
|
||||
import numpy as np
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
from logbook import Processor
|
||||
|
||||
@@ -71,7 +70,11 @@ def transact_stub(slippage, commission, event, open_orders):
|
||||
transactions = slippage.simulate(event, open_orders)
|
||||
|
||||
for transaction in transactions:
|
||||
if transaction and not np.allclose(transaction.amount, 0):
|
||||
if (
|
||||
transaction
|
||||
and not
|
||||
zp_math.tolerant_equals(transaction.amount, 0)
|
||||
):
|
||||
direction = math.copysign(1, transaction.amount)
|
||||
per_share, total_commission = commission.calculate(transaction)
|
||||
transaction.price = transaction.price + (per_share * direction)
|
||||
@@ -138,7 +141,7 @@ class VolumeShareSlippage(object):
|
||||
|
||||
open_amount = order.amount - order.filled
|
||||
|
||||
if np.allclose(open_amount, 0):
|
||||
if zp_math.tolerant_equals(open_amount, 0):
|
||||
continue
|
||||
|
||||
# check price limits, continue if the
|
||||
@@ -150,7 +153,11 @@ class VolumeShareSlippage(object):
|
||||
# price impact accounts for the total volume of transactions
|
||||
# created against the current minute bar
|
||||
remaining_volume = max_volume - total_volume
|
||||
if remaining_volume <= 0 or np.allclose(remaining_volume, 0):
|
||||
if (
|
||||
remaining_volume <= 0
|
||||
or
|
||||
zp_math.tolerant_equals(remaining_volume, 0)
|
||||
):
|
||||
# we can't fill any more transactions
|
||||
return txns
|
||||
|
||||
@@ -209,7 +216,7 @@ class FixedSlippage(object):
|
||||
if not order.triggered:
|
||||
continue
|
||||
|
||||
if np.allclose(order.amount, 0):
|
||||
if zp_math.tolerant_equals(order.amount, 0):
|
||||
return txns
|
||||
|
||||
txn = create_transaction(
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import math
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
from copy import copy
|
||||
@@ -32,6 +31,7 @@ from zipline.finance.slippage import (
|
||||
check_order_triggers
|
||||
)
|
||||
from zipline.finance.commission import PerShare
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
|
||||
@@ -83,7 +83,7 @@ class Blotter(object):
|
||||
yield date, results
|
||||
|
||||
def process_trade(self, trade_event):
|
||||
if np.allclose(trade_event.volume, 0):
|
||||
if zp_math.tolerant_equals(trade_event.volume, 0):
|
||||
# there are zero volume trade_events bc some stocks trade
|
||||
# less frequently than once per minute.
|
||||
return []
|
||||
|
||||
@@ -16,10 +16,9 @@
|
||||
from collections import defaultdict
|
||||
from math import sqrt
|
||||
|
||||
import numpy as np
|
||||
|
||||
from zipline.errors import WrongDataForTransform
|
||||
from zipline.transforms.utils import EventWindow, TransformMeta
|
||||
import zipline.utils.math_utils as zp_math
|
||||
|
||||
|
||||
class MovingStandardDev(object):
|
||||
@@ -118,7 +117,7 @@ class MovingStandardDevWindow(EventWindow):
|
||||
s_squared = (self.sum_sqr - self.sum * average) \
|
||||
/ (len(self) - 1)
|
||||
|
||||
if np.allclose(0, s_squared):
|
||||
if zp_math.tolerant_equals(0, s_squared):
|
||||
return 0.0
|
||||
stddev = sqrt(s_squared)
|
||||
return stddev
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#
|
||||
# Copyright 2013 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def tolerant_equals(a, b, atol=10e-7, rtol=10e-7):
|
||||
return math.fabs(a - b) <= (atol + rtol * math.fabs(b))
|
||||
Reference in New Issue
Block a user