mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 18:41:01 +08:00
pytorch version
This commit is contained in:
+477
@@ -0,0 +1,477 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import queue as py_queue
|
||||
import time
|
||||
import torch.nn as nn
|
||||
from pprint import pprint
|
||||
|
||||
from ddpg.nets import Actor, Critic
|
||||
from common.torch_util import to_numpy, to_tensor, soft_update
|
||||
from common.misc_util import create_if_need, set_global_seeds
|
||||
from common.logger import Logger
|
||||
from common.buffers import create_buffer
|
||||
from common.loss import create_loss, create_decay_fn
|
||||
from common.env_wrappers import create_env
|
||||
from common.random_process import create_random_process
|
||||
|
||||
|
||||
def create_model(args):
|
||||
actor = Actor(
|
||||
args.n_observation, args.n_action, args.actor_layers,
|
||||
activation=args.actor_activation,
|
||||
layer_norm=args.actor_layer_norm,
|
||||
parameters_noise=args.actor_parameters_noise,
|
||||
parameters_noise_factorised=args.actor_parameters_noise_factorised,
|
||||
last_activation=nn.Tanh)
|
||||
critic = Critic(
|
||||
args.n_observation, args.n_action, args.critic_layers,
|
||||
activation=args.critic_activation,
|
||||
layer_norm=args.critic_layer_norm,
|
||||
parameters_noise=args.critic_parameters_noise,
|
||||
parameters_noise_factorised=args.critic_parameters_noise_factorised)
|
||||
|
||||
pprint(actor)
|
||||
pprint(critic)
|
||||
|
||||
return actor, critic
|
||||
|
||||
|
||||
def create_act_update_fns(actor, critic, target_actor, target_critic, args):
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
|
||||
criterion = create_loss(args)
|
||||
|
||||
low_action_boundary = -1.
|
||||
high_action_boundary = 1.
|
||||
|
||||
def act_fn(observation, noise=0):
|
||||
nonlocal actor
|
||||
action = to_numpy(actor(to_tensor(np.array([observation], dtype=np.float32)))).squeeze(0)
|
||||
action += noise
|
||||
action = np.clip(action, low_action_boundary, high_action_boundary)
|
||||
return action
|
||||
|
||||
def update_fn(
|
||||
observations, actions, rewards, next_observations, dones, weights,
|
||||
actor_lr=1e-4, critic_lr=1e-3):
|
||||
nonlocal actor, critic, target_actor, target_critic, actor_optim, critic_optim
|
||||
|
||||
if hasattr(args, "flip_states"):
|
||||
observations_flip = args.flip_states(observations)
|
||||
next_observations_flip = args.flip_states(next_observations)
|
||||
actions_flip = np.zeros_like(actions)
|
||||
actions_flip[:, :args.n_action // 2] = actions[:, args.n_action // 2:]
|
||||
actions_flip[:, args.n_action // 2:] = actions[:, :args.n_action // 2]
|
||||
|
||||
observations = np.concatenate((observations, observations_flip))
|
||||
actions = np.concatenate((actions, actions_flip))
|
||||
rewards = np.tile(rewards.ravel(), 2)
|
||||
next_observations = np.concatenate((next_observations, next_observations_flip))
|
||||
dones = np.tile(dones.ravel(), 2)
|
||||
|
||||
dones = dones[:, None].astype(np.bool)
|
||||
rewards = rewards[:, None].astype(np.float32)
|
||||
|
||||
dones = to_tensor(np.invert(dones).astype(np.float32))
|
||||
rewards = to_tensor(rewards)
|
||||
weights = to_tensor(weights, requires_grad=False)
|
||||
|
||||
next_v_values = target_critic(
|
||||
to_tensor(next_observations, volatile=True),
|
||||
target_actor(to_tensor(next_observations, volatile=True)),
|
||||
)
|
||||
next_v_values.volatile = False
|
||||
|
||||
reward_predicted = dones * args.gamma * next_v_values
|
||||
td_target = rewards + reward_predicted
|
||||
|
||||
# Critic update
|
||||
critic.zero_grad()
|
||||
|
||||
v_values = critic(to_tensor(observations), to_tensor(actions))
|
||||
value_loss = criterion(v_values, td_target, weights=weights)
|
||||
value_loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(critic.parameters(), args.grad_clip)
|
||||
for param_group in critic_optim.param_groups:
|
||||
param_group["lr"] = critic_lr
|
||||
|
||||
critic_optim.step()
|
||||
|
||||
# Actor update
|
||||
actor.zero_grad()
|
||||
|
||||
policy_loss = -critic(
|
||||
to_tensor(observations),
|
||||
actor(to_tensor(observations))
|
||||
)
|
||||
|
||||
policy_loss = torch.mean(policy_loss * weights)
|
||||
policy_loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm(actor.parameters(), args.grad_clip)
|
||||
for param_group in actor_optim.param_groups:
|
||||
param_group["lr"] = actor_lr
|
||||
|
||||
actor_optim.step()
|
||||
|
||||
# Target update
|
||||
soft_update(target_actor, actor, args.tau)
|
||||
soft_update(target_critic, critic, args.tau)
|
||||
|
||||
metrics = {
|
||||
"value_loss": value_loss,
|
||||
"policy_loss": policy_loss
|
||||
}
|
||||
|
||||
td_v_values = critic(
|
||||
to_tensor(observations, volatile=True, requires_grad=False),
|
||||
to_tensor(actions, volatile=True, requires_grad=False))
|
||||
td_error = td_target - td_v_values
|
||||
|
||||
info = {
|
||||
"td_error": to_numpy(td_error)
|
||||
}
|
||||
|
||||
return metrics, info
|
||||
|
||||
def save_fn(episode=None):
|
||||
nonlocal actor, critic
|
||||
if episode is None:
|
||||
save_path = args.logdir
|
||||
else:
|
||||
save_path = "{}/episode_{}".format(args.logdir, episode)
|
||||
create_if_need(save_path)
|
||||
torch.save(actor.state_dict(), "{}/actor_state_dict.pkl".format(save_path))
|
||||
torch.save(critic.state_dict(), "{}/critic_state_dict.pkl".format(save_path))
|
||||
torch.save(target_actor.state_dict(), "{}/target_actor_state_dict.pkl".format(save_path))
|
||||
torch.save(target_critic.state_dict(), "{}/target_critic_state_dict.pkl".format(save_path))
|
||||
|
||||
return act_fn, update_fn, save_fn
|
||||
|
||||
|
||||
def train_multi_thread(actor, critic, target_actor, target_critic, args, prepare_fn, best_reward):
|
||||
workerseed = args.seed + 241 * args.thread
|
||||
set_global_seeds(workerseed)
|
||||
|
||||
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
|
||||
create_if_need(args.logdir)
|
||||
|
||||
act_fn, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)
|
||||
logger = Logger(args.logdir)
|
||||
|
||||
buffer = create_buffer(args)
|
||||
if args.prioritized_replay:
|
||||
beta_deacy_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.prioritized_replay_beta0,
|
||||
final_value=1.0,
|
||||
max_step=args.max_episodes)
|
||||
|
||||
env = create_env(args)
|
||||
random_process = create_random_process(args)
|
||||
|
||||
actor_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.actor_lr,
|
||||
final_value=args.actor_lr_end,
|
||||
max_step=args.max_episodes)
|
||||
critic_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.critic_lr,
|
||||
final_value=args.critic_lr_end,
|
||||
max_step=args.max_episodes)
|
||||
|
||||
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
|
||||
|
||||
epsilon_decay_fn = create_decay_fn(
|
||||
"cycle",
|
||||
initial_value=args.initial_epsilon,
|
||||
final_value=args.final_epsilon,
|
||||
cycle_len=epsilon_cycle_len,
|
||||
num_cycles=args.max_episodes // epsilon_cycle_len)
|
||||
|
||||
episode = 0
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
while episode < args.max_episodes:
|
||||
if episode % 100 == 0:
|
||||
env = create_env(args)
|
||||
seed = random.randrange(2 ** 32 - 2)
|
||||
|
||||
actor_lr = actor_learning_rate_decay_fn(episode)
|
||||
critic_lr = critic_learning_rate_decay_fn(episode)
|
||||
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
|
||||
|
||||
episode_metrics = {
|
||||
"value_loss": 0.0,
|
||||
"policy_loss": 0.0,
|
||||
"reward": 0.0,
|
||||
"step": 0,
|
||||
"epsilon": epsilon
|
||||
}
|
||||
|
||||
observation = env.reset(seed=seed, difficulty=args.difficulty)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = act_fn(observation, noise=epsilon*random_process.sample())
|
||||
next_observation, reward, done, _ = env.step(action)
|
||||
|
||||
buffer.add(observation, action, reward, next_observation, done)
|
||||
episode_metrics["reward"] += reward
|
||||
episode_metrics["step"] += 1
|
||||
|
||||
if len(buffer) >= args.train_steps:
|
||||
|
||||
if args.prioritized_replay:
|
||||
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
|
||||
weights, batch_idxes) = \
|
||||
buffer.sample(batch_size=args.batch_size, beta=beta_deacy_fn(episode))
|
||||
else:
|
||||
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
|
||||
buffer.sample(batch_size=args.batch_size)
|
||||
weights, batch_idxes = np.ones_like(tr_rewards), None
|
||||
|
||||
step_metrics, step_info = update_fn(
|
||||
tr_observations, tr_actions, tr_rewards,
|
||||
tr_next_observations, tr_dones,
|
||||
weights, actor_lr, critic_lr)
|
||||
|
||||
if args.prioritized_replay:
|
||||
new_priorities = np.abs(step_info["td_error"]) + 1e-6
|
||||
buffer.update_priorities(batch_idxes, new_priorities)
|
||||
|
||||
for key, value in step_metrics.items():
|
||||
value = to_numpy(value)[0]
|
||||
episode_metrics[key] += value
|
||||
|
||||
observation = next_observation
|
||||
|
||||
episode += 1
|
||||
|
||||
if episode_metrics["reward"] > 15.0 * args.reward_scale \
|
||||
and episode_metrics["reward"] > best_reward.value:
|
||||
best_reward.value = episode_metrics["reward"]
|
||||
logger.scalar_summary("best reward", best_reward.value, episode)
|
||||
save_fn(episode)
|
||||
|
||||
step += episode_metrics["step"]
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
for key, value in episode_metrics.items():
|
||||
value = value if "loss" not in key else value / episode_metrics["step"]
|
||||
logger.scalar_summary(key, value, episode)
|
||||
logger.scalar_summary(
|
||||
"episode per minute",
|
||||
episode / elapsed_time * 60,
|
||||
episode)
|
||||
logger.scalar_summary(
|
||||
"step per second",
|
||||
step / elapsed_time,
|
||||
episode)
|
||||
logger.scalar_summary("actor lr", actor_lr, episode)
|
||||
logger.scalar_summary("critic lr", critic_lr, episode)
|
||||
|
||||
if episode % args.save_step == 0:
|
||||
save_fn(episode)
|
||||
|
||||
if elapsed_time > 86400 * args.max_train_days:
|
||||
episode = args.max_episodes + 1
|
||||
|
||||
save_fn(episode)
|
||||
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def train_single_thread(
|
||||
actor, critic, target_actor, target_critic, args, prepare_fn,
|
||||
global_episode, global_update_step, episodes_queue):
|
||||
workerseed = args.seed + 241 * args.thread
|
||||
set_global_seeds(workerseed)
|
||||
|
||||
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
|
||||
create_if_need(args.logdir)
|
||||
|
||||
_, update_fn, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)
|
||||
|
||||
logger = Logger(args.logdir)
|
||||
|
||||
buffer = create_buffer(args)
|
||||
|
||||
if args.prioritized_replay:
|
||||
beta_deacy_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.prioritized_replay_beta0,
|
||||
final_value=1.0,
|
||||
max_step=args.max_update_steps)
|
||||
|
||||
actor_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.actor_lr,
|
||||
final_value=args.actor_lr_end,
|
||||
max_step=args.max_update_steps)
|
||||
critic_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.critic_lr,
|
||||
final_value=args.critic_lr_end,
|
||||
max_step=args.max_update_steps)
|
||||
|
||||
update_step = 0
|
||||
received_examples = 1 # just hack
|
||||
while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
|
||||
and global_update_step.value < args.max_update_steps * args.num_train_threads:
|
||||
actor_lr = actor_learning_rate_decay_fn(update_step)
|
||||
critic_lr = critic_learning_rate_decay_fn(update_step)
|
||||
|
||||
actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr))
|
||||
critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr))
|
||||
|
||||
while True:
|
||||
try:
|
||||
replay = episodes_queue.get_nowait()
|
||||
for (observation, action, reward, next_observation, done) in replay:
|
||||
buffer.add(observation, action, reward, next_observation, done)
|
||||
received_examples += len(replay)
|
||||
except py_queue.Empty:
|
||||
break
|
||||
|
||||
if len(buffer) >= args.train_steps:
|
||||
if args.prioritized_replay:
|
||||
beta = beta_deacy_fn(update_step)
|
||||
beta = min(1.0, max(args.prioritized_replay_beta0, beta))
|
||||
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones,
|
||||
weights, batch_idxes) = \
|
||||
buffer.sample(
|
||||
batch_size=args.batch_size,
|
||||
beta=beta)
|
||||
else:
|
||||
(tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones) = \
|
||||
buffer.sample(batch_size=args.batch_size)
|
||||
weights, batch_idxes = np.ones_like(tr_rewards), None
|
||||
|
||||
step_metrics, step_info = update_fn(
|
||||
tr_observations, tr_actions, tr_rewards,
|
||||
tr_next_observations, tr_dones,
|
||||
weights, actor_lr, critic_lr)
|
||||
|
||||
update_step += 1
|
||||
global_update_step.value += 1
|
||||
|
||||
if args.prioritized_replay:
|
||||
new_priorities = np.abs(step_info["td_error"]) + 1e-6
|
||||
buffer.update_priorities(batch_idxes, new_priorities)
|
||||
|
||||
for key, value in step_metrics.items():
|
||||
value = to_numpy(value)[0]
|
||||
logger.scalar_summary(key, value, update_step)
|
||||
|
||||
logger.scalar_summary("actor lr", actor_lr, update_step)
|
||||
logger.scalar_summary("critic lr", critic_lr, update_step)
|
||||
|
||||
if update_step % args.save_step == 0:
|
||||
save_fn(update_step)
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
logger.scalar_summary("buffer size", len(buffer), global_episode.value)
|
||||
logger.scalar_summary(
|
||||
"updates per example",
|
||||
update_step * args.batch_size / received_examples,
|
||||
global_episode.value)
|
||||
|
||||
save_fn(update_step)
|
||||
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def play_single_thread(
|
||||
actor, critic, target_actor, target_critic, args, prepare_fn,
|
||||
global_episode, global_update_step, episodes_queue,
|
||||
best_reward):
|
||||
workerseed = args.seed + 241 * args.thread
|
||||
set_global_seeds(workerseed)
|
||||
|
||||
args.logdir = "{}/thread_{}".format(args.logdir, args.thread)
|
||||
create_if_need(args.logdir)
|
||||
|
||||
act_fn, _, save_fn = prepare_fn(actor, critic, target_actor, target_critic, args)
|
||||
|
||||
logger = Logger(args.logdir)
|
||||
env = create_env(args)
|
||||
random_process = create_random_process(args)
|
||||
|
||||
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
|
||||
|
||||
epsilon_decay_fn = create_decay_fn(
|
||||
"cycle",
|
||||
initial_value=args.initial_epsilon,
|
||||
final_value=args.final_epsilon,
|
||||
cycle_len=epsilon_cycle_len,
|
||||
num_cycles=args.max_episodes // epsilon_cycle_len)
|
||||
|
||||
episode = 1
|
||||
step = 0
|
||||
start_time = time.time()
|
||||
while global_episode.value < args.max_episodes * (args.num_threads - args.num_train_threads) \
|
||||
and global_update_step.value < args.max_update_steps * args.num_train_threads:
|
||||
if episode % 100 == 0:
|
||||
env = create_env(args)
|
||||
seed = random.randrange(2 ** 32 - 2)
|
||||
|
||||
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
|
||||
|
||||
episode_metrics = {
|
||||
"reward": 0.0,
|
||||
"step": 0,
|
||||
"epsilon": epsilon
|
||||
}
|
||||
|
||||
observation = env.reset(seed=seed, difficulty=args.difficulty)
|
||||
random_process.reset_states()
|
||||
done = False
|
||||
|
||||
replay = []
|
||||
while not done:
|
||||
action = act_fn(observation, noise=epsilon * random_process.sample())
|
||||
next_observation, reward, done, _ = env.step(action)
|
||||
|
||||
replay.append((observation, action, reward, next_observation, done))
|
||||
episode_metrics["reward"] += reward
|
||||
episode_metrics["step"] += 1
|
||||
|
||||
observation = next_observation
|
||||
|
||||
episodes_queue.put(replay)
|
||||
|
||||
episode += 1
|
||||
global_episode.value += 1
|
||||
|
||||
if episode_metrics["reward"] > best_reward.value:
|
||||
best_reward.value = episode_metrics["reward"]
|
||||
logger.scalar_summary("best reward", best_reward.value, episode)
|
||||
|
||||
if episode_metrics["reward"] > 15.0 * args.reward_scale:
|
||||
save_fn(episode)
|
||||
|
||||
step += episode_metrics["step"]
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
for key, value in episode_metrics.items():
|
||||
logger.scalar_summary(key, value, episode)
|
||||
logger.scalar_summary(
|
||||
"episode per minute",
|
||||
episode / elapsed_time * 60,
|
||||
episode)
|
||||
logger.scalar_summary(
|
||||
"step per second",
|
||||
step / elapsed_time,
|
||||
episode)
|
||||
|
||||
if elapsed_time > 86400 * args.max_train_days:
|
||||
global_episode.value = args.max_episodes * (args.num_threads - args.num_train_threads) + 1
|
||||
|
||||
raise KeyboardInterrupt
|
||||
Reference in New Issue
Block a user