clean up nets and args

This commit is contained in:
wassname
2018-01-21 12:35:29 +08:00
parent 6bb9c51403
commit c4065bb7db
2 changed files with 19 additions and 114 deletions
+15 -109
View File
@@ -21,90 +21,12 @@ def fanin_init(size, fanin=None):
return torch.Tensor(size).uniform_(-v, v)
class Actor(nn.Module):
def __init__(self, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
last_activation=torch.nn.Tanh, init_w=3e-3):
super(Actor, self).__init__()
if parameters_noise:
def linear_layer(x_in, x_out):
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
else:
linear_layer = nn.Linear
self.feature_net = LinearNet(
layers=[n_observation] + layers,
activation=activation,
layer_norm=layer_norm,
linear_layer=linear_layer)
self.policy_net = LinearNet(
layers=[self.feature_net.output_shape, n_action],
activation=last_activation,
layer_norm=False
)
self.init_weights(init_w)
def init_weights(self, init_w):
for layer in self.feature_net.net:
if isinstance(layer, nn.Linear):
layer.weight.data = fanin_init(layer.weight.data.size())
for layer in self.feature_net.net:
if isinstance(layer, nn.Linear):
layer.weight.data.uniform_(-init_w, init_w)
def forward(self, observation):
x = to_torch_variable(observation)
x = self.feature_net.forward(x)
x = self.policy_net.forward(x)
return x
class Critic(nn.Module):
def __init__(self, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
init_w=3e-3):
super(Critic, self).__init__()
if parameters_noise:
def linear_layer(x_in, x_out):
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
else:
linear_layer = nn.Linear
self.feature_net = LinearNet(
layers=[n_observation + n_action] + layers,
activation=activation,
layer_norm=layer_norm,
linear_layer=linear_layer)
self.value_net = nn.Linear(self.feature_net.output_shape, 1)
self.init_weights(init_w)
def init_weights(self, init_w):
for layer in self.feature_net.net:
if isinstance(layer, nn.Linear):
layer.weight.data = fanin_init(layer.weight.data.size())
self.value_net.weight.data.uniform_(-init_w, init_w)
def forward(self, observation, action):
x = torch.cat((observation, action), dim=1)
x = self.feature_net.forward(x)
x = self.value_net.forward(x)
return x
class Base(nn.Module):
def __init__(self, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
last_activation=torch.nn.Tanh, init_w=3e-3):
init_w=3e-3):
super(Base, self).__init__()
if parameters_noise:
@@ -136,9 +58,6 @@ class Base(nn.Module):
class CriticHead(nn.Module):
def __init__(self, base, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
init_w=3e-3):
super(CriticHead, self).__init__()
self.base = base
@@ -156,16 +75,15 @@ class CriticHead(nn.Module):
class ActorHead(nn.Module):
def __init__(self, base, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
last_activation=lambda x: x, init_w=3e-3):
last_activation=torch.nn.Tanh,
action_scale=1, init_w=3e-3):
super(ActorHead, self).__init__()
self.base = base
self.action_scale = action_scale
self.policy_net = LinearNet(
layers=[self.base.feature_net.output_shape, n_action],
# activation=last_activation,
activation=last_activation,
layer_norm=False
)
self.init_weights(init_w)
@@ -178,39 +96,28 @@ class ActorHead(nn.Module):
def forward(self, observation):
x = observation
x = self.base.forward(x)
x = self.policy_net.forward(x)
x = self.policy_net.forward(x) * self.action_scale
return x
class DynamicsHead(nn.Module):
def __init__(self, base, n_observation, n_action,
layers, activation=torch.nn.ELU,
layer_norm=False,
parameters_noise=False, parameters_noise_factorised=False,
last_activation=torch.nn.Tanh,
init_w=3e-3):
super(DynamicsHead, self).__init__()
self.base = base
if parameters_noise:
def linear_layer(x_in, x_out):
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
else:
linear_layer = nn.Linear
# self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape)
self.value_net = LinearNet(
layers=[self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape],
activation=activation,
layer_norm=layer_norm,
linear_layer=linear_layer
self.dynamics_net = LinearNet(
layers=[self.base.feature_net.output_shape+n_action, n_observation],
activation=last_activation,
layer_norm=False
)
self.value_net2 = nn.Linear(self.base.feature_net.output_shape, n_observation)
# self.dynamics_net = nn.Linear(self.base.feature_net.output_shape+n_action, n_observation)
self.init_weights(init_w)
def init_weights(self, init_w):
self.value_net2.weight.data.uniform_(-init_w, init_w)
# self.value_net.weight.data.uniform_(-init_w, init_w)
for layer in self.value_net.net:
# self.dynamics_net.weight.data.uniform_(-init_w, init_w)
for layer in self.dynamics_net.net:
if isinstance(layer, nn.Linear):
layer.weight.data.uniform_(-init_w, init_w)
@@ -218,6 +125,5 @@ class DynamicsHead(nn.Module):
action = to_torch_variable(action)
x = self.base.forward(observation)
x = torch.cat((x, action), dim=1)
x = self.value_net.forward(x)
x = self.value_net2.forward(x)
x = self.dynamics_net.forward(x)
return x
+4 -5
View File
@@ -32,13 +32,14 @@ def parse_args():
boolean_flag(parser, "flip-state-action", default=False)
boolean_flag(parser, 'debug', default=False, help="Run in single threaded mode for debugging")
for agent in ["actor", "critic", "dynamics"]:
for agent in ["base"]:
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
parser.add_argument('--{}-activation'.format(agent), type=str, default="relu")
boolean_flag(parser, "{}-layer-norm".format(agent), default=False)
boolean_flag(parser, "{}-parameters-noise".format(agent), default=False)
boolean_flag(parser, "{}-parameters-noise-factorised".format(agent), default=False)
for agent in ["actor", "critic", "dynamics"]:
parser.add_argument('--{}-lr'.format(agent), type=float, default=1e-3)
parser.add_argument('--{}-lr-end'.format(agent), type=float, default=5e-5)
@@ -157,11 +158,9 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
args.n_action = env.action_space.shape[0]
args.n_observation = env.observation_space.shape[0]
args.actor_layers = str2params(args.actor_layers)
args.critic_layers = str2params(args.critic_layers)
args.base_layers = str2params(args.base_layers)
args.actor_activation = activations[args.actor_activation]
args.critic_activation = activations[args.critic_activation]
args.base_activation = activations[args.base_activation]
actor, critic, dynamics = model_fn(args)