This commit is contained in:
wassname
2018-01-21 12:45:56 +08:00
parent 11cb4d9cce
commit f4c45b03db
2 changed files with 5 additions and 3 deletions
+2 -2
View File
@@ -216,7 +216,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
env = create_env(args)
random_process = create_random_process(args)
state_normalizer = StaticNormalizer(env.state_dim)
state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape)))
reward_normalizer = StaticNormalizer(1)
actor_learning_rate_decay_fn = create_decay_fn(
@@ -488,7 +488,7 @@ def play_single_thread(
num_cycles=args.max_episodes // epsilon_cycle_len)
env = create_env(args)
state_normalizer = StaticNormalizer(env.state_dim)
state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape)))
reward_normalizer = StaticNormalizer(1)
episode = 1
+3 -1
View File
@@ -5,10 +5,12 @@ import copy
import torch
import torch.multiprocessing as mp
from multiprocessing import Value
import numpy as np
from common.misc_util import boolean_flag, str2params, create_if_need
from common.env_wrappers import create_env
from common.torch_util import activations, hard_update
from common.normalizer import StaticNormalizer
from ddpg.model import create_model, create_act_update_fns, train_multi_thread, \
train_single_thread, play_single_thread
@@ -148,7 +150,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True)
env = create_env(args)
shared_state_normalizer = StaticNormalizer(env.state_dim)
shared_state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape)))
shared_reward_normalizer = StaticNormalizer(1)
if args.flip_state_action and hasattr(env, "state_transform"):