fix softmax

This commit is contained in:
wassname
2017-11-15 10:44:44 +08:00
parent 9ea8f0d570
commit d054c79bd6
2 changed files with 7 additions and 1 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ def MDD(returns):
def softmax(w, t=1.0):
"""softmax implemented in numpy."""
log_eps = np.log(eps)
w = np.clip(w, -log_eps, log_eps) # avoid inf/nan
w = np.clip(w, log_eps, -log_eps) # avoid inf/nan
e = np.exp(np.array(w) / t)
dist = e / np.sum(e)
return dist
+6
View File
@@ -10,6 +10,12 @@ def test_softmax():
y = softmax(x)
np.testing.assert_almost_equal(y.sum(), 1)
x = np.random.random([0, 1])
y = softmax(x)
assert y[0] < y[1]
assert (y > 0).all()
assert (y < 1).all()
def test_maxdrawdown():
assert MDD(np.array([0, 0, 0, 0, 1, 2, 3])) == 0