This commit is contained in:
wassname
2017-11-12 12:19:06 +08:00
parent 99eb4190cb
commit fc1ad88e79
2 changed files with 20 additions and 11 deletions
@@ -46,7 +46,8 @@ class DataSrc(object):
# dataframe to matrix
self.asset_names = df.columns.levels[0].tolist()
self.features = df.columns.levels[1].tolist()
self._data = df.as_matrix().reshape((len(self.asset_names), -1, len(self.features)))
self._data = df.as_matrix().reshape(
(len(self.asset_names), -1, len(self.features)))
self._times = df.index
self.price_columns = ['close', 'high', 'low', 'open']
@@ -67,7 +68,7 @@ class DataSrc(object):
def _step(self):
# get history matrix from dataframe
data_window = self.data[:, self.step:self.step +
self.window_length].copy()
self.window_length].copy()
# (eq 18) prices are divided by open price
# While the paper says open/close, it only makes sense with close/open
@@ -75,7 +76,8 @@ class DataSrc(object):
if self.scale:
# scale prices by dividing price columns by the last open price
last_open_price = data_window[:, -1, 0]
data_window[:, :, :nb_pc] /= last_open_price[:, np.newaxis, np.newaxis]
data_window[:, :, :nb_pc] /= last_open_price[:,
np.newaxis, np.newaxis]
if self.scale_extra_cols:
# normalize non price columns
@@ -104,7 +106,7 @@ class DataSrc(object):
data = self._data[:, self.idx -
self.window_length:self.idx + self.steps + 1].copy()
self.times = self._times[self.idx -
self.window_length:self.idx + self.steps + 1]
self.window_length:self.idx + self.steps + 1]
# augment data to prevent overfitting
data += np.random.normal(loc=0, scale=self.augment, size=data.shape)
+14 -7
View File
@@ -64,7 +64,8 @@ def test_portfolio_env_random_agent(spec_id):
df_info = pd.DataFrame(env.infos)
final_value = df_info.portfolio_value.iloc[-1]
market_value = df_info.market_value.iloc[-1]
np.testing.assert_allclose(final_value, market_value, rtol=0.1, err_msg='should be similar to market values after 20 random steps')
np.testing.assert_allclose(final_value, market_value, rtol=0.1,
err_msg='should be similar to market values after 20 random steps')
@pytest.mark.parametrize("spec_id", env_specs)
@@ -98,8 +99,10 @@ def test_scaled_non_price_cols():
# if normalized: for a large window, mean non_prices should be near mean=0, std=1
non_price_std = stds[3:]
np.testing.assert_almost_equal(non_price_means, [0, 0], decimal=1, err_msg='non price columns should be normalized to be close to one')
np.testing.assert_allclose(non_price_std, [1, 1], rtol=0.1, err_msg='non price columns should be normalized to be close to one')
np.testing.assert_almost_equal(non_price_means, [
0, 0], decimal=1, err_msg='non price columns should be normalized to be close to one')
np.testing.assert_allclose(non_price_std, [
1, 1], rtol=0.1, err_msg='non price columns should be normalized to be close to one')
def test_scaled():
@@ -115,10 +118,12 @@ def test_scaled():
obs1 = env1.reset()
nb_price_cols = len(env1.src.price_columns) - 1
assert (obs0["history"][:, :, :nb_price_cols] != obs1["history"][:, :, :nb_price_cols]).all(), 'scaled and non-scaled data should differ'
assert (obs0["history"][:, :, :nb_price_cols] != obs1["history"][:, :,
:nb_price_cols]).all(), 'scaled and non-scaled data should differ'
# if scaled by last opening price: for a small window, mean prices should be near 1
np.testing.assert_allclose(obs1["history"][:, -1, :nb_price_cols], 1, rtol=0.1, err_msg='last prices should be normalized to be close to one')
np.testing.assert_allclose(obs1["history"][:, -1, :nb_price_cols], 1,
rtol=0.1, err_msg='last prices should be normalized to be close to one')
@pytest.mark.parametrize("spec_id", env_specs)
@@ -161,7 +166,9 @@ def test_costs(spec_id):
env.reset()
obs, reward, done, info = env.step(np.array([1, 0, 0, 0]))
obs, reward, done, info = env.step(np.array([0, 1, 0, 0]))
np.testing.assert_almost_equal(info['cost'], env.sim.cost, err_msg='trading 100% cash for asset1 should cost 1*trading_cost')
np.testing.assert_almost_equal(
info['cost'], env.sim.cost, err_msg='trading 100% cash for asset1 should cost 1*trading_cost')
obs, reward, done, info = env.step(np.array([0, 0, 1, 0]))
np.testing.assert_almost_equal(info['cost'], env.sim.cost * 2, err_msg='trading 100% asset1 for asset2 should cost 2*trading_cost')
np.testing.assert_almost_equal(
info['cost'], env.sim.cost * 2, err_msg='trading 100% asset1 for asset2 should cost 2*trading_cost')