mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 17:14:10 +08:00
add dynamics lr
This commit is contained in:
+23
-6
@@ -75,7 +75,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
|
||||
def update_fn(
|
||||
observations, actions, rewards, next_observations, dones, weights,
|
||||
actor_lr=1e-4, critic_lr=1e-3):
|
||||
actor_lr=1e-4, critic_lr=1e-3, dynamics_lr=1e-4):
|
||||
nonlocal actor, critic, dynamics, target_actor, target_critic, target_dynamics, actor_optim, critic_optim, dynamics_optim
|
||||
|
||||
if hasattr(args, "flip_states"):
|
||||
@@ -112,7 +112,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
dynamics_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm(dynamics.parameters(), args.grad_clip)
|
||||
for param_group in actor_optim.param_groups:
|
||||
param_group["lr"] = actor_lr # TODO change to dynamics lr
|
||||
param_group["lr"] = dynamics_lr # TODO change to dynamics lr
|
||||
dynamics_optim.step()
|
||||
|
||||
# Critic update
|
||||
@@ -172,7 +172,8 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
|
||||
|
||||
metrics = {
|
||||
"value_loss": value_loss,
|
||||
"policy_loss": policy_loss
|
||||
"policy_loss": policy_loss,
|
||||
"dynamics_loss": dynamics_loss
|
||||
}
|
||||
|
||||
td_v_values = critic(
|
||||
@@ -233,9 +234,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
max_step=args.max_episodes)
|
||||
critic_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.critic_lr,
|
||||
initial_value=arstep_metricsgs.critic_lr,
|
||||
final_value=args.critic_lr_end,
|
||||
max_step=args.max_episodes)
|
||||
dynamics_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.dynamics_lr,
|
||||
final_value=args.dynamics_lr_end,
|
||||
max_step=args.max_episodes)
|
||||
|
||||
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
|
||||
|
||||
@@ -256,11 +262,13 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
|
||||
actor_lr = actor_learning_rate_decay_fn(episode)
|
||||
critic_lr = critic_learning_rate_decay_fn(episode)
|
||||
dynamics_lr = dynamics_learning_rate_decay_fn(episode)
|
||||
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
|
||||
|
||||
episode_metrics = {
|
||||
"value_loss": 0.0,
|
||||
"policy_loss": 0.0,
|
||||
"dynamics_loss": 0.0,
|
||||
"reward": 0.0,
|
||||
"step": 0,
|
||||
"epsilon": epsilon
|
||||
@@ -293,7 +301,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
step_metrics, step_info = update_fn(
|
||||
tr_observations, tr_actions, tr_rewards,
|
||||
tr_next_observations, tr_dones,
|
||||
weights, actor_lr, critic_lr)
|
||||
weights, actor_lr, critic_lr, dynamics_lr)
|
||||
|
||||
if args.prioritized_replay:
|
||||
new_priorities = np.abs(step_info["td_error"]) + 1e-6
|
||||
@@ -329,6 +337,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
|
||||
episode)
|
||||
logger.scalar_summary("actor lr", actor_lr, episode)
|
||||
logger.scalar_summary("critic lr", critic_lr, episode)
|
||||
logger.scalar_summary("dynamics_lr", dynamics_lr, episode)
|
||||
|
||||
if episode % args.save_step == 0:
|
||||
save_fn(episode)
|
||||
@@ -373,6 +382,11 @@ def train_single_thread(
|
||||
initial_value=args.critic_lr,
|
||||
final_value=args.critic_lr_end,
|
||||
max_step=args.max_update_steps)
|
||||
dynamics_learning_rate_decay_fn = create_decay_fn(
|
||||
"linear",
|
||||
initial_value=args.dynamics_lr,
|
||||
final_value=args.dynamics_lr_end,
|
||||
max_step=args.max_update_steps)
|
||||
|
||||
update_step = 0
|
||||
received_examples = 1 # just hack
|
||||
@@ -380,8 +394,10 @@ def train_single_thread(
|
||||
and global_update_step.value < args.max_update_steps * args.num_train_threads:
|
||||
actor_lr = actor_learning_rate_decay_fn(update_step)
|
||||
critic_lr = critic_learning_rate_decay_fn(update_step)
|
||||
dynamics_lr = dynamics_learning_rate_decay_fn(update_step)
|
||||
|
||||
actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr))
|
||||
dynamics_lr = min(args.dynamics_lr, max(args.dynamics_lr_end, dynamics_lr))
|
||||
critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr))
|
||||
|
||||
while True:
|
||||
@@ -410,7 +426,7 @@ def train_single_thread(
|
||||
step_metrics, step_info = update_fn(
|
||||
tr_observations, tr_actions, tr_rewards,
|
||||
tr_next_observations, tr_dones,
|
||||
weights, actor_lr, critic_lr)
|
||||
weights, actor_lr, critic_lr, dynamics_lr)
|
||||
|
||||
update_step += 1
|
||||
global_update_step.value += 1
|
||||
@@ -425,6 +441,7 @@ def train_single_thread(
|
||||
|
||||
logger.scalar_summary("actor lr", actor_lr, update_step)
|
||||
logger.scalar_summary("critic lr", critic_lr, update_step)
|
||||
logger.scalar_summary("dynamics lr", dynamics_lr, update_step)
|
||||
|
||||
if update_step % args.save_step == 0:
|
||||
save_fn(update_step)
|
||||
|
||||
+23
-6
@@ -154,19 +154,18 @@ class CriticHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ActorHead(nn.Module):
|
||||
def __init__(self, base, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
last_activation=torch.nn.Tanh, init_w=3e-3):
|
||||
last_activation=lambda x: x, init_w=3e-3):
|
||||
super(ActorHead, self).__init__()
|
||||
self.base = base
|
||||
|
||||
self.policy_net = LinearNet(
|
||||
layers=[self.base.feature_net.output_shape, n_action],
|
||||
activation=last_activation,
|
||||
# activation=last_activation,
|
||||
layer_norm=False
|
||||
)
|
||||
self.init_weights(init_w)
|
||||
@@ -183,7 +182,6 @@ class ActorHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class DynamicsHead(nn.Module):
|
||||
def __init__(self, base, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
@@ -192,15 +190,34 @@ class DynamicsHead(nn.Module):
|
||||
init_w=3e-3):
|
||||
super(DynamicsHead, self).__init__()
|
||||
self.base = base
|
||||
self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, n_observation)
|
||||
|
||||
if parameters_noise:
|
||||
def linear_layer(x_in, x_out):
|
||||
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
|
||||
else:
|
||||
linear_layer = nn.Linear
|
||||
|
||||
# self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape)
|
||||
self.value_net = LinearNet(
|
||||
layers=[self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape],
|
||||
activation=activation,
|
||||
layer_norm=layer_norm,
|
||||
linear_layer=linear_layer
|
||||
)
|
||||
self.value_net2 = nn.Linear(self.base.feature_net.output_shape, n_observation)
|
||||
self.init_weights(init_w)
|
||||
|
||||
def init_weights(self, init_w):
|
||||
self.value_net.weight.data.uniform_(-init_w, init_w)
|
||||
self.value_net2.weight.data.uniform_(-init_w, init_w)
|
||||
# self.value_net.weight.data.uniform_(-init_w, init_w)
|
||||
for layer in self.value_net.net:
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.weight.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, observation, action):
|
||||
action = to_torch_variable(action)
|
||||
x = self.base.forward(observation)
|
||||
x = torch.cat((x, action), dim=1)
|
||||
x = self.value_net.forward(x)
|
||||
x = self.value_net2.forward(x)
|
||||
return x
|
||||
|
||||
+11
-3
@@ -31,7 +31,7 @@ def parse_args():
|
||||
parser.add_argument('--reward-scale', type=float, default=1.)
|
||||
boolean_flag(parser, "flip-state-action", default=False)
|
||||
|
||||
for agent in ["actor", "critic"]:
|
||||
for agent in ["actor", "critic", "dynamics"]:
|
||||
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
|
||||
parser.add_argument('--{}-activation'.format(agent), type=str, default="relu")
|
||||
boolean_flag(parser, "{}-layer-norm".format(agent), default=False)
|
||||
@@ -195,10 +195,18 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
best_reward = Value("f", 0.0)
|
||||
|
||||
try:
|
||||
|
||||
if args.num_threads == args.num_train_threads == 1:
|
||||
# run a single thread in the foreground so we can debug easier
|
||||
# # run a single thread in the foreground so we can debug easier
|
||||
# args.thread = 1
|
||||
# multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
|
||||
|
||||
# or debug the single thread funcs
|
||||
args.thread = 1
|
||||
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
|
||||
global_episode = Value("i", 0)
|
||||
global_update_step = Value("i", 0)
|
||||
episodes_queue = mp.Queue()
|
||||
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
|
||||
elif args.num_threads == args.num_train_threads:
|
||||
for rank in range(args.num_threads):
|
||||
args.thread = rank
|
||||
|
||||
Reference in New Issue
Block a user