Files
pytorch-a2c-ppo-acktr/model.py
T
Ilya Kostrikov 59890378f4 Initial commit
2017-09-07 19:45:57 -04:00

53 lines
1.5 KiB
Python
Executable File

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
nn.init.orthogonal(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0)
class ActorCritic(torch.nn.Module):
def __init__(self, num_inputs, action_space):
super(ActorCritic, self).__init__()
self.conv1 = nn.Conv2d(num_inputs, 32, 8, stride=4)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
self.linear1 = nn.Linear(64 * 7 * 7, 512)
num_outputs = action_space.n
self.critic_linear = nn.Linear(512, 1)
self.actor_linear = nn.Linear(512, num_outputs)
self.apply(weights_init)
self.conv1.weight.data.mul_(math.sqrt(2)) # Multiplier for relu
self.conv2.weight.data.mul_(math.sqrt(2)) # Multiplier for relu
self.conv3.weight.data.mul_(math.sqrt(2)) # Multiplier for relu
self.linear1.weight.data.mul_(math.sqrt(2)) # Multiplier for relu
self.train()
def forward(self, inputs):
x = self.conv1(inputs / 255.0)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = x.view(-1, 64 * 7 * 7)
x = self.linear1(x)
x = F.relu(x)
return self.critic_linear(x), self.actor_linear(x)