add dynamics lr

This commit is contained in:
wassname
2018-01-19 12:08:32 +08:00
parent a87a3ad7bb
commit 201fa2400f
3 changed files with 57 additions and 15 deletions
+23 -6
View File
@@ -75,7 +75,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
def update_fn(
observations, actions, rewards, next_observations, dones, weights,
actor_lr=1e-4, critic_lr=1e-3):
actor_lr=1e-4, critic_lr=1e-3, dynamics_lr=1e-4):
nonlocal actor, critic, dynamics, target_actor, target_critic, target_dynamics, actor_optim, critic_optim, dynamics_optim
if hasattr(args, "flip_states"):
@@ -112,7 +112,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
dynamics_loss.backward()
torch.nn.utils.clip_grad_norm(dynamics.parameters(), args.grad_clip)
for param_group in actor_optim.param_groups:
param_group["lr"] = actor_lr # TODO change to dynamics lr
param_group["lr"] = dynamics_lr # TODO change to dynamics lr
dynamics_optim.step()
# Critic update
@@ -172,7 +172,8 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic,
metrics = {
"value_loss": value_loss,
"policy_loss": policy_loss
"policy_loss": policy_loss,
"dynamics_loss": dynamics_loss
}
td_v_values = critic(
@@ -233,9 +234,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
max_step=args.max_episodes)
critic_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.critic_lr,
initial_value=arstep_metricsgs.critic_lr,
final_value=args.critic_lr_end,
max_step=args.max_episodes)
dynamics_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.dynamics_lr,
final_value=args.dynamics_lr_end,
max_step=args.max_episodes)
epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2)
@@ -256,11 +262,13 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
actor_lr = actor_learning_rate_decay_fn(episode)
critic_lr = critic_learning_rate_decay_fn(episode)
dynamics_lr = dynamics_learning_rate_decay_fn(episode)
epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode)))
episode_metrics = {
"value_loss": 0.0,
"policy_loss": 0.0,
"dynamics_loss": 0.0,
"reward": 0.0,
"step": 0,
"epsilon": epsilon
@@ -293,7 +301,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
step_metrics, step_info = update_fn(
tr_observations, tr_actions, tr_rewards,
tr_next_observations, tr_dones,
weights, actor_lr, critic_lr)
weights, actor_lr, critic_lr, dynamics_lr)
if args.prioritized_replay:
new_priorities = np.abs(step_info["td_error"]) + 1e-6
@@ -329,6 +337,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar
episode)
logger.scalar_summary("actor lr", actor_lr, episode)
logger.scalar_summary("critic lr", critic_lr, episode)
logger.scalar_summary("dynamics_lr", dynamics_lr, episode)
if episode % args.save_step == 0:
save_fn(episode)
@@ -373,6 +382,11 @@ def train_single_thread(
initial_value=args.critic_lr,
final_value=args.critic_lr_end,
max_step=args.max_update_steps)
dynamics_learning_rate_decay_fn = create_decay_fn(
"linear",
initial_value=args.dynamics_lr,
final_value=args.dynamics_lr_end,
max_step=args.max_update_steps)
update_step = 0
received_examples = 1 # just hack
@@ -380,8 +394,10 @@ def train_single_thread(
and global_update_step.value < args.max_update_steps * args.num_train_threads:
actor_lr = actor_learning_rate_decay_fn(update_step)
critic_lr = critic_learning_rate_decay_fn(update_step)
dynamics_lr = dynamics_learning_rate_decay_fn(update_step)
actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr))
dynamics_lr = min(args.dynamics_lr, max(args.dynamics_lr_end, dynamics_lr))
critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr))
while True:
@@ -410,7 +426,7 @@ def train_single_thread(
step_metrics, step_info = update_fn(
tr_observations, tr_actions, tr_rewards,
tr_next_observations, tr_dones,
weights, actor_lr, critic_lr)
weights, actor_lr, critic_lr, dynamics_lr)
update_step += 1
global_update_step.value += 1
@@ -425,6 +441,7 @@ def train_single_thread(
logger.scalar_summary("actor lr", actor_lr, update_step)
logger.scalar_summary("critic lr", critic_lr, update_step)
logger.scalar_summary("dynamics lr", dynamics_lr, update_step)
if update_step % args.save_step == 0:
save_fn(update_step)
+23 -6
View File
@@ -154,19 +154,18 @@ class CriticHead(nn.Module):
return x
class ActorHead(nn.Module):
def __init__(self, base, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
last_activation=torch.nn.Tanh, init_w=3e-3):
last_activation=lambda x: x, init_w=3e-3):
super(ActorHead, self).__init__()
self.base = base
self.policy_net = LinearNet(
layers=[self.base.feature_net.output_shape, n_action],
activation=last_activation,
# activation=last_activation,
layer_norm=False
)
self.init_weights(init_w)
@@ -183,7 +182,6 @@ class ActorHead(nn.Module):
return x
class DynamicsHead(nn.Module):
def __init__(self, base, n_observation, n_action,
layers, activation=torch.nn.ELU,
@@ -192,15 +190,34 @@ class DynamicsHead(nn.Module):
init_w=3e-3):
super(DynamicsHead, self).__init__()
self.base = base
self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, n_observation)
if parameters_noise:
def linear_layer(x_in, x_out):
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
else:
linear_layer = nn.Linear
# self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape)
self.value_net = LinearNet(
layers=[self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape],
activation=activation,
layer_norm=layer_norm,
linear_layer=linear_layer
)
self.value_net2 = nn.Linear(self.base.feature_net.output_shape, n_observation)
self.init_weights(init_w)
def init_weights(self, init_w):
self.value_net.weight.data.uniform_(-init_w, init_w)
self.value_net2.weight.data.uniform_(-init_w, init_w)
# self.value_net.weight.data.uniform_(-init_w, init_w)
for layer in self.value_net.net:
if isinstance(layer, nn.Linear):
layer.weight.data.uniform_(-init_w, init_w)
def forward(self, observation, action):
action = to_torch_variable(action)
x = self.base.forward(observation)
x = torch.cat((x, action), dim=1)
x = self.value_net.forward(x)
x = self.value_net2.forward(x)
return x
+11 -3
View File
@@ -31,7 +31,7 @@ def parse_args():
parser.add_argument('--reward-scale', type=float, default=1.)
boolean_flag(parser, "flip-state-action", default=False)
for agent in ["actor", "critic"]:
for agent in ["actor", "critic", "dynamics"]:
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
parser.add_argument('--{}-activation'.format(agent), type=str, default="relu")
boolean_flag(parser, "{}-layer-norm".format(agent), default=False)
@@ -195,10 +195,18 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
best_reward = Value("f", 0.0)
try:
if args.num_threads == args.num_train_threads == 1:
# run a single thread in the foreground so we can debug easier
# # run a single thread in the foreground so we can debug easier
# args.thread = 1
# multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
# or debug the single thread funcs
args.thread = 1
multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward)
global_episode = Value("i", 0)
global_update_step = Value("i", 0)
episodes_queue = mp.Queue()
train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue)
elif args.num_threads == args.num_train_threads:
for rank in range(args.num_threads):
args.thread = rank