diff --git a/rllib/BUILD b/rllib/BUILD index 977c78df4..9f3d0b68f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1005,6 +1005,13 @@ py_test( srcs = ["tests/test_dependency.py"] ) +py_test( + name = "tests/test_dependency_torch", + tags = ["tests_dir", "tests_dir_D"], + size = "small", + srcs = ["tests/test_dependency_torch.py"] +) + py_test( name = "tests/test_eager_support", tags = ["tests_dir", "tests_dir_E"], diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index 50febd1bd..13430a88b 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -1,8 +1,6 @@ from gym.spaces import Tuple, Discrete, Dict import logging import numpy as np -from torch.optim import RMSprop -from torch.distributions import Categorical import ray from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer @@ -244,6 +242,7 @@ class QMixTorchPolicy(Policy): self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) + from torch.optim import RMSprop self.optimiser = RMSprop( params=self.params, lr=config["lr"], @@ -283,6 +282,7 @@ class QMixTorchPolicy(Policy): random_numbers = torch.rand_like(q_values[:, :, 0]) pick_random = (random_numbers < (self.cur_epsilon if explore else 0.0)).long() + from torch.distributions import Categorical random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + (1 - pick_random) * masked_q_values.argmax(dim=2)) diff --git a/rllib/tests/test_dependency_torch.py b/rllib/tests/test_dependency_torch.py new file mode 100755 index 000000000..59198bdc1 --- /dev/null +++ b/rllib/tests/test_dependency_torch.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +import os +import sys + +if __name__ == "__main__": + # Do not import torch for testing purposes. + os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1" + + from ray.rllib.agents.a3c import A2CTrainer + assert "torch" not in sys.modules, \ + "Torch initially present, when it shouldn't." + + # note: no ray.init(), to test it works without Ray + trainer = A2CTrainer( + env="CartPole-v0", config={ + "use_pytorch": False, + "num_workers": 0 + }) + trainer.train() + + assert "torch" not in sys.modules, "Torch should not be imported" diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 342ae4eab..3de14c6f1 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -101,7 +101,10 @@ def try_import_tfp(error=False): # Fake module for torch.nn. class NNStub: - pass + def __init__(self, *a, **kw): + # Fake nn.functional module within torch.nn. + self.functional = None + self.Module = ModuleStub # Fake class for torch.nn.Module to allow it to be inherited from. @@ -120,7 +123,7 @@ def try_import_torch(error=False): """ if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ: logger.warning("Not importing Torch for test purposes.") - return None, None + return _torch_stubs() try: import torch @@ -129,10 +132,12 @@ def try_import_torch(error=False): except ImportError as e: if error: raise e + return _torch_stubs() - nn = NNStub() - nn.Module = ModuleStub - return None, nn + +def _torch_stubs(): + nn = NNStub() + return None, nn def get_variable(value, @@ -165,7 +170,7 @@ def get_variable(value, return tf.compat.v1.get_variable( tf_name, initializer=value, dtype=dtype, trainable=trainable) elif framework == "torch" and torch_tensor is True: - import torch + torch, _ = try_import_torch() var_ = torch.from_numpy(value) var_.requires_grad = trainable return var_