BUG: Fix time spent checking equality of floating point numbers.

The use of np.allclose introduced a severe performance penalty,
caused by the creation of two `np.array`s for each check.

Instead create and use a similar check which maintains tolerance
to floating point rounding, but operates only on scalars.
This commit is contained in:
Eddie Hebert
2013-04-16 13:09:26 -04:00
parent 9f0500aa33
commit bf1fc42acc
5 changed files with 45 additions and 14 deletions
+8 -3
View File
@@ -64,6 +64,7 @@ from dateutil.relativedelta import relativedelta
import zipline.finance.trading as trading
from zipline.utils.date_utils import epoch_now
import zipline.utils.math_utils as zp_math
import pandas as pd
@@ -94,7 +95,7 @@ def sharpe_ratio(algorithm_volatility, algorithm_return, treasury_return):
Returns:
float. The Sharpe ratio.
"""
if np.allclose(algorithm_volatility, 0):
if zp_math.tolerant_equals(algorithm_volatility, 0):
return 0.0
return (algorithm_return - treasury_return) / algorithm_volatility
@@ -121,7 +122,7 @@ def sortino_ratio(algorithm_returns, algorithm_period_return, mar):
downside = (rets[rets < mar] - mar) ** 2
dr = np.sqrt(downside.sum() / len(rets))
if np.allclose(dr, 0):
if zp_math.tolerant_equals(dr, 0):
return 0.0
return (algorithm_period_return - mar) / dr
@@ -144,7 +145,11 @@ def information_ratio(algorithm_returns, benchmark_returns):
relative_deviation = relative_returns.std(ddof=1)
if np.allclose(relative_deviation, 0) or np.isnan(relative_deviation):
if (
zp_math.tolerant_equals(relative_deviation, 0)
or
np.isnan(relative_deviation)
):
return 0.0
return np.mean(relative_returns) / relative_deviation
+13 -6
View File
@@ -18,8 +18,7 @@ import math
from copy import copy
from functools import partial
from zipline.protocol import DATASOURCE_TYPE
import numpy as np
import zipline.utils.math_utils as zp_math
from logbook import Processor
@@ -71,7 +70,11 @@ def transact_stub(slippage, commission, event, open_orders):
transactions = slippage.simulate(event, open_orders)
for transaction in transactions:
if transaction and not np.allclose(transaction.amount, 0):
if (
transaction
and not
zp_math.tolerant_equals(transaction.amount, 0)
):
direction = math.copysign(1, transaction.amount)
per_share, total_commission = commission.calculate(transaction)
transaction.price = transaction.price + (per_share * direction)
@@ -138,7 +141,7 @@ class VolumeShareSlippage(object):
open_amount = order.amount - order.filled
if np.allclose(open_amount, 0):
if zp_math.tolerant_equals(open_amount, 0):
continue
# check price limits, continue if the
@@ -150,7 +153,11 @@ class VolumeShareSlippage(object):
# price impact accounts for the total volume of transactions
# created against the current minute bar
remaining_volume = max_volume - total_volume
if remaining_volume <= 0 or np.allclose(remaining_volume, 0):
if (
remaining_volume <= 0
or
zp_math.tolerant_equals(remaining_volume, 0)
):
# we can't fill any more transactions
return txns
@@ -209,7 +216,7 @@ class FixedSlippage(object):
if not order.triggered:
continue
if np.allclose(order.amount, 0):
if zp_math.tolerant_equals(order.amount, 0):
return txns
txn = create_transaction(
+2 -2
View File
@@ -14,7 +14,6 @@
# limitations under the License.
import itertools
import math
import numpy as np
import uuid
from copy import copy
@@ -32,6 +31,7 @@ from zipline.finance.slippage import (
check_order_triggers
)
from zipline.finance.commission import PerShare
import zipline.utils.math_utils as zp_math
log = Logger('Trade Simulation')
@@ -83,7 +83,7 @@ class Blotter(object):
yield date, results
def process_trade(self, trade_event):
if np.allclose(trade_event.volume, 0):
if zp_math.tolerant_equals(trade_event.volume, 0):
# there are zero volume trade_events bc some stocks trade
# less frequently than once per minute.
return []
+2 -3
View File
@@ -16,10 +16,9 @@
from collections import defaultdict
from math import sqrt
import numpy as np
from zipline.errors import WrongDataForTransform
from zipline.transforms.utils import EventWindow, TransformMeta
import zipline.utils.math_utils as zp_math
class MovingStandardDev(object):
@@ -118,7 +117,7 @@ class MovingStandardDevWindow(EventWindow):
s_squared = (self.sum_sqr - self.sum * average) \
/ (len(self) - 1)
if np.allclose(0, s_squared):
if zp_math.tolerant_equals(0, s_squared):
return 0.0
stddev = sqrt(s_squared)
return stddev
+20
View File
@@ -0,0 +1,20 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
def tolerant_equals(a, b, atol=10e-7, rtol=10e-7):
return math.fabs(a - b) <= (atol + rtol * math.fabs(b))