mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 19:45:48 +08:00
add share normalizer for state and reward
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
#######################################################################
|
||||
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
|
||||
# Permission given to modify the code as long as you keep this #
|
||||
# declaration at the top #
|
||||
#######################################################################
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class Normalizer:
|
||||
def __init__(self, o_size):
|
||||
self.stats = SharedStats(o_size)
|
||||
|
||||
def __call__(self, o_):
|
||||
if np.isscalar(o_):
|
||||
o = torch.FloatTensor([o_])
|
||||
else:
|
||||
o = torch.FloatTensor(o_)
|
||||
self.stats.feed(o)
|
||||
std = (self.stats.v + 1e-6) ** .5
|
||||
o = (o - self.stats.m) / std
|
||||
o = o.numpy()
|
||||
if np.isscalar(o_):
|
||||
o = np.asscalar(o)
|
||||
else:
|
||||
o = o.reshape(o_.shape)
|
||||
return o
|
||||
|
||||
class StaticNormalizer:
|
||||
def __init__(self, o_size):
|
||||
self.offline_stats = SharedStats(o_size)
|
||||
self.online_stats = SharedStats(o_size)
|
||||
|
||||
def __call__(self, o_):
|
||||
if np.isscalar(o_):
|
||||
o = torch.FloatTensor([o_])
|
||||
else:
|
||||
o = torch.FloatTensor(o_)
|
||||
self.online_stats.feed(o)
|
||||
if self.offline_stats.n[0] == 0:
|
||||
return o_
|
||||
std = (self.offline_stats.v + 1e-6) ** .5
|
||||
o = (o - self.offline_stats.m) / std
|
||||
o = o.numpy()
|
||||
if np.isscalar(o_):
|
||||
o = np.asscalar(o)
|
||||
else:
|
||||
o = o.reshape(o_.shape)
|
||||
return o
|
||||
|
||||
class SharedStats:
|
||||
def __init__(self, o_size):
|
||||
self.m = torch.zeros(o_size)
|
||||
self.v = torch.zeros(o_size)
|
||||
self.n = torch.zeros(1)
|
||||
self.m.share_memory_()
|
||||
self.v.share_memory_()
|
||||
self.n.share_memory_()
|
||||
|
||||
def feed(self, o):
|
||||
n = self.n[0]
|
||||
new_m = self.m * (n / (n + 1)) + o / (n + 1)
|
||||
self.v.copy_(self.v * (n / (n + 1)) + (o - self.m) * (o - new_m) / (n + 1))
|
||||
self.m.copy_(new_m)
|
||||
self.n.add_(1)
|
||||
|
||||
def zero(self):
|
||||
self.m.zero_()
|
||||
self.v.zero_()
|
||||
self.n.zero_()
|
||||
|
||||
def load(self, stats):
|
||||
self.m.copy_(stats.m)
|
||||
self.v.copy_(stats.v)
|
||||
self.n.copy_(stats.n)
|
||||
|
||||
def merge(self, B):
|
||||
A = self
|
||||
n_A = self.n[0]
|
||||
n_B = B.n[0]
|
||||
n = n_A + n_B
|
||||
delta = B.m - A.m
|
||||
m = A.m + delta * n_B / n
|
||||
v = A.v * n_A + B.v * n_B + delta * delta * n_A * n_B / n
|
||||
v /= n
|
||||
self.m.copy_(m)
|
||||
self.v.copy_(v)
|
||||
self.n.add_(B.n)
|
||||
|
||||
def state_dict(self):
|
||||
return {'m': self.m.numpy(),
|
||||
'v': self.v.numpy(),
|
||||
'n': self.n.numpy()}
|
||||
|
||||
def load_state_dict(self, saved):
|
||||
self.m = torch.FloatTensor(saved['m'])
|
||||
self.v = torch.FloatTensor(saved['v'])
|
||||
self.n = torch.FloatTensor(saved['n'])
|
||||
+65
-29
@@ -6,47 +6,51 @@ import time
|
||||
import torch.nn as nn
|
||||
from pprint import pprint
|
||||
|
||||
from ddpg.nets import Actor, Critic, Base, ActorHead, CriticHead, DynamicsHead
|
||||
from ddpg.nets import Base, ActorHead, CriticHead, DynamicsHead
|
||||
from common.torch_util import to_numpy, to_tensor, soft_update
|
||||
from common.misc_util import create_if_need, set_global_seeds
|
||||
from common.logger import Logger
|
||||
from common.buffers import create_buffer
|
||||
from common.normalizer import StaticNormalizer, Normalizer
|
||||
from common.loss import create_loss, create_decay_fn
|
||||
from common.env_wrappers import create_env
|
||||
from common.random_process import create_random_process
|
||||
|
||||
def create_model(args):
|
||||
# TODO still using actor layers etc
|
||||
base = Base(
|
||||
args.n_observation, args.n_action, args.actor_layers,
|
||||
activation=args.actor_activation,
|
||||
layer_norm=args.actor_layer_norm,
|
||||
parameters_noise=args.actor_parameters_noise,
|
||||
parameters_noise_factorised=args.actor_parameters_noise_factorised,
|
||||
last_activation=nn.Tanh
|
||||
args.n_observation, args.n_action, args.base_layers,
|
||||
activation=args.base_activation,
|
||||
layer_norm=args.base_layer_norm,
|
||||
parameters_noise=args.base_parameters_noise,
|
||||
parameters_noise_factorised=args.base_parameters_noise_factorised
|
||||
)
|
||||
actor = ActorHead(
|
||||
base,
|
||||
args.n_observation, args.n_action, args.actor_layers,
|
||||
activation=args.actor_activation,
|
||||
layer_norm=args.actor_layer_norm,
|
||||
parameters_noise=args.actor_parameters_noise,
|
||||
parameters_noise_factorised=args.actor_parameters_noise_factorised,
|
||||
args.n_observation, args.n_action,
|
||||
# args.actor_layers,
|
||||
# activation=args.actor_activation,
|
||||
# layer_norm=args.actor_layer_norm,
|
||||
# parameters_noise=args.actor_parameters_noise,
|
||||
# parameters_noise_factorised=args.actor_parameters_noise_factorised,
|
||||
last_activation=nn.Tanh)
|
||||
critic = CriticHead(
|
||||
base,
|
||||
args.n_observation, args.n_action, args.critic_layers,
|
||||
activation=args.critic_activation,
|
||||
layer_norm=args.critic_layer_norm,
|
||||
parameters_noise=args.critic_parameters_noise,
|
||||
parameters_noise_factorised=args.critic_parameters_noise_factorised)
|
||||
args.n_observation, args.n_action,
|
||||
# args.critic_layers,
|
||||
# activation=args.critic_activation,
|
||||
# layer_norm=args.critic_layer_norm,
|
||||
# parameters_noise=args.critic_parameters_noise,
|
||||
# parameters_noise_factorised=args.critic_parameters_noise_factorised
|
||||
)
|
||||
dynamics = DynamicsHead(
|
||||
base,
|
||||
args.n_observation, args.n_action, args.critic_layers,
|
||||
activation=args.critic_activation,
|
||||
layer_norm=args.critic_layer_norm,
|
||||
parameters_noise=args.critic_parameters_noise,
|
||||
parameters_noise_factorised=args.critic_parameters_noise_factorised)
|
||||
args.n_observation, args.n_action,
|
||||
# args.dynamics_layers,
|
||||
# activation=args.dynamics_activation,
|
||||
# layer_norm=args.dynamics_layer_norm,
|
||||
# parameters_noise=args.dynamics_parameters_noise,
|
||||
# parameters_noise_factorised=args.dynamics_parameters_noise_factorised
|
||||
)
|
||||
|
||||
pprint(actor)
|
||||
pprint(critic)
|
||||
@@ -205,7 +209,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
return act_fn, update_fn, save_fn
|
||||
|
||||
|
||||
def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward):
|
||||
def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward, shared_state_normalizer, shared_reward_normalizer):
|
||||
workerseed = args.seed + 241 * args.thread
|
||||
set_global_seeds(workerseed)
|
||||
|
||||
@@ -226,6 +230,9 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
env = create_env(args)
|
||||
random_process = create_random_process(args)
|
||||
|
||||
state_normalizer = StaticNormalizer(env.state_dim)
|
||||
reward_normalizer = StaticNormalizer(1)
|
||||
|
||||
actor_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.actor_lr,
|
||||
@@ -275,6 +282,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
|
||||
env.seed(seed)
|
||||
observation = env.reset() #seed=seed, difficulty=args.difficulty)
|
||||
observation = state_normalizer(observation)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
@@ -282,10 +290,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
action = act_fn(observation, noise=epsilon*random_process.sample())
|
||||
next_observation, reward, done, _ = env.step(action)
|
||||
|
||||
buffer.add(observation, action, reward, next_observation, done)
|
||||
episode_metrics["reward"] += reward
|
||||
episode_metrics["step"] += 1
|
||||
|
||||
next_observation = state_normalizer(next_observation)
|
||||
reward = reward_normalizer(reward)
|
||||
|
||||
buffer.add(observation, action, reward, next_observation, done)
|
||||
|
||||
if len(buffer) >= args.train_steps:
|
||||
|
||||
if args.prioritized_replay:
|
||||
@@ -314,8 +326,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
|
||||
episode += 1
|
||||
|
||||
if episode_metrics["reward"] > 15.0 * args.reward_scale \
|
||||
and episode_metrics["reward"] > best_reward.value:
|
||||
if episode_metrics["reward"] > best_reward.value:
|
||||
best_reward.value = episode_metrics["reward"]
|
||||
logger.scalar_summary("best reward", best_reward.value, episode)
|
||||
save_fn(episode)
|
||||
@@ -340,10 +351,18 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
|
||||
if episode % args.save_step == 0:
|
||||
save_fn(episode)
|
||||
logger.info('episode %s, metrics %s', episode, episode_metrics)
|
||||
|
||||
if elapsed_time > 86400 * args.max_train_days:
|
||||
episode = args.max_episodes + 1
|
||||
|
||||
# Sync normalizers
|
||||
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
|
||||
state_normalizer.online_stats.zero()
|
||||
|
||||
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
|
||||
reward_normalizer.online_stats.zero()
|
||||
|
||||
save_fn(episode)
|
||||
|
||||
raise KeyboardInterrupt
|
||||
@@ -460,7 +479,7 @@ def train_single_thread(
|
||||
def play_single_thread(
|
||||
actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn,
|
||||
global_episode, global_update_step, episodes_queue,
|
||||
best_reward):
|
||||
best_reward, shared_state_normalizer, shared_reward_normalizer):
|
||||
workerseed = args.seed + 241 * args.thread
|
||||
set_global_seeds(workerseed)
|
||||
|
||||
@@ -482,6 +501,10 @@ def play_single_thread(
|
||||
cycle_len=epsilon_cycle_len,
|
||||
num_cycles=args.max_episodes // epsilon_cycle_len)
|
||||
|
||||
env = create_env(args)
|
||||
state_normalizer = StaticNormalizer(env.state_dim)
|
||||
reward_normalizer = StaticNormalizer(1)
|
||||
|
||||
episode = 1
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
@@ -501,6 +524,7 @@ def play_single_thread(
|
||||
|
||||
env.seed(seed)
|
||||
observation = env.reset()#seed=seed, difficulty=args.difficulty)
|
||||
observation = state_normalizer(observation)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
@@ -509,10 +533,15 @@ def play_single_thread(
|
||||
action = act_fn(observation, noise=epsilon * random_process.sample())
|
||||
next_observation, reward, done, _ = env.step(action)
|
||||
|
||||
replay.append((observation, action, reward, next_observation, done))
|
||||
|
||||
episode_metrics["reward"] += reward
|
||||
episode_metrics["step"] += 1
|
||||
|
||||
next_observation = state_normalizer(next_observation)
|
||||
reward = reward_normalizer(reward)
|
||||
|
||||
replay.append((observation, action, reward, next_observation, done))
|
||||
|
||||
observation = next_observation
|
||||
|
||||
episodes_queue.put(replay)
|
||||
@@ -544,4 +573,11 @@ def play_single_thread(
|
||||
if elapsed_time > 86400 * args.max_train_days:
|
||||
global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1
|
||||
|
||||
# Sync normalizers
|
||||
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
|
||||
state_normalizer.online_stats.zero()
|
||||
|
||||
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
|
||||
reward_normalizer.online_stats.zero()
|
||||
|
||||
raise KeyboardInterrupt
|
||||
|
||||
+6
-4
@@ -147,6 +147,8 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
json.dump(vars(args), fout, indent=4, ensure_ascii=False, sort_keys=True)
|
||||
|
||||
env = create_env(args)
|
||||
shared_state_normalizer = StaticNormalizer(env.state_dim)
|
||||
shared_reward_normalizer = StaticNormalizer(1)
|
||||
|
||||
if args.flip_state_action and hasattr(env, "state_transform"):
|
||||
args.flip_states = env.state_transform.flip_states
|
||||
@@ -202,14 +204,14 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
if args.debug:
|
||||
# # run a single thread in the foreground so we can debug easier
|
||||
args.thread = 1
|
||||
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
|
||||
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward, shared_state_normalizer, shared_reward_normalizer)
|
||||
else:
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
p = mp.Process(
|
||||
target=multi_thread,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
best_reward))
|
||||
best_reward, shared_state_normalizer, shared_reward_normalizer))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
else:
|
||||
@@ -224,7 +226,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
|
||||
|
||||
args.thread = 2
|
||||
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward)
|
||||
play_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue, best_reward, shared_state_normalizer, shared_reward_normalizer)
|
||||
else:
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
@@ -241,7 +243,7 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
target=play_single,
|
||||
args=(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns,
|
||||
global_episode, global_update_step, episodes_queue,
|
||||
best_reward))
|
||||
best_reward, shared_state_normalizer, shared_reward_normalizer))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user