Files
Run-Skeleton-Run/ddpg/model.py
wassname f4c45b03db fix norm
2018-01-21 12:45:56 +08:00

570 lines
20 KiB
Python

import random
import numpy as np
import torch
import queue as py_queue
import time
import torch.nn as nn
from pprint import pprint
from ddpg.nets import Base, ActorHead, CriticHead, DynamicsHead
from common.torch_util import to_numpy, to_tensor, soft_update
from common.misc_util import create_if_need, set_global_seeds
from common.logger import Logger
from common.buffers import create_buffer
from common.normalizer import StaticNormalizer, Normalizer
from common.loss import create_loss, create_decay_fn
from common.env_wrappers import create_env
from common.random_process import create_random_process
def create_model(args):
base = Base(
args.n_observation, args.n_action, args.base_layers,
activation=args.base_activation,
layer_norm=args.base_layer_norm,
parameters_noise=args.base_parameters_noise,
parameters_noise_factorised=args.base_parameters_noise_factorised
)
actor = ActorHead(
base,
args.n_observation, args.n_action,
action_scale=args.action_scale,
last_activation=nn.Tanh)
critic = CriticHead(
base,
args.n_observation, args.n_action,
)
dynamics = DynamicsHead(
base,
args.n_observation, args.n_action,
)
pprint(actor)
pprint(critic)
pprint(dynamics)
return actor, critic, dynamics
def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args):
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
dynamics_optim = torch.optim.Adam(dynamics.parameters(), lr=args.actor_lr)
criterion = create_loss(args)
dynamics_criterion = nn.MSELoss()
low_action_boundary = -1.
high_action_boundary = 1.
def act_fn(observation, noise=0):
nonlocal actor
action = to_numpy(actor(to_tensor(np.array([observation], dtype=np.float32)))).squeeze(0)
action += noise
action = np.clip(action, low_action_boundary, high_action_boundary)
return action
def update_fn(
observations, actions, rewards, next_observations, dones, weights,
actor_lr=1e-4, critic_lr=1e-3, dynamics_lr=1e-4):
nonlocal actor, critic, dynamics, target_actor, target_critic, target_dynamics, actor_optim, critic_optim, dynamics_optim
if hasattr(args, "flip_states"):
observations_flip = args.flip_states(observations)
next_observations_flip = args.flip_states(next_observations)
actions_flip = np.zeros_like(actions)
actions_flip[:, :args.n_action // 2] = actions[:, args.n_action // 2:]
actions_flip[:, args.n_action // 2:] = actions[:, :args.n_action // 2]
observations = np.concatenate((observations, observations_flip))
actions = np.concatenate((actions, actions_flip))
rewards = np.tile(rewards.ravel(), 2)
next_observations = np.concatenate((next_observations, next_observations_flip))
dones = np.tile(dones.ravel(), 2)
dones = dones[:, None].astype(np.bool)
rewards = rewards[:, None].astype(np.float32)
dones = to_tensor(np.invert(dones).astype(np.float32))
rewards = to_tensor(rewards)
weights = to_tensor(weights, requires_grad=False)
# Dynamics update
next_observations_pred = dynamics(
to_tensor(observations),
to_tensor(actions)
)
dynamics_loss = criterion(
next_observations_pred,
to_tensor(next_observations),
weights=torch.stack([weights, weights, weights], 1)
)
dynamics.zero_grad()
dynamics_loss.backward()
torch.nn.utils.clip_grad_norm(dynamics.parameters(), args.grad_clip)
for param_group in actor_optim.param_groups:
param_group["lr"] = dynamics_lr # TODO change to dynamics lr
dynamics_optim.step()
# Critic update
next_next_observations_pred = target_dynamics(
to_tensor(next_observations, volatile=True),
target_actor(to_tensor(next_observations, volatile=True)),
)
next_v_values = target_critic(next_next_observations_pred)
next_v_values.volatile = False
# next_v_values = target_critic(
# to_tensor(next_observations, volatile=True),
# target_actor(to_tensor(next_observations, volatile=True)),
# )
# next_v_values.volatile = False
reward_predicted = dones * args.gamma * next_v_values
td_target = rewards + reward_predicted
critic.zero_grad()
v_values = critic(
dynamics(
to_tensor(observations),
to_tensor(actions)
)
)
value_loss = criterion(v_values, td_target, weights=weights)
value_loss.backward()
torch.nn.utils.clip_grad_norm(critic.parameters(), args.grad_clip)
for param_group in critic_optim.param_groups:
param_group["lr"] = critic_lr
critic_optim.step()
# Actor update
actor.zero_grad()
policy_loss = -critic(
dynamics(
to_tensor(observations),
actor(to_tensor(observations))
)
)
policy_loss = torch.mean(policy_loss * weights)
policy_loss.backward()
torch.nn.utils.clip_grad_norm(actor.parameters(), args.grad_clip)
for param_group in actor_optim.param_groups:
param_group["lr"] = actor_lr
actor_optim.step()
# Target update
soft_update(target_actor, actor, args.tau)
soft_update(target_critic, critic, args.tau)
metrics = {
"value_loss": value_loss,
"policy_loss": policy_loss,
"dynamics_loss": dynamics_loss
}
td_v_values = critic(
dynamics(
to_tensor(observations, volatile=True, requires_grad=False),
to_tensor(actions, volatile=True, requires_grad=False)
)
)
td_error = td_target - td_v_values
info = {
"td_error": to_numpy(td_error)
}
return metrics, info
def save_fn(episode=None):
nonlocal actor, critic, dynamics
if episode is None:
save_path = args.logdir
else:
save_path = "{}/episode_{}".format(args.logdir, episode)
create_if_need(save_path)
torch.save(actor.state_dict(), "{}/actor_state_dict.pkl".format(save_path))
torch.save(critic.state_dict(), "{}/critic_state_dict.pkl".format(save_path))
torch.save(target_actor.state_dict(), "{}/target_actor_state_dict.pkl".format(save_path))
torch.save(target_critic.state_dict(), "{}/target_critic_state_dict.pkl".format(save_path))
return act_fn, update_fn, save_fn
def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn, best_reward, shared_state_normalizer, shared_reward_normalizer):
workerseed = args.seed + 241 * args.thread
set_global_seeds(workerseed)
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
create_if_need(args.logdir)
act_fn, update_fn, save_fn = prepare_fn(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args)
logger = Logger(args.logdir)
buffer = create_buffer(args)
if args.prioritized_replay:
beta_deacy_fn = create_decay_fn(
"linear",
initial_value=args.prioritized_replay_beta0,
final_value=1.0,
max_step=args.max_episodes)
env = create_env(args)
random_process = create_random_process(args)
state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape)))
reward_normalizer = StaticNormalizer(1)
actor_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.actor_lr,
final_value=args.actor_lr_end,
max_step=args.max_episodes)
critic_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.critic_lr,
final_value=args.critic_lr_end,
max_step=args.max_episodes)
dynamics_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.dynamics_lr,
final_value=args.dynamics_lr_end,
max_step=args.max_episodes)
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
epsilon_decay_fn = create_decay_fn(
"cycle",
initial_value=args.initial_epsilon,
final_value=args.final_epsilon,
cycle_len=epsilon_cycle_len,
num_cycles=args.max_episodes // epsilon_cycle_len)
episode = 0
step = 0
start_time = time.time()
while episode < args.max_episodes:
if episode % 100 == 0:
env = create_env(args)
seed = random.randrange(2 ** 32 - 2)
actor_lr = actor_learning_rate_decay_fn(episode)
critic_lr = critic_learning_rate_decay_fn(episode)
dynamics_lr = dynamics_learning_rate_decay_fn(episode)
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
episode_metrics = {
"value_loss": 0.0,
"policy_loss": 0.0,
"dynamics_loss": 0.0,
"reward": 0.0,
"step": 0,
"epsilon": epsilon
}
env.seed(seed)
observation = env.reset() #seed=seed, difficulty=args.difficulty)
observation = state_normalizer(observation)
random_process.reset_states()
done = False
while not done:
action = act_fn(observation, noise=epsilon*random_process.sample())
next_observation, reward, done, _ = env.step(action)
episode_metrics["reward"] += reward
episode_metrics["step"] += 1
next_observation = state_normalizer(next_observation)
reward = reward_normalizer(reward)
buffer.add(observation, action, reward, next_observation, done)
if len(buffer) >= args.train_steps:
if args.prioritized_replay:
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
weights, batch_idxes) = \
buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode))
else:
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
buffer.sample(batch_size=args.batch_size)
weights, batch_idxes = np.ones_like(tr_rewards), None
step_metrics, step_info = update_fn(
tr_observations, tr_actions, tr_rewards,
tr_next_observations, tr_dones,
weights, actor_lr, critic_lr, dynamics_lr)
if args.prioritized_replay:
new_priorities = np.abs(step_info["td_error"]) + 1e-6
buffer.update_priorities(batch_idxes, new_priorities)
for key, value in step_metrics.items():
value = to_numpy(value)[0]
episode_metrics[key] += value
observation = next_observation
episode += 1
if episode_metrics["reward"] > best_reward.value:
best_reward.value = episode_metrics["reward"]
logger.scalar_summary("best reward", best_reward.value, episode)
save_fn(episode)
step += episode_metrics["step"]
elapsed_time = time.time() - start_time
for key, value in episode_metrics.items():
value = value if "loss" not in key else value / episode_metrics["step"]
logger.scalar_summary(key, value, episode)
logger.scalar_summary(
"episode per minute",
episode / elapsed_time * 60,
episode)
logger.scalar_summary(
"step per second",
step / elapsed_time,
episode)
logger.scalar_summary("actor lr", actor_lr, episode)
logger.scalar_summary("critic lr", critic_lr, episode)
logger.scalar_summary("dynamics_lr", dynamics_lr, episode)
if episode % args.save_step == 0:
save_fn(episode)
logger.info('episode %s, metrics %s', episode, episode_metrics)
if elapsed_time > 86400 * args.max_train_days:
episode = args.max_episodes + 1
# Sync normalizers
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
state_normalizer.online_stats.zero()
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
reward_normalizer.online_stats.zero()
save_fn(episode)
raise KeyboardInterrupt
def train_single_thread(
actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn,
global_episode, global_update_step, episodes_queue):
workerseed = args.seed + 241 * args.thread
set_global_seeds(workerseed)
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
create_if_need(args.logdir)
_, update_fn, save_fn = prepare_fn(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args)
logger = Logger(args.logdir)
buffer = create_buffer(args)
if args.prioritized_replay:
beta_deacy_fn = create_decay_fn(
"linear",
initial_value=args.prioritized_replay_beta0,
final_value=1.0,
max_step=args.max_update_steps)
actor_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.actor_lr,
final_value=args.actor_lr_end,
max_step=args.max_update_steps)
critic_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.critic_lr,
final_value=args.critic_lr_end,
max_step=args.max_update_steps)
dynamics_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.dynamics_lr,
final_value=args.dynamics_lr_end,
max_step=args.max_update_steps)
update_step = 0
received_examples = 1 # just hack
while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
and global_update_step.value < args.max_update_steps * args.num_train_threads:
actor_lr = actor_learning_rate_decay_fn(update_step)
critic_lr = critic_learning_rate_decay_fn(update_step)
dynamics_lr = dynamics_learning_rate_decay_fn(update_step)
actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr))
dynamics_lr = min(args.dynamics_lr, max(args.dynamics_lr_end, dynamics_lr))
critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr))
while True:
try:
replay = episodes_queue.get_nowait()
for (observation, action, reward, next_observation, done) in replay:
buffer.add(observation, action, reward, next_observation, done)
received_examples += len(replay)
except py_queue.Empty:
break
if len(buffer) >= args.train_steps:
if args.prioritized_replay:
beta = beta_deacy_fn(update_step)
beta = min(1.0, max(args.prioritized_replay_beta0, beta))
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
weights, batch_idxes) = \
buffer.sample(
batch_size=args.batch_size,
beta=beta)
else:
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
buffer.sample(batch_size=args.batch_size)
weights, batch_idxes = np.ones_like(tr_rewards), None
step_metrics, step_info = update_fn(
tr_observations, tr_actions, tr_rewards,
tr_next_observations, tr_dones,
weights, actor_lr, critic_lr, dynamics_lr)
update_step += 1
global_update_step.value += 1
if args.prioritized_replay:
new_priorities = np.abs(step_info["td_error"]) + 1e-6
buffer.update_priorities(batch_idxes, new_priorities)
for key, value in step_metrics.items():
logger.scalar_summary(key, value, update_step)
logger.scalar_summary("actor lr", actor_lr, update_step)
logger.scalar_summary("critic lr", critic_lr, update_step)
logger.scalar_summary("dynamics lr", dynamics_lr, update_step)
if update_step % args.save_step == 0:
save_fn(update_step)
else:
time.sleep(1)
logger.scalar_summary("buffer size", len(buffer), global_episode.value)
logger.scalar_summary(
"updates per example",
update_step * args.batch_size / received_examples,
global_episode.value)
save_fn(update_step)
raise KeyboardInterrupt
def play_single_thread(
actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, prepare_fn,
global_episode, global_update_step, episodes_queue,
best_reward, shared_state_normalizer, shared_reward_normalizer):
workerseed = args.seed + 241 * args.thread
set_global_seeds(workerseed)
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
create_if_need(args.logdir)
act_fn, _, save_fn = prepare_fn(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args)
logger = Logger(args.logdir)
env = create_env(args)
random_process = create_random_process(args)
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
epsilon_decay_fn = create_decay_fn(
"cycle",
initial_value=args.initial_epsilon,
final_value=args.final_epsilon,
cycle_len=epsilon_cycle_len,
num_cycles=args.max_episodes // epsilon_cycle_len)
env = create_env(args)
state_normalizer = StaticNormalizer(int(np.prod(env.observation_space.shape)))
reward_normalizer = StaticNormalizer(1)
episode = 1
step = 0
start_time = time.time()
while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
and global_update_step.value < args.max_update_steps * args.num_train_threads:
if episode % 100 == 0:
env = create_env(args)
seed = random.randrange(2 ** 32 - 2)
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
episode_metrics = {
"reward": 0.0,
"step": 0,
"epsilon": epsilon
}
env.seed(seed)
observation = env.reset()#seed=seed, difficulty=args.difficulty)
observation = state_normalizer(observation)
random_process.reset_states()
done = False
replay = []
while not done:
action = act_fn(observation, noise=epsilon * random_process.sample())
next_observation, reward, done, _ = env.step(action)
episode_metrics["reward"] += reward
episode_metrics["step"] += 1
next_observation = state_normalizer(next_observation)
reward = reward_normalizer(reward)
replay.append((observation, action, reward, next_observation, done))
observation = next_observation
episodes_queue.put(replay)
episode += 1
global_episode.value += 1
if episode_metrics["reward"] > best_reward.value:
best_reward.value = episode_metrics["reward"]
logger.scalar_summary("best reward", best_reward.value, episode)
if episode_metrics["reward"] > 15.0 * args.reward_scale:
save_fn(episode)
step += episode_metrics["step"]
elapsed_time = time.time() - start_time
for key, value in episode_metrics.items():
logger.scalar_summary(key, value, episode)
logger.scalar_summary(
"episode per minute",
episode / elapsed_time * 60,
episode)
logger.scalar_summary(
"step per second",
step / elapsed_time,
episode)
if elapsed_time > 86400 * args.max_train_days:
global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1
# Sync normalizers
shared_state_normalizer.offline_stats.merge(state_normalizer.online_stats)
state_normalizer.online_stats.zero()
shared_reward_normalizer.offline_stats.merge(reward_normalizer.online_stats)
reward_normalizer.online_stats.zero()
raise KeyboardInterrupt