mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 19:45:48 +08:00
clean up nets and args
This commit is contained in:
+15
-109
@@ -21,90 +21,12 @@ def fanin_init(size, fanin=None):
|
||||
return torch.Tensor(size).uniform_(-v, v)
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
last_activation=torch.nn.Tanh, init_w=3e-3):
|
||||
super(Actor, self).__init__()
|
||||
|
||||
if parameters_noise:
|
||||
def linear_layer(x_in, x_out):
|
||||
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
|
||||
else:
|
||||
linear_layer = nn.Linear
|
||||
|
||||
self.feature_net = LinearNet(
|
||||
layers=[n_observation] + layers,
|
||||
activation=activation,
|
||||
layer_norm=layer_norm,
|
||||
linear_layer=linear_layer)
|
||||
self.policy_net = LinearNet(
|
||||
layers=[self.feature_net.output_shape, n_action],
|
||||
activation=last_activation,
|
||||
layer_norm=False
|
||||
)
|
||||
self.init_weights(init_w)
|
||||
|
||||
def init_weights(self, init_w):
|
||||
for layer in self.feature_net.net:
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.weight.data = fanin_init(layer.weight.data.size())
|
||||
|
||||
for layer in self.feature_net.net:
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.weight.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, observation):
|
||||
x = to_torch_variable(observation)
|
||||
x = self.feature_net.forward(x)
|
||||
x = self.policy_net.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
init_w=3e-3):
|
||||
super(Critic, self).__init__()
|
||||
|
||||
if parameters_noise:
|
||||
def linear_layer(x_in, x_out):
|
||||
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
|
||||
else:
|
||||
linear_layer = nn.Linear
|
||||
|
||||
self.feature_net = LinearNet(
|
||||
layers=[n_observation + n_action] + layers,
|
||||
activation=activation,
|
||||
layer_norm=layer_norm,
|
||||
linear_layer=linear_layer)
|
||||
self.value_net = nn.Linear(self.feature_net.output_shape, 1)
|
||||
self.init_weights(init_w)
|
||||
|
||||
def init_weights(self, init_w):
|
||||
for layer in self.feature_net.net:
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.weight.data = fanin_init(layer.weight.data.size())
|
||||
|
||||
self.value_net.weight.data.uniform_(-init_w, init_w)
|
||||
|
||||
def forward(self, observation, action):
|
||||
x = torch.cat((observation, action), dim=1)
|
||||
x = self.feature_net.forward(x)
|
||||
x = self.value_net.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class Base(nn.Module):
|
||||
def __init__(self, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
last_activation=torch.nn.Tanh, init_w=3e-3):
|
||||
init_w=3e-3):
|
||||
super(Base, self).__init__()
|
||||
|
||||
if parameters_noise:
|
||||
@@ -136,9 +58,6 @@ class Base(nn.Module):
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(self, base, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
init_w=3e-3):
|
||||
super(CriticHead, self).__init__()
|
||||
self.base = base
|
||||
@@ -156,16 +75,15 @@ class CriticHead(nn.Module):
|
||||
|
||||
class ActorHead(nn.Module):
|
||||
def __init__(self, base, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
last_activation=lambda x: x, init_w=3e-3):
|
||||
last_activation=torch.nn.Tanh,
|
||||
action_scale=1, init_w=3e-3):
|
||||
super(ActorHead, self).__init__()
|
||||
self.base = base
|
||||
self.action_scale = action_scale
|
||||
|
||||
self.policy_net = LinearNet(
|
||||
layers=[self.base.feature_net.output_shape, n_action],
|
||||
# activation=last_activation,
|
||||
activation=last_activation,
|
||||
layer_norm=False
|
||||
)
|
||||
self.init_weights(init_w)
|
||||
@@ -178,39 +96,28 @@ class ActorHead(nn.Module):
|
||||
def forward(self, observation):
|
||||
x = observation
|
||||
x = self.base.forward(x)
|
||||
x = self.policy_net.forward(x)
|
||||
x = self.policy_net.forward(x) * self.action_scale
|
||||
return x
|
||||
|
||||
|
||||
class DynamicsHead(nn.Module):
|
||||
def __init__(self, base, n_observation, n_action,
|
||||
layers, activation=torch.nn.ELU,
|
||||
layer_norm=False,
|
||||
parameters_noise=False, parameters_noise_factorised=False,
|
||||
last_activation=torch.nn.Tanh,
|
||||
init_w=3e-3):
|
||||
super(DynamicsHead, self).__init__()
|
||||
self.base = base
|
||||
|
||||
if parameters_noise:
|
||||
def linear_layer(x_in, x_out):
|
||||
return NoisyLinear(x_in, x_out, factorised=parameters_noise_factorised)
|
||||
else:
|
||||
linear_layer = nn.Linear
|
||||
|
||||
# self.value_net = nn.Linear(self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape)
|
||||
self.value_net = LinearNet(
|
||||
layers=[self.base.feature_net.output_shape + n_action, self.base.feature_net.output_shape],
|
||||
activation=activation,
|
||||
layer_norm=layer_norm,
|
||||
linear_layer=linear_layer
|
||||
self.dynamics_net = LinearNet(
|
||||
layers=[self.base.feature_net.output_shape+n_action, n_observation],
|
||||
activation=last_activation,
|
||||
layer_norm=False
|
||||
)
|
||||
self.value_net2 = nn.Linear(self.base.feature_net.output_shape, n_observation)
|
||||
# self.dynamics_net = nn.Linear(self.base.feature_net.output_shape+n_action, n_observation)
|
||||
self.init_weights(init_w)
|
||||
|
||||
def init_weights(self, init_w):
|
||||
self.value_net2.weight.data.uniform_(-init_w, init_w)
|
||||
# self.value_net.weight.data.uniform_(-init_w, init_w)
|
||||
for layer in self.value_net.net:
|
||||
# self.dynamics_net.weight.data.uniform_(-init_w, init_w)
|
||||
for layer in self.dynamics_net.net:
|
||||
if isinstance(layer, nn.Linear):
|
||||
layer.weight.data.uniform_(-init_w, init_w)
|
||||
|
||||
@@ -218,6 +125,5 @@ class DynamicsHead(nn.Module):
|
||||
action = to_torch_variable(action)
|
||||
x = self.base.forward(observation)
|
||||
x = torch.cat((x, action), dim=1)
|
||||
x = self.value_net.forward(x)
|
||||
x = self.value_net2.forward(x)
|
||||
x = self.dynamics_net.forward(x)
|
||||
return x
|
||||
|
||||
+4
-5
@@ -32,13 +32,14 @@ def parse_args():
|
||||
boolean_flag(parser, "flip-state-action", default=False)
|
||||
boolean_flag(parser, 'debug', default=False, help="Run in single threaded mode for debugging")
|
||||
|
||||
for agent in ["actor", "critic", "dynamics"]:
|
||||
for agent in ["base"]:
|
||||
parser.add_argument('--{}-layers'.format(agent), type=str, default="64-64")
|
||||
parser.add_argument('--{}-activation'.format(agent), type=str, default="relu")
|
||||
boolean_flag(parser, "{}-layer-norm".format(agent), default=False)
|
||||
boolean_flag(parser, "{}-parameters-noise".format(agent), default=False)
|
||||
boolean_flag(parser, "{}-parameters-noise-factorised".format(agent), default=False)
|
||||
|
||||
for agent in ["actor", "critic", "dynamics"]:
|
||||
parser.add_argument('--{}-lr'.format(agent), type=float, default=1e-3)
|
||||
parser.add_argument('--{}-lr-end'.format(agent), type=float, default=5e-5)
|
||||
|
||||
@@ -157,11 +158,9 @@ def train(args, model_fn, act_update_fns, multi_thread, train_single, play_singl
|
||||
args.n_action = env.action_space.shape[0]
|
||||
args.n_observation = env.observation_space.shape[0]
|
||||
|
||||
args.actor_layers = str2params(args.actor_layers)
|
||||
args.critic_layers = str2params(args.critic_layers)
|
||||
args.base_layers = str2params(args.base_layers)
|
||||
|
||||
args.actor_activation = activations[args.actor_activation]
|
||||
args.critic_activation = activations[args.critic_activation]
|
||||
args.base_activation = activations[args.base_activation]
|
||||
|
||||
actor, critic, dynamics = model_fn(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user