From 166bb5d690b74cecf62fbe360731b79c7765a97d Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sun, 3 May 2020 13:44:25 +0200 Subject: [PATCH] [RLlib] IMPALA PyTorch (#8287) This PR adds an IMPALA PyTorch implementation. - adds compilation tests for LSTM and w/o LSTM. - adds learning test for CartPole. --- .travis.yml | 2 +- doc/source/rllib-algorithms.rst | 4 +- doc/source/rllib-toc.rst | 2 +- rllib/BUILD | 6 + rllib/agents/impala/impala.py | 47 ++-- rllib/agents/impala/tests/test_impala.py | 46 ++++ rllib/agents/impala/vtrace_tf_policy.py | 7 - rllib/agents/impala/vtrace_torch.py | 2 + rllib/agents/impala/vtrace_torch_policy.py | 215 ++++++++++++++++++ rllib/agents/ppo/appo.py | 8 +- rllib/agents/ppo/appo_tf_policy.py | 3 +- rllib/models/catalog.py | 3 - rllib/models/torch/torch_action_dist.py | 10 + .../regression_tests/cartpole-impala-tf.yaml | 9 + .../cartpole-impala-torch.yaml | 9 + .../exploration/tests/test_explorations.py | 5 +- rllib/utils/test_utils.py | 2 +- 17 files changed, 335 insertions(+), 45 deletions(-) create mode 100644 rllib/agents/impala/tests/test_impala.py create mode 100644 rllib/tuned_examples/regression_tests/cartpole-impala-tf.yaml create mode 100644 rllib/tuned_examples/regression_tests/cartpole-impala-torch.yaml diff --git a/.travis.yml b/.travis.yml index 0fa066250..50801ddf3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -201,7 +201,7 @@ matrix: - travis_wait 60 bazel test --config=ci --build_tests_only --test_tag_filters=quick_train rllib/... # Test everything that does not have any of the "main" labels: # "learning_tests|quick_train|examples|tests_dir". - - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir rllib/... + - ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests_tf,-learning_tests_torch,-quick_train,-examples,-tests_dir rllib/... # RLlib: Everything in rllib/examples/ directory. - os: linux diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 9e490facd..cb150ad1d 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -18,7 +18,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi- `APEX-DDPG`_ tf No **Yes** **Yes** `DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes** `APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes** -`IMPALA`_ tf **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ `MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ `PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ `PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ @@ -94,7 +94,7 @@ SpaceInvaders 646 ~300 Importance Weighted Actor-Learner Architecture (IMPALA) ------------------------------------------------------- -|tensorflow| +|pytorch| |tensorflow| `[paper] `__ `[implementation] `__ In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. Multiple learner GPUs and experience replay are also supported. diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index 0c1bf6a87..6d0acc0cd 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -89,7 +89,7 @@ Algorithms - |pytorch| |tensorflow| :ref:`Distributed Prioritized Experience Replay (Ape-X) ` - - |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) ` + - |pytorch| |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) ` - |pytorch| |tensorflow| :ref:`Asynchronous Proximal Policy Optimization (APPO) ` diff --git a/rllib/BUILD b/rllib/BUILD index 3aaed4b32..876d14155 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -136,6 +136,12 @@ py_test( size = "small", srcs = ["agents/impala/tests/test_vtrace.py"] ) +py_test( + name = "test_impala", + tags = ["agents_dir"], + size = "medium", + srcs = ["agents/impala/tests/test_impala.py"] +) # MARWILTrainer py_test( diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index 193bddd60..8b32a9e08 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -11,7 +11,7 @@ from ray.tune.resources import Resources # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ - # V-trace params (see vtrace_tf.py). + # V-trace params (see vtrace_tf/torch.py). "vtrace": True, "vtrace_clip_rho_threshold": 1.0, "vtrace_clip_pg_rho_threshold": 1.0, @@ -83,22 +83,6 @@ DEFAULT_CONFIG = with_common_config({ # yapf: enable -def choose_policy(config): - if config["vtrace"]: - return VTraceTFPolicy - else: - return A3CTFPolicy - - -def validate_config(config): - # PyTorch check. - if config["use_pytorch"]: - raise ValueError( - "IMPALA does not support PyTorch yet! Use tf instead.") - if config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - - def defer_make_workers(trainer, env_creator, policy, config): # Defer worker creation to after the optimizer has been created. return trainer._make_workers(env_creator, policy, config, 0) @@ -157,12 +141,39 @@ class OverrideDefaultResourceRequest: cf["num_workers"]) +def get_policy_class(config): + if config["use_pytorch"]: + if config["vtrace"]: + from ray.rllib.agents.impala.vtrace_torch_policy import \ + VTraceTorchPolicy + return VTraceTorchPolicy + else: + from ray.rllib.agents.a3c.a3c_torch_policy import \ + A3CTorchPolicy + return A3CTorchPolicy + else: + if config["vtrace"]: + return VTraceTFPolicy + else: + return A3CTFPolicy + + +def validate_config(config): + if config["entropy_coeff"] < 0.0: + raise DeprecationWarning("`entropy_coeff` must be >= 0.0!") + + if config["vtrace"] and not config["in_evaluation"]: + if config["batch_mode"] != "truncate_episodes": + raise ValueError( + "Must use `batch_mode`=truncate_episodes if `vtrace` is True.") + + ImpalaTrainer = build_trainer( name="IMPALA", default_config=DEFAULT_CONFIG, default_policy=VTraceTFPolicy, validate_config=validate_config, - get_policy_class=choose_policy, + get_policy_class=get_policy_class, make_workers=defer_make_workers, make_policy_optimizer=make_aggregators_and_optimizer, mixins=[OverrideDefaultResourceRequest]) diff --git a/rllib/agents/impala/tests/test_impala.py b/rllib/agents/impala/tests/test_impala.py new file mode 100644 index 000000000..ce9b76477 --- /dev/null +++ b/rllib/agents/impala/tests/test_impala.py @@ -0,0 +1,46 @@ +import unittest + +import ray +import ray.rllib.agents.impala as impala +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.test_utils import framework_iterator + +tf = try_import_tf() + + +class TestIMPALA(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_impala_compilation(self): + """Test whether an ImpalaTrainer can be built with both frameworks.""" + config = impala.DEFAULT_CONFIG.copy() + num_iterations = 1 + + for _ in framework_iterator(config, frameworks=("torch", "tf")): + local_cfg = config.copy() + for env in ["Pendulum-v0", "CartPole-v0"]: + print("Env={}".format(env)) + print("w/ LSTM") + # Test w/o LSTM. + trainer = impala.ImpalaTrainer(config=local_cfg, env=env) + for i in range(num_iterations): + print(trainer.train()) + + # Test w/ LSTM. + print("w/o LSTM") + local_cfg["model"]["use_lstm"] = True + trainer = impala.ImpalaTrainer(config=local_cfg, env=env) + for i in range(num_iterations): + print(trainer.train()) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index 05a8e91d1..3b8fac211 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -251,12 +251,6 @@ def postprocess_trajectory(policy, return sample_batch -def validate_config(policy, obs_space, action_space, config): - if config["vtrace"] and not config["in_evaluation"]: - assert config["batch_mode"] == "truncate_episodes", \ - "Must use `truncate_episodes` batch mode with V-trace." - - def choose_optimizer(policy, config): if policy.config["opt_type"] == "adam": return tf.train.AdamOptimizer(policy.cur_lr) @@ -289,7 +283,6 @@ VTraceTFPolicy = build_tf_policy( postprocess_fn=postprocess_trajectory, optimizer_fn=choose_optimizer, gradients_fn=clip_gradients, - before_init=validate_config, before_loss_init=setup_mixins, mixins=[LearningRateSchedule, EntropyCoeffSchedule], get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"]) diff --git a/rllib/agents/impala/vtrace_torch.py b/rllib/agents/impala/vtrace_torch.py index b74451334..540f3bf0b 100644 --- a/rllib/agents/impala/vtrace_torch.py +++ b/rllib/agents/impala/vtrace_torch.py @@ -32,6 +32,7 @@ tensors. from ray.rllib.agents.impala.vtrace_tf import VTraceFromLogitsReturns, \ VTraceReturns from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.utils import force_list from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import convert_to_torch_tensor @@ -217,6 +218,7 @@ def multi_from_logits(behaviour_policy_logits, behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions( behaviour_policy_logits, actions, dist_class, model)) + behaviour_action_log_probs = force_list(behaviour_action_log_probs) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index 4d1371526..e5095a867 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -1,12 +1,187 @@ +import gym import logging +import numpy as np +import ray +from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping +from ray.rllib.agents.impala.vtrace_tf_policy import postprocess_trajectory +import ray.rllib.agents.impala.vtrace_torch as vtrace +from ray.rllib.models.torch.torch_action_dist import TorchCategorical +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy import LearningRateSchedule, \ + EntropyCoeffSchedule +from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import global_norm, sequence_mask torch, nn = try_import_torch() logger = logging.getLogger(__name__) +class VTraceLoss: + def __init__(self, + actions, + actions_logp, + actions_entropy, + dones, + behaviour_action_logp, + behaviour_logits, + target_logits, + discount, + rewards, + values, + bootstrap_value, + dist_class, + model, + valid_mask, + config, + vf_loss_coeff=0.5, + entropy_coeff=0.01, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0): + """Policy gradient loss with vtrace importance weighting. + + VTraceLoss takes tensors of shape [T, B, ...], where `B` is the + batch_size. The reason we need to know `B` is for V-trace to properly + handle episode cut boundaries. + + Args: + actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. + actions_logp: A float32 tensor of shape [T, B]. + actions_entropy: A float32 tensor of shape [T, B]. + dones: A bool tensor of shape [T, B]. + behaviour_action_logp: Tensor of shape [T, B]. + behaviour_logits: A list with length of ACTION_SPACE of float32 + tensors of shapes + [T, B, ACTION_SPACE[0]], + ..., + [T, B, ACTION_SPACE[-1]] + target_logits: A list with length of ACTION_SPACE of float32 + tensors of shapes + [T, B, ACTION_SPACE[0]], + ..., + [T, B, ACTION_SPACE[-1]] + discount: A float32 scalar. + rewards: A float32 tensor of shape [T, B]. + values: A float32 tensor of shape [T, B]. + bootstrap_value: A float32 tensor of shape [B]. + dist_class: action distribution class for logits. + valid_mask: A bool tensor of valid RNN input elements (#2992). + config: Trainer config dict. + """ + + if valid_mask is None: + valid_mask = torch.ones_like(actions_logp) + + # Compute vtrace on the CPU for better perf + # (devices handled inside `vtrace.multi_from_logits`). + self.vtrace_returns = vtrace.multi_from_logits( + behaviour_action_log_probs=behaviour_action_logp, + behaviour_policy_logits=behaviour_logits, + target_policy_logits=target_logits, + actions=torch.unbind(actions, dim=2), + discounts=(1.0 - dones.float()) * discount, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + dist_class=dist_class, + model=model, + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold) + self.value_targets = self.vtrace_returns.vs + + # The policy gradients loss + self.pi_loss = -torch.sum( + actions_logp * self.vtrace_returns.pg_advantages * valid_mask) + + # The baseline loss + delta = (values - self.vtrace_returns.vs) * valid_mask + self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0)) + + # The entropy loss + self.entropy = torch.sum(actions_entropy * valid_mask) + + # The summed weighted loss + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - + self.entropy * entropy_coeff) + + +def build_vtrace_loss(policy, model, dist_class, train_batch): + model_out, _ = model.from_batch(train_batch) + action_dist = dist_class(model_out, model) + + if isinstance(policy.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [policy.action_space.n] + elif isinstance(policy.action_space, + gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = policy.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def _make_time_major(*args, **kw): + return make_time_major(policy, train_batch.get("seq_lens"), *args, + **kw) + + actions = train_batch[SampleBatch.ACTIONS] + dones = train_batch[SampleBatch.DONES] + rewards = train_batch[SampleBatch.REWARDS] + behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] + behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] + if isinstance(output_hidden_shape, list): + unpacked_behaviour_logits = torch.split( + behaviour_logits, output_hidden_shape, dim=1) + unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1) + else: + unpacked_behaviour_logits = torch.chunk( + behaviour_logits, output_hidden_shape, dim=1) + unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) + values = model.value_function() + + if policy.is_recurrent(): + max_seq_len = torch.max(train_batch["seq_lens"]) - 1 + mask = sequence_mask(train_batch["seq_lens"], max_seq_len) + mask = torch.reshape(mask, [-1]) + else: + mask = torch.ones_like(rewards) + + # Prepare actions for loss. + loss_actions = actions if is_multidiscrete else torch.unsqueeze( + actions, dim=1) + + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. + policy.loss = VTraceLoss( + actions=_make_time_major(loss_actions, drop_last=True), + actions_logp=_make_time_major( + action_dist.logp(actions), drop_last=True), + actions_entropy=_make_time_major( + action_dist.multi_entropy(), drop_last=True), + dones=_make_time_major(dones, drop_last=True), + behaviour_action_logp=_make_time_major( + behaviour_action_logp, drop_last=True), + behaviour_logits=_make_time_major( + unpacked_behaviour_logits, drop_last=True), + target_logits=_make_time_major(unpacked_outputs, drop_last=True), + discount=policy.config["gamma"], + rewards=_make_time_major(rewards, drop_last=True), + values=_make_time_major(values, drop_last=True), + bootstrap_value=_make_time_major(values)[-1], + dist_class=TorchCategorical if is_multidiscrete else dist_class, + model=model, + valid_mask=_make_time_major(mask, drop_last=True), + config=policy.config, + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.entropy_coeff, + clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) + + return policy.loss.total_loss + + def make_time_major(policy, seq_lens, tensor, drop_last=False): """Swaps batch and trajectory axis. @@ -44,6 +219,27 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False): return res +def stats(policy, train_batch): + values_batched = make_time_major( + policy, + train_batch.get("seq_lens"), + policy.model.value_function(), + drop_last=policy.config["vtrace"]) + + return { + "cur_lr": policy.cur_lr, + "policy_loss": policy.loss.pi_loss, + "entropy": policy.loss.entropy, + "entropy_coeff": policy.entropy_coeff, + "var_gnorm": global_norm(policy.model.trainable_variables()), + "vf_loss": policy.loss.vf_loss, + "vf_explained_var": explained_variance( + torch.reshape(policy.loss.value_targets, [-1]), + torch.reshape(values_batched, [-1]), + framework="torch"), + } + + def choose_optimizer(policy, config): if policy.config["opt_type"] == "adam": return torch.optim.Adam( @@ -55,3 +251,22 @@ def choose_optimizer(policy, config): weight_decay=config["decay"], momentum=config["momentum"], eps=config["epsilon"]) + + +def setup_mixins(policy, obs_space, action_space, config): + EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], + config["entropy_coeff_schedule"]) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +VTraceTorchPolicy = build_torch_policy( + name="VTraceTorchPolicy", + loss_fn=build_vtrace_loss, + get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, + stats_fn=stats, + postprocess_fn=postprocess_trajectory, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=choose_optimizer, + before_init=setup_mixins, + mixins=[LearningRateSchedule, EntropyCoeffSchedule], + get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"]) diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 535a2013a..01d0b0ade 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -1,6 +1,7 @@ +from ray.rllib.agents.impala.impala import validate_config from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy -from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents.ppo.ppo import update_kl +from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala # yapf: disable @@ -89,11 +90,6 @@ def get_policy_class(config): return AsyncPPOTFPolicy -def validate_config(config): - if config["entropy_coeff"] < 0: - raise ValueError("`entropy_coeff` must be >= 0.0!") - - APPOTrainer = impala.ImpalaTrainer.with_updates( name="APPO", default_config=DEFAULT_CONFIG, diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 1cddbc6fa..a6ea1f769 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -8,7 +8,7 @@ import gym from ray.rllib.agents.impala import vtrace_tf as vtrace from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \ - clip_gradients, validate_config, choose_optimizer + clip_gradients, choose_optimizer from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.policy.sample_batch import SampleBatch @@ -449,7 +449,6 @@ AsyncPPOTFPolicy = build_tf_policy( optimizer_fn=choose_optimizer, gradients_fn=clip_gradients, extra_action_fetches_fn=add_values, - before_init=validate_config, before_loss_init=setup_mixins, after_init=setup_late_mixins, mixins=[ diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 92f96ac15..c00496d79 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -528,9 +528,6 @@ class ModelCatalog: FCNet) from ray.rllib.models.torch.visionnet import (VisionNetwork as VisionNet) - if model_config.get("use_lstm"): - raise NotImplementedError( - "LSTM auto-wrapping not implemented for torch") else: from ray.rllib.models.tf.fcnet_v2 import \ FullyConnectedNetwork as FCNet diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index 424188e18..ce9c19038 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -237,6 +237,11 @@ class TorchSquashedGaussian(TorchDistributionWrapper): unsquashed = atanh(save_normed_values) return unsquashed + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) * 2 + class TorchBeta(TorchDistributionWrapper): """ @@ -285,6 +290,11 @@ class TorchBeta(TorchDistributionWrapper): def _unsquash(self, values): return (values - self.low) / (self.high - self.low) + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape) * 2 + class TorchDeterministic(TorchDistributionWrapper): """Action distribution that returns the input values directly. diff --git a/rllib/tuned_examples/regression_tests/cartpole-impala-tf.yaml b/rllib/tuned_examples/regression_tests/cartpole-impala-tf.yaml new file mode 100644 index 000000000..8233220fb --- /dev/null +++ b/rllib/tuned_examples/regression_tests/cartpole-impala-tf.yaml @@ -0,0 +1,9 @@ +cartpole-impala-tf: + env: CartPole-v0 + run: IMPALA + stop: + episode_reward_mean: 150 + timesteps_total: 500000 + config: + use_pytorch: false + num_gpus: 0 diff --git a/rllib/tuned_examples/regression_tests/cartpole-impala-torch.yaml b/rllib/tuned_examples/regression_tests/cartpole-impala-torch.yaml new file mode 100644 index 000000000..e29d2caad --- /dev/null +++ b/rllib/tuned_examples/regression_tests/cartpole-impala-torch.yaml @@ -0,0 +1,9 @@ +cartpole-impala-torch: + env: CartPole-v0 + run: IMPALA + stop: + episode_reward_mean: 150 + timesteps_total: 500000 + config: + use_pytorch: true + num_gpus: 0 diff --git a/rllib/utils/exploration/tests/test_explorations.py b/rllib/utils/exploration/tests/test_explorations.py index 3e0d87997..211c19d3c 100644 --- a/rllib/utils/exploration/tests/test_explorations.py +++ b/rllib/utils/exploration/tests/test_explorations.py @@ -30,10 +30,7 @@ def do_test_explorations(run, # Test all frameworks. for fw in framework_iterator(core_config): - if fw == "torch" and \ - run in [impala.ImpalaTrainer, sac.SACTrainer]: - continue - elif fw == "eager" and run in [ + if fw == "eager" and run in [ ddpg.DDPGTrainer, sac.SACTrainer, td3.TD3Trainer ]: continue diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 49bf4269d..600040d91 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -47,7 +47,7 @@ def framework_iterator(config=None, logger.warning( "framework_iterator skipping torch (not installed)!") continue - elif not tf: + if fw != "torch" and not tf: logger.warning("framework_iterator skipping {} (tf not " "installed)!".format(fw)) continue