diff --git a/ddpg/model.py b/ddpg/model.py index 331ee6a..4be5649 100644 --- a/ddpg/model.py +++ b/ddpg/model.py @@ -216,7 +216,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar env = create_env(args) random_process = create_random_process(args) - state_normalizer = StaticNormalizer(env.state_dim) + state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape))) reward_normalizer = StaticNormalizer(1) actor_learning_rate_decay_fn = create_decay_fn( @@ -488,7 +488,7 @@ def play_single_thread( num_cycles=args.max_episodes // epsilon_cycle_len) env = create_env(args) - state_normalizer = StaticNormalizer(env.state_dim) + state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape))) reward_normalizer = StaticNormalizer(1) episode = 1 diff --git a/ddpg/train.py b/ddpg/train.py index dfcbe35..bb80b88 100644 --- a/ddpg/train.py +++ b/ddpg/train.py @@ -5,10 +5,12 @@ import copy import torch import torch.multiprocessing as mp from multiprocessing import Value +import numpy as np from common.misc_util import boolean_flag, str2params, create_if_need from common.env_wrappers import create_env from common.torch_util import activations, hard_update +from common.normalizer import StaticNormalizer from ddpg.model import create_model, create_act_update_fns, train_multi_thread, \ train_single_thread, play_single_thread @@ -148,7 +150,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True) env = create_env(args) - shared_state_normalizer = StaticNormalizer(env.state_dim) + shared_state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape))) shared_reward_normalizer = StaticNormalizer(1) if args.flip_state_action and hasattr(env, "state_transform"):