mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
more tests
This commit is contained in:
+7
-1
@@ -1,10 +1,16 @@
|
||||
from rl_portfolio_management.data.utils import random_shift, normalize, scale_to_start
|
||||
from rl_portfolio_management.util import sharpe, MDD
|
||||
from rl_portfolio_management.util import sharpe, MDD, softmax
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_softmax():
|
||||
x = np.random.random((20, 20))
|
||||
y = softmax(x)
|
||||
np.testing.assert_almost_equal(y.sum(), 1)
|
||||
|
||||
|
||||
def test_maxdrawdown():
|
||||
assert MDD(np.array([0, 0, 0, 0, 1, 2, 3])) == 0
|
||||
assert MDD(np.array([0, 0, 0, 0, 1, 2, 1])) == -1
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import gym
|
||||
|
||||
from rl_portfolio_management.environments import PortfolioEnv, env_specs
|
||||
from rl_portfolio_management.wrappers import ConcatStates, SoftmaxActions
|
||||
|
||||
|
||||
def test_concat():
|
||||
env = gym.make("CryptoPortfolioEIIE-v0")
|
||||
env = ConcatStates(env)
|
||||
obs = env.reset()
|
||||
assert len(obs.shape) == 3
|
||||
action = env.action_space.sample()
|
||||
action /= action.sum()
|
||||
obs, rew, done, info = env.step(action)
|
||||
assert len(obs.shape) == 3
|
||||
|
||||
|
||||
def test_softmax():
|
||||
env = gym.make("CryptoPortfolioEIIE-v0")
|
||||
env = SoftmaxActions(env)
|
||||
obs = env.reset()
|
||||
# should be no problem with actions that don't sum to one
|
||||
action = env.action_space.sample() * 100
|
||||
obs, rew, done, info = env.step(action)
|
||||
Reference in New Issue
Block a user