mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:49:45 +08:00
[RLlib] IMPALA PyTorch (#8287)
This PR adds an IMPALA PyTorch implementation. - adds compilation tests for LSTM and w/o LSTM. - adds learning test for CartPole.
This commit is contained in:
+1
-1
@@ -201,7 +201,7 @@ matrix:
|
||||
- travis_wait 60 bazel test --config=ci --build_tests_only --test_tag_filters=quick_train rllib/...
|
||||
# Test everything that does not have any of the "main" labels:
|
||||
# "learning_tests|quick_train|examples|tests_dir".
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir rllib/...
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests_tf,-learning_tests_torch,-quick_train,-examples,-tests_dir rllib/...
|
||||
|
||||
# RLlib: Everything in rllib/examples/ directory.
|
||||
- os: linux
|
||||
|
||||
@@ -18,7 +18,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
|
||||
`APEX-DDPG`_ tf No **Yes** **Yes**
|
||||
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`IMPALA`_ tf **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
@@ -94,7 +94,7 @@ SpaceInvaders 646 ~300
|
||||
|
||||
Importance Weighted Actor-Learner Architecture (IMPALA)
|
||||
-------------------------------------------------------
|
||||
|tensorflow|
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://arxiv.org/abs/1802.01561>`__
|
||||
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/impala/impala.py>`__
|
||||
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models-tensorflow>`__. Multiple learner GPUs and experience replay are also supported.
|
||||
|
||||
@@ -89,7 +89,7 @@ Algorithms
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Distributed Prioritized Experience Replay (Ape-X) <apex>`
|
||||
|
||||
- |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) <impala>`
|
||||
- |pytorch| |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) <impala>`
|
||||
|
||||
- |pytorch| |tensorflow| :ref:`Asynchronous Proximal Policy Optimization (APPO) <appo>`
|
||||
|
||||
|
||||
@@ -136,6 +136,12 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["agents/impala/tests/test_vtrace.py"]
|
||||
)
|
||||
py_test(
|
||||
name = "test_impala",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/impala/tests/test_impala.py"]
|
||||
)
|
||||
|
||||
# MARWILTrainer
|
||||
py_test(
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.tune.resources import Resources
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# V-trace params (see vtrace_tf.py).
|
||||
# V-trace params (see vtrace_tf/torch.py).
|
||||
"vtrace": True,
|
||||
"vtrace_clip_rho_threshold": 1.0,
|
||||
"vtrace_clip_pg_rho_threshold": 1.0,
|
||||
@@ -83,22 +83,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def choose_policy(config):
|
||||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
return A3CTFPolicy
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError(
|
||||
"IMPALA does not support PyTorch yet! Use tf instead.")
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
|
||||
def defer_make_workers(trainer, env_creator, policy, config):
|
||||
# Defer worker creation to after the optimizer has been created.
|
||||
return trainer._make_workers(env_creator, policy, config, 0)
|
||||
@@ -157,12 +141,39 @@ class OverrideDefaultResourceRequest:
|
||||
cf["num_workers"])
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import \
|
||||
VTraceTorchPolicy
|
||||
return VTraceTorchPolicy
|
||||
else:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
return A3CTorchPolicy
|
||||
else:
|
||||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
return A3CTFPolicy
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0.0:
|
||||
raise DeprecationWarning("`entropy_coeff` must be >= 0.0!")
|
||||
|
||||
if config["vtrace"] and not config["in_evaluation"]:
|
||||
if config["batch_mode"] != "truncate_episodes":
|
||||
raise ValueError(
|
||||
"Must use `batch_mode`=truncate_episodes if `vtrace` is True.")
|
||||
|
||||
|
||||
ImpalaTrainer = build_trainer(
|
||||
name="IMPALA",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=VTraceTFPolicy,
|
||||
validate_config=validate_config,
|
||||
get_policy_class=choose_policy,
|
||||
get_policy_class=get_policy_class,
|
||||
make_workers=defer_make_workers,
|
||||
make_policy_optimizer=make_aggregators_and_optimizer,
|
||||
mixins=[OverrideDefaultResourceRequest])
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.impala as impala
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class TestIMPALA(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_impala_compilation(self):
|
||||
"""Test whether an ImpalaTrainer can be built with both frameworks."""
|
||||
config = impala.DEFAULT_CONFIG.copy()
|
||||
num_iterations = 1
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||
local_cfg = config.copy()
|
||||
for env in ["Pendulum-v0", "CartPole-v0"]:
|
||||
print("Env={}".format(env))
|
||||
print("w/ LSTM")
|
||||
# Test w/o LSTM.
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
||||
# Test w/ LSTM.
|
||||
print("w/o LSTM")
|
||||
local_cfg["model"]["use_lstm"] = True
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -251,12 +251,6 @@ def postprocess_trajectory(policy,
|
||||
return sample_batch
|
||||
|
||||
|
||||
def validate_config(policy, obs_space, action_space, config):
|
||||
if config["vtrace"] and not config["in_evaluation"]:
|
||||
assert config["batch_mode"] == "truncate_episodes", \
|
||||
"Must use `truncate_episodes` batch mode with V-trace."
|
||||
|
||||
|
||||
def choose_optimizer(policy, config):
|
||||
if policy.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(policy.cur_lr)
|
||||
@@ -289,7 +283,6 @@ VTraceTFPolicy = build_tf_policy(
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
optimizer_fn=choose_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, EntropyCoeffSchedule],
|
||||
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
|
||||
|
||||
@@ -32,6 +32,7 @@ tensors.
|
||||
from ray.rllib.agents.impala.vtrace_tf import VTraceFromLogitsReturns, \
|
||||
VTraceReturns
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
@@ -217,6 +218,7 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions, dist_class, model))
|
||||
|
||||
behaviour_action_log_probs = force_list(behaviour_action_log_probs)
|
||||
log_rhos = get_log_rhos(target_action_log_probs,
|
||||
behaviour_action_log_probs)
|
||||
|
||||
|
||||
@@ -1,12 +1,187 @@
|
||||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import postprocess_trajectory
|
||||
import ray.rllib.agents.impala.vtrace_torch as vtrace
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
|
||||
EntropyCoeffSchedule
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import global_norm, sequence_mask
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VTraceLoss:
|
||||
def __init__(self,
|
||||
actions,
|
||||
actions_logp,
|
||||
actions_entropy,
|
||||
dones,
|
||||
behaviour_action_logp,
|
||||
behaviour_logits,
|
||||
target_logits,
|
||||
discount,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
model,
|
||||
valid_mask,
|
||||
config,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0):
|
||||
"""Policy gradient loss with vtrace importance weighting.
|
||||
|
||||
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
|
||||
batch_size. The reason we need to know `B` is for V-trace to properly
|
||||
handle episode cut boundaries.
|
||||
|
||||
Args:
|
||||
actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_action_logp: Tensor of shape [T, B].
|
||||
behaviour_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
target_logits: A list with length of ACTION_SPACE of float32
|
||||
tensors of shapes
|
||||
[T, B, ACTION_SPACE[0]],
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
config: Trainer config dict.
|
||||
"""
|
||||
|
||||
if valid_mask is None:
|
||||
valid_mask = torch.ones_like(actions_logp)
|
||||
|
||||
# Compute vtrace on the CPU for better perf
|
||||
# (devices handled inside `vtrace.multi_from_logits`).
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_action_log_probs=behaviour_action_logp,
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=torch.unbind(actions, dim=2),
|
||||
discounts=(1.0 - dones.float()) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
dist_class=dist_class,
|
||||
model=model,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
|
||||
# The policy gradients loss
|
||||
self.pi_loss = -torch.sum(
|
||||
actions_logp * self.vtrace_returns.pg_advantages * valid_mask)
|
||||
|
||||
# The baseline loss
|
||||
delta = (values - self.vtrace_returns.vs) * valid_mask
|
||||
self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss
|
||||
self.entropy = torch.sum(actions_entropy * valid_mask)
|
||||
|
||||
# The summed weighted loss
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
||||
def build_vtrace_loss(policy, model, dist_class, train_batch):
|
||||
model_out, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(model_out, model)
|
||||
|
||||
if isinstance(policy.action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = [policy.action_space.n]
|
||||
elif isinstance(policy.action_space,
|
||||
gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
def _make_time_major(*args, **kw):
|
||||
return make_time_major(policy, train_batch.get("seq_lens"), *args,
|
||||
**kw)
|
||||
|
||||
actions = train_batch[SampleBatch.ACTIONS]
|
||||
dones = train_batch[SampleBatch.DONES]
|
||||
rewards = train_batch[SampleBatch.REWARDS]
|
||||
behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
|
||||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
if isinstance(output_hidden_shape, list):
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1)
|
||||
else:
|
||||
unpacked_behaviour_logits = torch.chunk(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
|
||||
values = model.value_function()
|
||||
|
||||
if policy.is_recurrent():
|
||||
max_seq_len = torch.max(train_batch["seq_lens"]) - 1
|
||||
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = torch.reshape(mask, [-1])
|
||||
else:
|
||||
mask = torch.ones_like(rewards)
|
||||
|
||||
# Prepare actions for loss.
|
||||
loss_actions = actions if is_multidiscrete else torch.unsqueeze(
|
||||
actions, dim=1)
|
||||
|
||||
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
|
||||
policy.loss = VTraceLoss(
|
||||
actions=_make_time_major(loss_actions, drop_last=True),
|
||||
actions_logp=_make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=_make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
dones=_make_time_major(dones, drop_last=True),
|
||||
behaviour_action_logp=_make_time_major(
|
||||
behaviour_action_logp, drop_last=True),
|
||||
behaviour_logits=_make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
target_logits=_make_time_major(unpacked_outputs, drop_last=True),
|
||||
discount=policy.config["gamma"],
|
||||
rewards=_make_time_major(rewards, drop_last=True),
|
||||
values=_make_time_major(values, drop_last=True),
|
||||
bootstrap_value=_make_time_major(values)[-1],
|
||||
dist_class=TorchCategorical if is_multidiscrete else dist_class,
|
||||
model=model,
|
||||
valid_mask=_make_time_major(mask, drop_last=True),
|
||||
config=policy.config,
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.entropy_coeff,
|
||||
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
return policy.loss.total_loss
|
||||
|
||||
|
||||
def make_time_major(policy, seq_lens, tensor, drop_last=False):
|
||||
"""Swaps batch and trajectory axis.
|
||||
|
||||
@@ -44,6 +219,27 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False):
|
||||
return res
|
||||
|
||||
|
||||
def stats(policy, train_batch):
|
||||
values_batched = make_time_major(
|
||||
policy,
|
||||
train_batch.get("seq_lens"),
|
||||
policy.model.value_function(),
|
||||
drop_last=policy.config["vtrace"])
|
||||
|
||||
return {
|
||||
"cur_lr": policy.cur_lr,
|
||||
"policy_loss": policy.loss.pi_loss,
|
||||
"entropy": policy.loss.entropy,
|
||||
"entropy_coeff": policy.entropy_coeff,
|
||||
"var_gnorm": global_norm(policy.model.trainable_variables()),
|
||||
"vf_loss": policy.loss.vf_loss,
|
||||
"vf_explained_var": explained_variance(
|
||||
torch.reshape(policy.loss.value_targets, [-1]),
|
||||
torch.reshape(values_batched, [-1]),
|
||||
framework="torch"),
|
||||
}
|
||||
|
||||
|
||||
def choose_optimizer(policy, config):
|
||||
if policy.config["opt_type"] == "adam":
|
||||
return torch.optim.Adam(
|
||||
@@ -55,3 +251,22 @@ def choose_optimizer(policy, config):
|
||||
weight_decay=config["decay"],
|
||||
momentum=config["momentum"],
|
||||
eps=config["epsilon"])
|
||||
|
||||
|
||||
def setup_mixins(policy, obs_space, action_space, config):
|
||||
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
|
||||
config["entropy_coeff_schedule"])
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
VTraceTorchPolicy = build_torch_policy(
|
||||
name="VTraceTorchPolicy",
|
||||
loss_fn=build_vtrace_loss,
|
||||
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_trajectory,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=choose_optimizer,
|
||||
before_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, EntropyCoeffSchedule],
|
||||
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from ray.rllib.agents.impala.impala import validate_config
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents.ppo.ppo import update_kl
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
|
||||
# yapf: disable
|
||||
@@ -89,11 +90,6 @@ def get_policy_class(config):
|
||||
return AsyncPPOTFPolicy
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
|
||||
|
||||
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
||||
name="APPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
|
||||
@@ -8,7 +8,7 @@ import gym
|
||||
|
||||
from ray.rllib.agents.impala import vtrace_tf as vtrace
|
||||
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
|
||||
clip_gradients, validate_config, choose_optimizer
|
||||
clip_gradients, choose_optimizer
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
@@ -449,7 +449,6 @@ AsyncPPOTFPolicy = build_tf_policy(
|
||||
optimizer_fn=choose_optimizer,
|
||||
gradients_fn=clip_gradients,
|
||||
extra_action_fetches_fn=add_values,
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
mixins=[
|
||||
|
||||
@@ -528,9 +528,6 @@ class ModelCatalog:
|
||||
FCNet)
|
||||
from ray.rllib.models.torch.visionnet import (VisionNetwork as
|
||||
VisionNet)
|
||||
if model_config.get("use_lstm"):
|
||||
raise NotImplementedError(
|
||||
"LSTM auto-wrapping not implemented for torch")
|
||||
else:
|
||||
from ray.rllib.models.tf.fcnet_v2 import \
|
||||
FullyConnectedNetwork as FCNet
|
||||
|
||||
@@ -237,6 +237,11 @@ class TorchSquashedGaussian(TorchDistributionWrapper):
|
||||
unsquashed = atanh(save_normed_values)
|
||||
return unsquashed
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
class TorchBeta(TorchDistributionWrapper):
|
||||
"""
|
||||
@@ -285,6 +290,11 @@ class TorchBeta(TorchDistributionWrapper):
|
||||
def _unsquash(self, values):
|
||||
return (values - self.low) / (self.high - self.low)
|
||||
|
||||
@staticmethod
|
||||
@override(ActionDistribution)
|
||||
def required_model_output_shape(action_space, model_config):
|
||||
return np.prod(action_space.shape) * 2
|
||||
|
||||
|
||||
class TorchDeterministic(TorchDistributionWrapper):
|
||||
"""Action distribution that returns the input values directly.
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
cartpole-impala-tf:
|
||||
env: CartPole-v0
|
||||
run: IMPALA
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 500000
|
||||
config:
|
||||
use_pytorch: false
|
||||
num_gpus: 0
|
||||
@@ -0,0 +1,9 @@
|
||||
cartpole-impala-torch:
|
||||
env: CartPole-v0
|
||||
run: IMPALA
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
timesteps_total: 500000
|
||||
config:
|
||||
use_pytorch: true
|
||||
num_gpus: 0
|
||||
@@ -30,10 +30,7 @@ def do_test_explorations(run,
|
||||
|
||||
# Test all frameworks.
|
||||
for fw in framework_iterator(core_config):
|
||||
if fw == "torch" and \
|
||||
run in [impala.ImpalaTrainer, sac.SACTrainer]:
|
||||
continue
|
||||
elif fw == "eager" and run in [
|
||||
if fw == "eager" and run in [
|
||||
ddpg.DDPGTrainer, sac.SACTrainer, td3.TD3Trainer
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -47,7 +47,7 @@ def framework_iterator(config=None,
|
||||
logger.warning(
|
||||
"framework_iterator skipping torch (not installed)!")
|
||||
continue
|
||||
elif not tf:
|
||||
if fw != "torch" and not tf:
|
||||
logger.warning("framework_iterator skipping {} (tf not "
|
||||
"installed)!".format(fw))
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user