concat other way round

This commit is contained in:
wassname
2017-11-13 07:48:25 +08:00
parent 3597f185cc
commit ea47fc42f7
2 changed files with 279 additions and 241 deletions
@@ -7,9 +7,16 @@ def concat_states(state):
history = state["history"]
weights = state["weights"]
weight_insert_shape = (history.shape[0], 1, history.shape[2])
weight_insert = np.ones(
weight_insert_shape) * weights[1:, np.newaxis, np.newaxis]
state = np.concatenate([history, weight_insert], axis=1)
if len(weights) - 1 == history.shape[0]:
weight_insert = np.ones(
weight_insert_shape) * weights[1:, np.newaxis, np.newaxis]
elif len(weights) - 1 == history.shape[2]:
weight_insert = np.ones(
weight_insert_shape) * weights[np.newaxis, np.newaxis, 1:]
else:
weight_insert = np.ones(
weight_insert_shape) * weights[np.newaxis, 1:, np.newaxis]
state = np.concatenate([weight_insert, history], axis=1)
return state
+269 -238
View File
File diff suppressed because one or more lines are too long