mirror of
https://github.com/wassname/Run-Skeleton-Run.git
synced 2026-06-27 16:13:51 +08:00
38 lines
1015 B
Python
38 lines
1015 B
Python
import torch
|
|
from torch.autograd import Variable
|
|
|
|
USE_CUDA = torch.cuda.is_available()
|
|
FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
|
|
|
|
|
|
def to_numpy(var):
|
|
return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
|
|
|
|
|
|
def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT):
|
|
return Variable(
|
|
torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad
|
|
).type(dtype)
|
|
|
|
|
|
def soft_update(target, source, tau):
|
|
for target_param, param in zip(target.parameters(), source.parameters()):
|
|
target_param.data.copy_(
|
|
target_param.data * (1.0 - tau) + param.data * tau
|
|
)
|
|
|
|
|
|
def hard_update(target, source):
|
|
for target_param, param in zip(target.parameters(), source.parameters()):
|
|
target_param.data.copy_(param.data)
|
|
|
|
|
|
activations = {
|
|
"relu": torch.nn.ReLU,
|
|
"elu": torch.nn.ELU,
|
|
"leakyrelu": torch.nn.LeakyReLU,
|
|
"selu": torch.nn.SELU,
|
|
"sigmoid": torch.nn.Sigmoid,
|
|
"tanh": torch.nn.Tanh
|
|
}
|