Files
Run-Skeleton-Run/common/torch_util.py
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

38 lines
1015 B
Python

import torch
from torch.autograd import Variable
USE_CUDA = torch.cuda.is_available()
FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
def to_numpy(var):
return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT):
return Variable(
torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad
).type(dtype)
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + param.data * tau
)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
activations = {
"relu": torch.nn.ReLU,
"elu": torch.nn.ELU,
"leakyrelu": torch.nn.LeakyReLU,
"selu": torch.nn.SELU,
"sigmoid": torch.nn.Sigmoid,
"tanh": torch.nn.Tanh
}