mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
fix softmax
This commit is contained in:
@@ -17,7 +17,7 @@ def MDD(returns):
|
||||
def softmax(w, t=1.0):
|
||||
"""softmax implemented in numpy."""
|
||||
log_eps = np.log(eps)
|
||||
w = np.clip(w, -log_eps, log_eps) # avoid inf/nan
|
||||
w = np.clip(w, log_eps, -log_eps) # avoid inf/nan
|
||||
e = np.exp(np.array(w) / t)
|
||||
dist = e / np.sum(e)
|
||||
return dist
|
||||
|
||||
@@ -10,6 +10,12 @@ def test_softmax():
|
||||
y = softmax(x)
|
||||
np.testing.assert_almost_equal(y.sum(), 1)
|
||||
|
||||
x = np.random.random([0, 1])
|
||||
y = softmax(x)
|
||||
assert y[0] < y[1]
|
||||
assert (y > 0).all()
|
||||
assert (y < 1).all()
|
||||
|
||||
|
||||
def test_maxdrawdown():
|
||||
assert MDD(np.array([0, 0, 0, 0, 1, 2, 3])) == 0
|
||||
|
||||
Reference in New Issue
Block a user