diff --git a/ddpg/model.py b/ddpg/model.py index 4043191..e40dd5f 100644 --- a/ddpg/model.py +++ b/ddpg/model.py @@ -75,7 +75,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, def update_fn( observations, actions, rewards, next_observations, dones, weights, - actor_lr=1e-4, critic_lr=1e-3): + actor_lr=1e-4, critic_lr=1e-3, dynamics_lr=1e-4): nonlocal actor, critic, dynamics, target_actor, target_critic, target_dynamics, actor_optim, critic_optim, dynamics_optim if hasattr(args, "flip_states"): @@ -112,7 +112,7 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, dynamics_loss.backward() torch.nn.utils.clip_grad_norm(dynamics.parameters(), args.grad_clip) for param_group in actor_optim.param_groups: - param_group["lr"] = actor_lr # TODO change to dynamics lr + param_group["lr"] = dynamics_lr # TODO change to dynamics lr dynamics_optim.step() # Critic update @@ -172,7 +172,8 @@ def create_act_update_fns(actor, critic, dynamics, target_actor, target_critic, metrics = { "value_loss": value_loss, - "policy_loss": policy_loss + "policy_loss": policy_loss, + "dynamics_loss": dynamics_loss } td_v_values = critic( @@ -233,9 +234,14 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar max_step=args.max_episodes) critic_learning_rate_decay_fn = create_decay_fn( "linear", - initial_value=args.critic_lr, + initial_value=arstep_metricsgs.critic_lr, final_value=args.critic_lr_end, max_step=args.max_episodes) + dynamics_learning_rate_decay_fn = create_decay_fn( + "linear", + initial_value=args.dynamics_lr, + final_value=args.dynamics_lr_end, + max_step=args.max_episodes) epsilon_cycle_len = random.randint(args.epsilon_cycle_len // 2, args.epsilon_cycle_len * 2) @@ -256,11 +262,13 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar actor_lr = actor_learning_rate_decay_fn(episode) critic_lr = critic_learning_rate_decay_fn(episode) + dynamics_lr = dynamics_learning_rate_decay_fn(episode) epsilon = min(args.initial_epsilon, max(args.final_epsilon, epsilon_decay_fn(episode))) episode_metrics = { "value_loss": 0.0, "policy_loss": 0.0, + "dynamics_loss": 0.0, "reward": 0.0, "step": 0, "epsilon": epsilon @@ -293,7 +301,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar step_metrics, step_info = update_fn( tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, - weights, actor_lr, critic_lr) + weights, actor_lr, critic_lr, dynamics_lr) if args.prioritized_replay: new_priorities = np.abs(step_info["td_error"]) + 1e-6 @@ -329,6 +337,7 @@ def train_multi_thread(actor, critic, dynamics, target_actor, target_critic, tar episode) logger.scalar_summary("actor lr", actor_lr, episode) logger.scalar_summary("critic lr", critic_lr, episode) + logger.scalar_summary("dynamics_lr", dynamics_lr, episode) if episode % args.save_step == 0: save_fn(episode) @@ -373,6 +382,11 @@ def train_single_thread( initial_value=args.critic_lr, final_value=args.critic_lr_end, max_step=args.max_update_steps) + dynamics_learning_rate_decay_fn = create_decay_fn( + "linear", + initial_value=args.dynamics_lr, + final_value=args.dynamics_lr_end, + max_step=args.max_update_steps) update_step = 0 received_examples = 1 # just hack @@ -380,8 +394,10 @@ def train_single_thread( and global_update_step.value < args.max_update_steps * args.num_train_threads: actor_lr = actor_learning_rate_decay_fn(update_step) critic_lr = critic_learning_rate_decay_fn(update_step) + dynamics_lr = dynamics_learning_rate_decay_fn(update_step) actor_lr = min(args.actor_lr, max(args.actor_lr_end, actor_lr)) + dynamics_lr = min(args.dynamics_lr, max(args.dynamics_lr_end, dynamics_lr)) critic_lr = min(args.critic_lr, max(args.critic_lr_end, critic_lr)) while True: @@ -410,7 +426,7 @@ def train_single_thread( step_metrics, step_info = update_fn( tr_observations, tr_actions, tr_rewards, tr_next_observations, tr_dones, - weights, actor_lr, critic_lr) + weights, actor_lr, critic_lr, dynamics_lr) update_step += 1 global_update_step.value += 1 @@ -425,6 +441,7 @@ def train_single_thread( logger.scalar_summary("actor lr", actor_lr, update_step) logger.scalar_summary("critic lr", critic_lr, update_step) + logger.scalar_summary("dynamics lr", dynamics_lr, update_step) if update_step % args.save_step == 0: save_fn(update_step) diff --git a/ddpg/nets.py b/ddpg/nets.py index 96cacd4..1b11856 100644 --- a/ddpg/nets.py +++ b/ddpg/nets.py @@ -154,19 +154,18 @@ class CriticHead(nn.Module): return x - class ActorHead(nn.Module): def __init__(self, base, n_observation, n_action, layers, activation=torch.nn.ELU, layer_norm=False, parameters_noise=False, parameters_noise_factorised=False, - last_activation=torch.nn.Tanh, init_w=3e-3): + last_activation=lambda x: x, init_w=3e-3): super(ActorHead, self).__init__() self.base = base self.policy_net = LinearNet( layers=[self.base.feature_net.output_shape, n_action], - activation=last_activation, + # activation=last_activation, layer_norm=False ) self.init_weights(init_w) @@ -183,7 +182,6 @@ class ActorHead(nn.Module): return x - class DynamicsHead(nn.Module): def __init__(self, base, n_observation, n_action, layers, activation=torch.nn.ELU, @@ -192,15 +190,34 @@ class DynamicsHead(nn.Module): init_w=3e-3): super(DynamicsHead, self).__init__() self.base = base - self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, n_observation) + + if parameters_noise: + def linear_layer(x_in, x_out): + return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised) + else: + linear_layer = nn.Linear + + # self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape) + self.value_net = LinearNet( + layers=[self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape], + activation=activation, + layer_norm=layer_norm, + linear_layer=linear_layer + ) + self.value_net2 = nn.Linear(self.base.feature_net.output_shape, n_observation) self.init_weights(init_w) def init_weights(self, init_w): - self.value_net.weight.data.uniform_(-init_w, init_w) + self.value_net2.weight.data.uniform_(-init_w, init_w) + # self.value_net.weight.data.uniform_(-init_w, init_w) + for layer in self.value_net.net: + if isinstance(layer, nn.Linear): + layer.weight.data.uniform_(-init_w, init_w) def forward(self, observation, action): action = to_torch_variable(action) x = self.base.forward(observation) x = torch.cat((x, action), dim=1) x = self.value_net.forward(x) + x = self.value_net2.forward(x) return x diff --git a/ddpg/train.py b/ddpg/train.py index d41b3f4..538836c 100644 --- a/ddpg/train.py +++ b/ddpg/train.py @@ -31,7 +31,7 @@ def parse_args(): parser.add_argument('--reward-scale', type=float, default=1.) boolean_flag(parser, "flip-state-action", default=False) - for agent in ["actor", "critic"]: + for agent in ["actor", "critic", "dynamics"]: parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64") parser.add_argument('--{}-activation'.format(agent), type=str, default="relu") boolean_flag(parser, "{}-layer-norm".format(agent), default=False) @@ -195,10 +195,18 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl best_reward = Value("f", 0.0) try: + if args.num_threads == args.num_train_threads == 1: - # run a single thread in the foreground so we can debug easier + # # run a single thread in the foreground so we can debug easier + # args.thread = 1 + # multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward) + + # or debug the single thread funcs args.thread = 1 - multi_thread(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, best_reward) + global_episode = Value("i", 0) + global_update_step = Value("i", 0) + episodes_queue = mp.Queue() + train_single(actor, critic, dynamics, target_actor, target_critic, target_dynamics, args, act_update_fns, global_episode, global_update_step, episodes_queue) elif args.num_threads == args.num_train_threads: for rank in range(args.num_threads): args.thread = rank