mirror of
https://github.com/wassname/rl-portfolio-management.git
synced 2026-06-27 16:46:41 +08:00
concat other way round
This commit is contained in:
@@ -7,9 +7,16 @@ def concat_states(state):
|
||||
history = state["history"]
|
||||
weights = state["weights"]
|
||||
weight_insert_shape = (history.shape[0], 1, history.shape[2])
|
||||
weight_insert = np.ones(
|
||||
weight_insert_shape) * weights[1:, np.newaxis, np.newaxis]
|
||||
state = np.concatenate([history, weight_insert], axis=1)
|
||||
if len(weights) - 1 == history.shape[0]:
|
||||
weight_insert = np.ones(
|
||||
weight_insert_shape) * weights[1:, np.newaxis, np.newaxis]
|
||||
elif len(weights) - 1 == history.shape[2]:
|
||||
weight_insert = np.ones(
|
||||
weight_insert_shape) * weights[np.newaxis, np.newaxis, 1:]
|
||||
else:
|
||||
weight_insert = np.ones(
|
||||
weight_insert_shape) * weights[np.newaxis, 1:, np.newaxis]
|
||||
state = np.concatenate([weight_insert, history], axis=1)
|
||||
return state
|
||||
|
||||
|
||||
|
||||
+269
-238
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user