add share normalizer for state and reward

This commit is contained in:
wassname
2018-01-21 12:35:13 +08:00
parent cd949644d3
commit 6bb9c51403
3 changed files with 168 additions and 33 deletions
+97
View File
@@ -0,0 +1,97 @@
#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
import torch
import numpy as np
class Normalizer:
def __init__(self, o_size):
self.stats = SharedStats(o_size)
def __call__(self, o_):
if np.isscalar(o_):
o = torch.FloatTensor([o_])
else:
o = torch.FloatTensor(o_)
self.stats.feed(o)
std = (self.stats.v + 1e-6) ** .5
o = (o - self.stats.m) / std
o = o.numpy()
if np.isscalar(o_):
o = np.asscalar(o)
else:
o = o.reshape(o_.shape)
return o
class StaticNormalizer:
def __init__(self, o_size):
self.offline_stats = SharedStats(o_size)
self.online_stats = SharedStats(o_size)
def __call__(self, o_):
if np.isscalar(o_):
o = torch.FloatTensor([o_])
else:
o = torch.FloatTensor(o_)
self.online_stats.feed(o)
if self.offline_stats.n[0] == 0:
return o_
std = (self.offline_stats.v + 1e-6) ** .5
o = (o - self.offline_stats.m) / std
o = o.numpy()
if np.isscalar(o_):
o = np.asscalar(o)
else:
o = o.reshape(o_.shape)
return o
class SharedStats:
def __init__(self, o_size):
self.m = torch.zeros(o_size)
self.v = torch.zeros(o_size)
self.n = torch.zeros(1)
self.m.share_memory_()
self.v.share_memory_()
self.n.share_memory_()
def feed(self, o):
n = self.n[0]
new_m = self.m * (n / (n + 1)) + o / (n + 1)
self.v.copy_(self.v * (n / (n + 1)) + (o - self.m) * (o - new_m) / (n + 1))
self.m.copy_(new_m)
self.n.add_(1)
def zero(self):
self.m.zero_()
self.v.zero_()
self.n.zero_()
def load(self, stats):
self.m.copy_(stats.m)
self.v.copy_(stats.v)
self.n.copy_(stats.n)
def merge(self, B):
A = self
n_A = self.n[0]
n_B = B.n[0]
n = n_A + n_B
delta = B.m - A.m
m = A.m + delta * n_B / n
v = A.v * n_A + B.v * n_B + delta * delta * n_A * n_B / n
v /= n
self.m.copy_(m)
self.v.copy_(v)
self.n.add_(B.n)
def state_dict(self):
return {'m': self.m.numpy(),
'v': self.v.numpy(),
'n': self.n.numpy()}
def load_state_dict(self, saved):
self.m = torch.FloatTensor(saved['m'])
self.v = torch.FloatTensor(saved['v'])
self.n = torch.FloatTensor(saved['n'])
+65 -29
View File
@@ -6,47 +6,51 @@ import time
import torch.nn as nn
from pprint import pprint
from ddpg.nets import Actor, Critic, Base, ActorHead, CriticHead, DynamicsHead
from ddpg.nets import Base, ActorHead, CriticHead, DynamicsHead
from common.torch_util import to_numpy, to_tensor, soft_update
from common.misc_util import create_if_need, set_global_seeds
from common.logger import Logger
from common.buffers import create_buffer
from common.normalizer import StaticNormalizer, Normalizer
from common.loss import create_loss, create_decay_fn
from common.env_wrappers import create_env
from common.random_process import create_random_process
def create_model(args):
# TODO still using actor layers etc
base = Base(
args.n_observation, args.n_action, args.actor_layers,
activation=args.actor_activation,
layer_norm=args.actor_layer_norm,
parameters_noise=args.actor_parameters_noise,
parameters_noise_factorised=args.actor_parameters_noise_factorised,
last_activation=nn.Tanh
args.n_observation, args.n_action, args.base_layers,
activation=args.base_activation,
layer_norm=args.base_layer_norm,
parameters_noise=args.base_parameters_noise,
parameters_noise_factorised=args.base_parameters_noise_factorised
)
actor = ActorHead(
base,
args.n_observation, args.n_action, args.actor_layers,
activation=args.actor_activation,
layer_norm=args.actor_layer_norm,
parameters_noise=args.actor_parameters_noise,
parameters_noise_factorised=args.actor_parameters_noise_factorised,
args.n_observation, args.n_action,
# args.actor_layers,
# activation=args.actor_activation,
# layer_norm=args.actor_layer_norm,
# parameters_noise=args.actor_parameters_noise,
# parameters_noise_factorised=args.actor_parameters_noise_factorised,
last_activation=nn.Tanh)
critic = CriticHead(
base,
args.n_observation, args.n_action, args.critic_layers,
activation=args.critic_activation,
layer_norm=args.critic_layer_norm,
parameters_noise=args.critic_parameters_noise,
parameters_noise_factorised=args.critic_parameters_noise_factorised)
args.n_observation, args.n_action,
# args.critic_layers,
# activation=args.critic_activation,
# layer_norm=args.critic_layer_norm,
# parameters_noise=args.critic_parameters_noise,
# parameters_noise_factorised=args.critic_parameters_noise_factorised
)
dynamics = DynamicsHead(
base,
args.n_observation, args.n_action, args.critic_layers,
activation=args.critic_activation,
layer_norm=args.critic_layer_norm,
parameters_noise=args.critic_parameters_noise,
parameters_noise_factorised=args.critic_parameters_noise_factorised)
args.n_observation, args.n_action,
# args.dynamics_layers,
# activation=args.dynamics_activation,
# layer_norm=args.dynamics_layer_norm,
# parameters_noise=args.dynamics_parameters_noise,
# parameters_noise_factorised=args.dynamics_parameters_noise_factorised
)
pprint(actor)
pprint(critic)
@@ -205,7 +209,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
return act_fn, update_fn, save_fn
def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward):
def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward, shared_state_normalizer, shared_reward_normalizer):
workerseed = args.seed + 241 * args.thread
set_global_seeds(workerseed)
@@ -226,6 +230,9 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
env = create_env(args)
random_process = create_random_process(args)
state_normalizer = StaticNormalizer(env.state_dim)
reward_normalizer = StaticNormalizer(1)
actor_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.actor_lr,
@@ -275,6 +282,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
env.seed(seed)
observation = env.reset() #seed=seed, difficulty=args.difficulty)
observation = state_normalizer(observation)
random_process.reset_states()
done = False
@@ -282,10 +290,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
action = act_fn(observation, noise=epsilon*random_process.sample())
next_observation, reward, done, _ = env.step(action)
buffer.add(observation, action, reward, next_observation, done)
episode_metrics["reward"] += reward
episode_metrics["step"] += 1
next_observation = state_normalizer(next_observation)
reward = reward_normalizer(reward)
buffer.add(observation, action, reward, next_observation, done)
if len(buffer) >= args.train_steps:
if args.prioritized_replay:
@@ -314,8 +326,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
episode += 1
if episode_metrics["reward"] > 15.0 * args.reward_scale \
and episode_metrics["reward"] > best_reward.value:
if episode_metrics["reward"] > best_reward.value:
best_reward.value = episode_metrics["reward"]
logger.scalar_summary("best reward", best_reward.value, episode)
save_fn(episode)
@@ -340,10 +351,18 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
if episode % args.save_step == 0:
save_fn(episode)
logger.info('episode %s, metrics %s', episode, episode_metrics)
if elapsed_time > 86400 * args.max_train_days:
episode = args.max_episodes + 1
# Sync normalizers
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
state_normalizer.online_stats.zero()
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
reward_normalizer.online_stats.zero()
save_fn(episode)
raise KeyboardInterrupt
@@ -460,7 +479,7 @@ def train_single_thread(
def play_single_thread(
actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn,
global_episode, global_update_step, episodes_queue,
best_reward):
best_reward, shared_state_normalizer, shared_reward_normalizer):
workerseed = args.seed + 241 * args.thread
set_global_seeds(workerseed)
@@ -482,6 +501,10 @@ def play_single_thread(
cycle_len=epsilon_cycle_len,
num_cycles=args.max_episodes // epsilon_cycle_len)
env = create_env(args)
state_normalizer = StaticNormalizer(env.state_dim)
reward_normalizer = StaticNormalizer(1)
episode = 1
step = 0
start_time = time.time()
@@ -501,6 +524,7 @@ def play_single_thread(
env.seed(seed)
observation = env.reset()#seed=seed, difficulty=args.difficulty)
observation = state_normalizer(observation)
random_process.reset_states()
done = False
@@ -509,10 +533,15 @@ def play_single_thread(
action = act_fn(observation, noise=epsilon * random_process.sample())
next_observation, reward, done, _ = env.step(action)
replay.append((observation, action, reward, next_observation, done))
episode_metrics["reward"] += reward
episode_metrics["step"] += 1
next_observation = state_normalizer(next_observation)
reward = reward_normalizer(reward)
replay.append((observation, action, reward, next_observation, done))
observation = next_observation
episodes_queue.put(replay)
@@ -544,4 +573,11 @@ def play_single_thread(
if elapsed_time > 86400 * args.max_train_days:
global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1
# Sync normalizers
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
state_normalizer.online_stats.zero()
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
reward_normalizer.online_stats.zero()
raise KeyboardInterrupt
+6 -4
View File
@@ -147,6 +147,8 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True)
env = create_env(args)
shared_state_normalizer = StaticNormalizer(env.state_dim)
shared_reward_normalizer = StaticNormalizer(1)
if args.flip_state_action and hasattr(env, "state_transform"):
args.flip_states = env.state_transform.flip_states
@@ -202,14 +204,14 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
if args.debug:
# # run a single thread in the foreground so we can debug easier
args.thread = 1
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward, shared_state_normalizer, shared_reward_normalizer)
else:
for rank in range(args.num_threads):
args.thread = rank
p = mp.Process(
target=multi_thread,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
best_reward))
best_reward, shared_state_normalizer, shared_reward_normalizer))
p.start()
processes.append(p)
else:
@@ -224,7 +226,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
args.thread = 2
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward)
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward, shared_state_normalizer, shared_reward_normalizer)
else:
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
@@ -241,7 +243,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
target=play_single,
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
global_episode, global_update_step, episodes_queue,
best_reward))
best_reward, shared_state_normalizer, shared_reward_normalizer))
p.start()
processes.append(p)