diff --git a/common/normalizer.py b/common/normalizer.py new file mode 100644 index 0000000..525fd92 --- /dev/null +++ b/common/normalizer.py @@ -0,0 +1,97 @@ +####################################################################### +# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) # +# Permission given to modify the code as long as you keep this # +# declaration at the top # +####################################################################### +import torch +import numpy as np + +class Normalizer: + def __init__(self, o_size): + self.stats = SharedStats(o_size) + + def __call__(self, o_): + if np.isscalar(o_): + o = torch.FloatTensor([o_]) + else: + o = torch.FloatTensor(o_) + self.stats.feed(o) + std = (self.stats.v + 1e-6) ** .5 + o = (o - self.stats.m) / std + o = o.numpy() + if np.isscalar(o_): + o = np.asscalar(o) + else: + o = o.reshape(o_.shape) + return o + +class StaticNormalizer: + def __init__(self, o_size): + self.offline_stats = SharedStats(o_size) + self.online_stats = SharedStats(o_size) + + def __call__(self, o_): + if np.isscalar(o_): + o = torch.FloatTensor([o_]) + else: + o = torch.FloatTensor(o_) + self.online_stats.feed(o) + if self.offline_stats.n[0] == 0: + return o_ + std = (self.offline_stats.v + 1e-6) ** .5 + o = (o - self.offline_stats.m) / std + o = o.numpy() + if np.isscalar(o_): + o = np.asscalar(o) + else: + o = o.reshape(o_.shape) + return o + +class SharedStats: + def __init__(self, o_size): + self.m = torch.zeros(o_size) + self.v = torch.zeros(o_size) + self.n = torch.zeros(1) + self.m.share_memory_() + self.v.share_memory_() + self.n.share_memory_() + + def feed(self, o): + n = self.n[0] + new_m = self.m * (n / (n + 1)) + o / (n + 1) + self.v.copy_(self.v * (n / (n + 1)) + (o - self.m) * (o - new_m) / (n + 1)) + self.m.copy_(new_m) + self.n.add_(1) + + def zero(self): + self.m.zero_() + self.v.zero_() + self.n.zero_() + + def load(self, stats): + self.m.copy_(stats.m) + self.v.copy_(stats.v) + self.n.copy_(stats.n) + + def merge(self, B): + A = self + n_A = self.n[0] + n_B = B.n[0] + n = n_A + n_B + delta = B.m - A.m + m = A.m + delta * n_B / n + v = A.v * n_A + B.v * n_B + delta * delta * n_A * n_B / n + v /= n + self.m.copy_(m) + self.v.copy_(v) + self.n.add_(B.n) + + def state_dict(self): + return {'m': self.m.numpy(), + 'v': self.v.numpy(), + 'n': self.n.numpy()} + + def load_state_dict(self, saved): + self.m = torch.FloatTensor(saved['m']) + self.v = torch.FloatTensor(saved['v']) + self.n = torch.FloatTensor(saved['n']) diff --git a/ddpg/model.py b/ddpg/model.py index 40a6648..a48fdaa 100644 --- a/ddpg/model.py +++ b/ddpg/model.py @@ -6,47 +6,51 @@ import time import torch.nn as nn from pprint import pprint -from ddpg.nets import Actor, Critic, Base, ActorHead, CriticHead, DynamicsHead +from ddpg.nets import Base, ActorHead, CriticHead, DynamicsHead from common.torch_util import to_numpy, to_tensor, soft_update from common.misc_util import create_if_need, set_global_seeds from common.logger import Logger from common.buffers import create_buffer +from common.normalizer import StaticNormalizer, Normalizer from common.loss import create_loss, create_decay_fn from common.env_wrappers import create_env from common.random_process import create_random_process def create_model(args): - # TODO still using actor layers etc base = Base( - args.n_observation, args.n_action, args.actor_layers, - activation=args.actor_activation, - layer_norm=args.actor_layer_norm, - parameters_noise=args.actor_parameters_noise, - parameters_noise_factorised=args.actor_parameters_noise_factorised, - last_activation=nn.Tanh + args.n_observation, args.n_action, args.base_layers, + activation=args.base_activation, + layer_norm=args.base_layer_norm, + parameters_noise=args.base_parameters_noise, + parameters_noise_factorised=args.base_parameters_noise_factorised ) actor = ActorHead( base, - args.n_observation, args.n_action, args.actor_layers, - activation=args.actor_activation, - layer_norm=args.actor_layer_norm, - parameters_noise=args.actor_parameters_noise, - parameters_noise_factorised=args.actor_parameters_noise_factorised, + args.n_observation, args.n_action, + # args.actor_layers, + # activation=args.actor_activation, + # layer_norm=args.actor_layer_norm, + # parameters_noise=args.actor_parameters_noise, + # parameters_noise_factorised=args.actor_parameters_noise_factorised, last_activation=nn.Tanh) critic = CriticHead( base, - args.n_observation, args.n_action, args.critic_layers, - activation=args.critic_activation, - layer_norm=args.critic_layer_norm, - parameters_noise=args.critic_parameters_noise, - parameters_noise_factorised=args.critic_parameters_noise_factorised) + args.n_observation, args.n_action, + # args.critic_layers, + # activation=args.critic_activation, + # layer_norm=args.critic_layer_norm, + # parameters_noise=args.critic_parameters_noise, + # parameters_noise_factorised=args.critic_parameters_noise_factorised + ) dynamics = DynamicsHead( base, - args.n_observation, args.n_action, args.critic_layers, - activation=args.critic_activation, - layer_norm=args.critic_layer_norm, - parameters_noise=args.critic_parameters_noise, - parameters_noise_factorised=args.critic_parameters_noise_factorised) + args.n_observation, args.n_action, + # args.dynamics_layers, + # activation=args.dynamics_activation, + # layer_norm=args.dynamics_layer_norm, + # parameters_noise=args.dynamics_parameters_noise, + # parameters_noise_factorised=args.dynamics_parameters_noise_factorised + ) pprint(actor) pprint(critic) @@ -205,7 +209,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, return act_fn, update_fn, save_fn -def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward): +def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward, shared_state_normalizer, shared_reward_normalizer): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) @@ -226,6 +230,9 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar env = create_env(args) random_process = create_random_process(args) + state_normalizer = StaticNormalizer(env.state_dim) + reward_normalizer = StaticNormalizer(1) + actor_learning_rate_decay_fn = create_decay_fn( "linear", initial_value=args.actor_lr, @@ -275,6 +282,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar env.seed(seed) observation = env.reset() #seed=seed, difficulty=args.difficulty) + observation = state_normalizer(observation) random_process.reset_states() done = False @@ -282,10 +290,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar action = act_fn(observation, noise=epsilon*random_process.sample()) next_observation, reward, done, _ = env.step(action) - buffer.add(observation, action, reward, next_observation, done) episode_metrics["reward"] += reward episode_metrics["step"] += 1 + next_observation = state_normalizer(next_observation) + reward = reward_normalizer(reward) + + buffer.add(observation, action, reward, next_observation, done) + if len(buffer) >= args.train_steps: if args.prioritized_replay: @@ -314,8 +326,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar episode += 1 - if episode_metrics["reward"] > 15.0 * args.reward_scale \ - and episode_metrics["reward"] > best_reward.value: + if episode_metrics["reward"] > best_reward.value: best_reward.value = episode_metrics["reward"] logger.scalar_summary("best reward", best_reward.value, episode) save_fn(episode) @@ -340,10 +351,18 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar if episode % args.save_step == 0: save_fn(episode) + logger.info('episode %s, metrics %s', episode, episode_metrics) if elapsed_time > 86400 * args.max_train_days: episode = args.max_episodes + 1 + # Sync normalizers + shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats) + state_normalizer.online_stats.zero() + + shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats) + reward_normalizer.online_stats.zero() + save_fn(episode) raise KeyboardInterrupt @@ -460,7 +479,7 @@ def train_single_thread( def play_single_thread( actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, global_episode, global_update_step, episodes_queue, - best_reward): + best_reward, shared_state_normalizer, shared_reward_normalizer): workerseed = args.seed + 241 * args.thread set_global_seeds(workerseed) @@ -482,6 +501,10 @@ def play_single_thread( cycle_len=epsilon_cycle_len, num_cycles=args.max_episodes // epsilon_cycle_len) + env = create_env(args) + state_normalizer = StaticNormalizer(env.state_dim) + reward_normalizer = StaticNormalizer(1) + episode = 1 step = 0 start_time = time.time() @@ -501,6 +524,7 @@ def play_single_thread( env.seed(seed) observation = env.reset()#seed=seed, difficulty=args.difficulty) + observation = state_normalizer(observation) random_process.reset_states() done = False @@ -509,10 +533,15 @@ def play_single_thread( action = act_fn(observation, noise=epsilon * random_process.sample()) next_observation, reward, done, _ = env.step(action) - replay.append((observation, action, reward, next_observation, done)) + episode_metrics["reward"] += reward episode_metrics["step"] += 1 + next_observation = state_normalizer(next_observation) + reward = reward_normalizer(reward) + + replay.append((observation, action, reward, next_observation, done)) + observation = next_observation episodes_queue.put(replay) @@ -544,4 +573,11 @@ def play_single_thread( if elapsed_time > 86400 * args.max_train_days: global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1 + # Sync normalizers + shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats) + state_normalizer.online_stats.zero() + + shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats) + reward_normalizer.online_stats.zero() + raise KeyboardInterrupt diff --git a/ddpg/train.py b/ddpg/train.py index a4f0965..4669d8f 100644 --- a/ddpg/train.py +++ b/ddpg/train.py @@ -147,6 +147,8 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True) env = create_env(args) + shared_state_normalizer = StaticNormalizer(env.state_dim) + shared_reward_normalizer = StaticNormalizer(1) if args.flip_state_action and hasattr(env, "state_transform"): args.flip_states = env.state_transform.flip_states @@ -202,14 +204,14 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl if args.debug: # # run a single thread in the foreground so we can debug easier args.thread = 1 - multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward) + multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward, shared_state_normalizer, shared_reward_normalizer) else: for rank in range(args.num_threads): args.thread = rank p = mp.Process( target=multi_thread, args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, - best_reward)) + best_reward, shared_state_normalizer, shared_reward_normalizer)) p.start() processes.append(p) else: @@ -224,7 +226,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue) args.thread = 2 - play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward) + play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward, shared_state_normalizer, shared_reward_normalizer) else: global_episode = Value("i", 0) global_update_step = Value("i", 0) @@ -241,7 +243,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl target=play_single, args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, - best_reward)) + best_reward, shared_state_normalizer, shared_reward_normalizer)) p.start() processes.append(p)