diff --git a/rllib/BUILD b/rllib/BUILD index eea7f3c9b..a7f035bda 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -883,7 +883,7 @@ py_test( py_test( name = "policy/tests/test_compute_log_likelihoods", tags = ["policy"], - size = "small", + size = "medium", srcs = ["policy/tests/test_compute_log_likelihoods.py"] ) @@ -966,7 +966,7 @@ py_test( py_test( name = "tests/test_eager_support", tags = ["tests_dir", "tests_dir_E"], - size = "large", + size = "enormous", srcs = ["tests/test_eager_support.py"] ) @@ -1361,7 +1361,8 @@ sh_test( ) py_test( - name = "examples/rock_paper_scissors_multiagent", main = "examples/rock_paper_scissors_multiagent.py", + name = "examples/rock_paper_scissors_multiagent", + main = "examples/rock_paper_scissors_multiagent.py", tags = ["examples", "examples_R"], size = "large", srcs = ["examples/rock_paper_scissors_multiagent.py"], diff --git a/rllib/agents/dqn/tests/test_dqn.py b/rllib/agents/dqn/tests/test_dqn.py index a3b6b49a8..aab555c0a 100644 --- a/rllib/agents/dqn/tests/test_dqn.py +++ b/rllib/agents/dqn/tests/test_dqn.py @@ -33,7 +33,7 @@ class TestDQN(unittest.TestCase): tf_config = config.copy() tf_config["eager"] = False trainer = dqn.DQNTrainer(config=tf_config, env="CartPole-v0") - num_iterations = 2 + num_iterations = 1 for i in range(num_iterations): results = trainer.train() print(results) @@ -44,7 +44,7 @@ class TestDQN(unittest.TestCase): eager_ctx = eager_mode() eager_ctx.__enter__() trainer = dqn.DQNTrainer(config=eager_config, env="CartPole-v0") - num_iterations = 2 + num_iterations = 1 for i in range(num_iterations): results = trainer.train() print(results) @@ -58,14 +58,21 @@ class TestDQN(unittest.TestCase): obs = np.array(0) # Test against all frameworks. - for fw in ["eager", "tf", "torch"]: + for fw in ["tf", "eager", "torch"]: if fw == "torch": continue print("framework={}".format(fw)) - config["eager"] = True if fw == "eager" else False - config["use_pytorch"] = True if fw == "torch" else False + eager_mode_ctx = None + if fw == "tf": + assert not tf.executing_eagerly() + else: + eager_mode_ctx = eager_mode() + eager_mode_ctx.__enter__() + + config["eager"] = fw == "eager" + config["use_pytorch"] = fw == "torch" # Default EpsilonGreedy setup. trainer = dqn.DQNTrainer(config=config, env="FrozenLake-v0") @@ -122,5 +129,6 @@ class TestDQN(unittest.TestCase): if __name__ == "__main__": - import unittest - unittest.main(verbosity=1) + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 61d70358c..5f47f3c8a 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -419,7 +419,8 @@ class Trainer(Trainable): config = config or {} if tf and config.get("eager"): - tf.enable_eager_execution() + if not tf.executing_eagerly(): + tf.enable_eager_execution() logger.info("Executing eagerly, with eager_tracing={}".format( "True" if config.get("eager_tracing") else "False")) diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 073b61f0e..e858d31c0 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -151,6 +151,9 @@ class MultiAgentSampleBatchBuilder: "from a single trajectory.", pre_batch) post_batches[agent_id] = policy.postprocess_trajectory( pre_batch, other_batches, episode) + # Call the Policy's Exploration's postprocess method. + policy.exploration.postprocess_trajectory( + policy, post_batches[agent_id], getattr(policy, "_sess", None)) if log_once("after_post"): logger.info( diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 7f413b0c5..9358f289a 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -306,6 +306,14 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) + # Call each policy's Exploration.on_episode_start method. + for p in policies.values(): + p.exploration.on_episode_start( + policy=p, + environment=base_env, + episode=episode, + tf_sess=getattr(p, "_sess", None)) + # Call custom on_episode_start callback. if callbacks.get("on_episode_start"): callbacks["on_episode_start"]({ "env": base_env, @@ -492,6 +500,14 @@ def _process_observations(base_env, policies, batch_builder_pool, if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) + # Call each policy's Exploration.on_episode_end method. + for p in policies.values(): + p.exploration.on_episode_end( + policy=p, + environment=base_env, + episode=episode, + tf_sess=getattr(p, "_sess", None)) + # Call custom on_episode_end callback. if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, @@ -558,15 +574,15 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is TFPolicy.compute_actions.__code__): - rnn_in_cols = _to_column_format(rnn_in) + + obs_batch = [t.obs for t in eval_data] + state_batches = _to_column_format(rnn_in) + # TODO(ekl): how can we make info batch available to TF code? - # TODO(sven): Return dict from _build_compute_actions. - # it's becoming more and more unclear otherwise, what's where in - # the return tuple. pending_fetches[policy_id] = policy._build_compute_actions( builder, - obs_batch=[t.obs for t in eval_data], - state_batches=rnn_in_cols, + obs_batch=obs_batch, + state_batches=state_batches, prev_action_batch=[t.prev_action for t in eval_data], prev_reward_batch=[t.prev_reward for t in eval_data], timestep=policy.global_timestep) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index bed04b468..9c02f9155 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -261,15 +261,16 @@ def build_eager_tf_policy(name, @override(Policy) def postprocess_trajectory(self, - samples, + sample_batch, other_agent_batches=None, episode=None): assert tf.executing_eagerly() + # Call super's postprocess_trajectory first. + sample_batch = Policy.postprocess_trajectory(self, sample_batch) if postprocess_fn: - return postprocess_fn(self, samples, other_agent_batches, + return postprocess_fn(self, sample_batch, other_agent_batches, episode) - else: - return samples + return sample_batch @override(Policy) @convert_eager_inputs @@ -305,6 +306,8 @@ def build_eager_tf_policy(name, explore = explore if explore is not None else \ self.config["explore"] + timestep = timestep if timestep is not None else \ + self.global_timestep # TODO: remove python side effect to cull sources of bugs. self._is_training = False @@ -339,19 +342,20 @@ def build_eager_tf_policy(name, self.action_space, explore, self.config, - timestep=timestep - if timestep is not None else self.global_timestep) + timestep=timestep) # Use Exploration object. else: with tf.variable_creator_scope(_disallow_var_creation): + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions(timestep=timestep) + model_out, state_out = self.model(input_dict, state_batches, seq_lens) action, logp = self.exploration.get_exploration_action( model_out, self.dist_class, self.model, - timestep=timestep - if timestep is not None else self.global_timestep, + timestep=timestep, explore=explore) extra_fetches = {} diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 44f5446b3..a3b20b2a6 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -284,27 +284,6 @@ class Policy(metaclass=ABCMeta): """ return self.exploration.get_info() - @DeveloperAPI - def get_exploration_state(self): - """Returns the current exploration state of this policy. - - This state depends on the policy's Exploration object. - - Returns: - any: Serializable copy or view of the current exploration state. - """ - raise NotImplementedError - - @DeveloperAPI - def set_exploration_state(self, exploration_state): - """Sets the current exploration state of this Policy. - - Arguments: - exploration_state (any): Serializable copy or view of the new - exploration state. - """ - raise NotImplementedError - @DeveloperAPI def is_recurrent(self): """Whether this Policy holds a recurrent Model. diff --git a/rllib/policy/tests/test_compute_log_likelihoods.py b/rllib/policy/tests/test_compute_log_likelihoods.py index 645aa4cf6..3017dc851 100644 --- a/rllib/policy/tests/test_compute_log_likelihoods.py +++ b/rllib/policy/tests/test_compute_log_likelihoods.py @@ -45,14 +45,15 @@ def do_test_log_likelihood(run, config["use_pytorch"] = fw == "torch" eager_ctx = None - if fw == "eager": + if fw == "tf": + assert not tf.executing_eagerly() + elif fw == "eager": eager_ctx = eager_mode() eager_ctx.__enter__() assert tf.executing_eagerly() - elif fw == "tf": - assert not tf.executing_eagerly() trainer = run(config=config, env=env) + policy = trainer.get_policy() vars = policy.get_weights() # Sample n actions, then roughly check their logp against their diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index d6310ae88..188b0aa9e 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -253,6 +253,7 @@ class TFPolicy(Policy): timestep=None, **kwargs): explore = explore if explore is not None else self.config["explore"] + builder = TFRunBuilder(self._sess, "compute_actions") fetches = self._build_compute_actions( builder, @@ -528,11 +529,16 @@ class TFPolicy(Policy): explore = explore if explore is not None else self.config["explore"] + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions( + timestep=self.global_timestep, tf_sess=self.get_session()) + state_batches = state_batches or [] if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". format(self._state_inputs, state_batches)) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index e704410fe..0eefc9da3 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -151,10 +151,10 @@ def build_tf_policy(name, sample_batch, other_agent_batches=None, episode=None): - if not postprocess_fn: - return sample_batch - return postprocess_fn(self, sample_batch, other_agent_batches, - episode) + if postprocess_fn: + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + return sample_batch @override(TFPolicy) def optimizer(self): diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 4e51f70e9..432448738 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -71,6 +71,7 @@ class TorchPolicy(Policy): **kwargs): explore = explore if explore is not None else self.config["explore"] + timestep = timestep if timestep is not None else self.global_timestep with torch.no_grad(): input_dict = self._lazy_tensor_dict({ @@ -81,6 +82,10 @@ class TorchPolicy(Policy): if prev_reward_batch: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch state_batches = [self._convert_to_tensor(s) for s in state_batches] + + # Call the exploration before_compute_actions hook. + self.exploration.before_compute_actions(timestep=timestep) + model_out = self.model(input_dict, state_batches, self._convert_to_tensor([1])) logits, state = model_out @@ -88,8 +93,7 @@ class TorchPolicy(Policy): actions, logp = \ self.exploration.get_exploration_action( logits, self.dist_class, self.model, - timestep if timestep is not None else - self.global_timestep, explore) + timestep, explore) input_dict[SampleBatch.ACTIONS] = actions extra_action_out = self.extra_action_out(input_dict, state_batches, @@ -100,8 +104,8 @@ class TorchPolicy(Policy): ACTION_PROB: np.exp(logp), ACTION_LOGP: logp }) - return convert_to_non_torch_type( - (actions, state, extra_action_out)) + return convert_to_non_torch_type((actions, state, + extra_action_out)) @override(Policy) def compute_log_likelihoods(self, diff --git a/rllib/policy/torch_policy_template.py b/rllib/policy/torch_policy_template.py index df67c4f13..ac45341a9 100644 --- a/rllib/policy/torch_policy_template.py +++ b/rllib/policy/torch_policy_template.py @@ -86,10 +86,8 @@ def build_torch_policy(name, self.config["model"], framework="torch") - TorchPolicy.__init__( - self, obs_space, action_space, config, self.model, - loss_fn, self.dist_class - ) + TorchPolicy.__init__(self, obs_space, action_space, config, + self.model, loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config) @@ -117,17 +115,18 @@ def build_torch_policy(name, return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) - def extra_action_out(self, input_dict, state_batches, model, + def extra_action_out(self, + input_dict, + state_batches, + model, action_dist=None): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn( - self, input_dict, state_batches, model, action_dist - ) + self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( - self, input_dict, state_batches, model, action_dist - ) + self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict) @override(TorchPolicy) diff --git a/rllib/tests/test_eager_support.py b/rllib/tests/test_eager_support.py index f562930c7..2d2dc9853 100644 --- a/rllib/tests/test_eager_support.py +++ b/rllib/tests/test_eager_support.py @@ -87,8 +87,9 @@ class TestEagerSupport(unittest.TestCase): }, }) - def test_sac(self): - check_support("SAC", {"num_workers": 0}) + # TODO(sven): Add this once SAC supports eager. + # def test_sac(self): + # check_support("SAC", {"num_workers": 0, "learning_starts": 0}) if __name__ == "__main__": diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index 78f1eaa95..1a4fe15a6 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -1,12 +1,15 @@ from gym.spaces import Space +from typing import Union + from ray.rllib.utils.framework import check_framework, try_import_tf, \ TensorType from ray.rllib.models.modelv2 import ModelV2 -from typing import Union +from ray.rllib.utils.annotations import DeveloperAPI tf = try_import_tf() +@DeveloperAPI class Exploration: """Implements an exploration strategy for Policies. @@ -32,6 +35,24 @@ class Exploration: self.worker_index = worker_index self.framework = check_framework(framework) + @DeveloperAPI + def before_compute_actions(self, + *, + timestep=None, + explore=None, + tf_sess=None, + **kwargs): + """Hook for preparations before policy.compute_actions() is called. + + Args: + timestep (Optional[TensorType]): An optional timestep tensor. + explore (Optional[TensorType]): An optional explore boolean flag. + tf_sess (Optional[tf.Session]): The tf-session object to use. + **kwargs: Forward compatibility kwargs. + """ + pass + + @DeveloperAPI def get_exploration_action(self, distribution_inputs: TensorType, action_dist_class: type, @@ -64,25 +85,55 @@ class Exploration: """ pass - def get_loss_exploration_term(self, - model_output: TensorType, - model: ModelV2, - action_dist: type, - action_sample: TensorType = None): - """Returns an extra loss term to be added to a loss. + @DeveloperAPI + def on_episode_start(self, + policy, + *, + environment=None, + episode=None, + tf_sess=None): + """Handles necessary exploration logic at the beginning of an episode. Args: - model_output (TensorType): The Model's output Tensor(s). - model (ModelV2): The Model object. - action_dist: The ActionDistribution object resulting from - `model_output`. TODO: Or the class? - action_sample (TensorType): An optional action sample. - - Returns: - TensorType: The extra loss term to add to the loss. + policy (Policy): The Policy object that holds this Exploration. + environment (BaseEnv): The environment object we are acting in. + episode (int): The number of the episode that is starting. + tf_sess (Optional[tf.Session]): In case of tf, the session object. """ - pass # TODO(sven): implement for some example Exploration class. + pass + @DeveloperAPI + def on_episode_end(self, + policy, + *, + environment=None, + episode=None, + tf_sess=None): + """Handles necessary exploration logic at the end of an episode. + + Args: + policy (Policy): The Policy object that holds this Exploration. + environment (BaseEnv): The environment object we are acting in. + episode (int): The number of the episode that is starting. + tf_sess (Optional[tf.Session]): In case of tf, the session object. + """ + pass + + @DeveloperAPI + def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): + """Handles post-processing of done episode trajectories. + + Changes the given batch in place. This callback is invoked by the + sampler after policy.postprocess_trajectory() is called. + + Args: + policy (Policy): The owning policy object. + sample_batch (SampleBatch): The SampleBatch object to post-process. + tf_sess (Optional[tf.Session]): An optional tf.Session object. + """ + return sample_batch + + @DeveloperAPI def get_info(self): """Returns a description of the current exploration state. diff --git a/rllib/utils/exploration/parameter_noise.py b/rllib/utils/exploration/parameter_noise.py new file mode 100644 index 000000000..0c5d0ff42 --- /dev/null +++ b/rllib/utils/exploration/parameter_noise.py @@ -0,0 +1,382 @@ +from gym.spaces import Discrete +import numpy as np + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical +from ray.rllib.utils.annotations import override +from ray.rllib.utils.exploration.exploration import Exploration +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.framework import get_variable +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.numpy import softmax, SMALL_NUMBER + +tf = try_import_tf() +torch, _ = try_import_torch() + + +class ParameterNoise(Exploration): + """An exploration that changes a Model's parameters. + + Implemented based on: + [1] https://blog.openai.com/better-exploration-with-parameter-noise/ + [2] https://arxiv.org/pdf/1706.01905.pdf + + At the beginning of an episode, Gaussian noise is added to all weights + of the model. At the end of the episode, the noise is undone and an action + diff (pi-delta) is calculated, from which we determine the changes in the + noise's stddev for the next episode. + """ + + def __init__(self, + action_space, + *, + framework: str, + policy_config: dict, + model: ModelV2, + initial_stddev=1.0, + random_timesteps=10000, + sub_exploration=None, + **kwargs): + """Initializes a ParameterNoise Exploration object. + + Args: + initial_stddev (float): The initial stddev to use for the noise. + random_timesteps (int): The number of timesteps to act completely + randomly (see [1]). + sub_exploration (Optional[dict]): Optional sub-exploration config. + None for auto-detection/setup. + """ + assert framework is not None + super().__init__(action_space, framework=framework, **kwargs) + + # TODO(sven): Move these to base-Exploration class. + self.policy_config = policy_config, + self.model = model, + + self.stddev = get_variable( + initial_stddev, framework=self.framework, tf_name="stddev") + self.stddev_val = initial_stddev # Out-of-graph tf value holder. + + # The weight variables of the Model where noise should be applied to. + # This excludes any variable, whose name contains "LayerNorm" (those + # are BatchNormalization layers, which should not be perturbed). + self.model_variables = [ + v for v in self.model.variables() if "LayerNorm" not in v.name + ] + # Our noise to be added to the weights. Each item in `self.noise` + # corresponds to one Model variable and holding the Gaussian noise to + # be added to that variable (weight). + self.noise = [] + for var in self.model_variables: + self.noise.append( + get_variable( + np.zeros(var.shape, dtype=np.float32), + framework=self.framework, + tf_name=var.name.split(":")[0] + "_noisy")) + + # tf-specific ops to sample, assign and remove noise. + if self.framework == "tf" and not tf.executing_eagerly(): + self.tf_sample_new_noise_op = \ + self._tf_sample_new_noise_op() + self.tf_add_stored_noise_op = \ + self._tf_add_stored_noise_op() + self.tf_remove_noise_op = \ + self._tf_remove_noise_op() + # Create convenience sample+add op for tf. + with tf.control_dependencies([self.tf_sample_new_noise_op]): + add_op = self._tf_add_stored_noise_op() + with tf.control_dependencies([add_op]): + self.tf_sample_new_noise_and_add_op = tf.no_op() + + # Whether the Model's weights currently have noise added or not. + self.weights_are_currently_noisy = False + + # Auto-detection of underlying exploration functionality. + if sub_exploration is None: + # For discrete action spaces, use an underlying EpsilonGreedy with + # a special schedule. + if isinstance(self.action_space, Discrete): + sub_exploration = { + "type": "EpsilonGreedy", + "epsilon_schedule": { + "type": "PiecewiseSchedule", + # Step function (see [2]). + "endpoints": [(0, 1.0), (random_timesteps + 1, 1.0), + (random_timesteps + 2, 0.01)], + "outside_value": 0.01 + } + } + # TODO(sven): Implement for any action space. + else: + raise NotImplementedError + + self.sub_exploration = from_config( + Exploration, + sub_exploration, + framework=self.framework, + action_space=self.action_space, + **kwargs) + + # Store the default setting for `explore`. + self.default_explore = policy_config["explore"] + # Whether we need to call `self._delayed_on_episode_start` before + # the forward pass. + self.episode_started = False + + @override(Exploration) + def before_compute_actions(self, + *, + timestep=None, + explore=None, + tf_sess=None): + # Is this the first forward pass in the new episode? If yes, do the + # noise re-sampling and add to weights. + if self.episode_started: + self._delayed_on_episode_start(tf_sess) + + explore = explore if explore is not None else \ + self.policy_config["explore"] + # Add noise if necessary. + if explore and not self.weights_are_currently_noisy: + self._add_stored_noise(tf_sess=tf_sess) + # Remove noise if necessary. + elif not explore and self.weights_are_currently_noisy: + self._remove_noise(tf_sess=tf_sess) + + @override(Exploration) + def get_exploration_action(self, + *, + distribution_inputs, + action_dist_class, + timestep, + explore=True): + # Use our sub-exploration object to handle the final exploration + # action (depends on the algo-type/action-space/etc..). + return self.sub_exploration.get_exploration_action( + distribution_inputs=distribution_inputs, + action_dist_class=action_dist_class, + timestep=timestep, + explore=explore) + + @override(Exploration) + def on_episode_start(self, + policy, + *, + environment=None, + episode=None, + tf_sess=None): + # We have to delay the noise-adding step by one forward call. + # This is due to the fact that the optimizer does it's step right + # after the episode was reset (and hence the noise was already added!). + # We don't want to update into a noisy net. + self.episode_started = True + + def _delayed_on_episode_start(self, tf_sess): + # Sample fresh noise and add to weights. + if self.default_explore: + self._sample_new_noise_and_add(tf_sess=tf_sess, override=True) + # Only sample, don't apply anything to the weights. + else: + self._sample_new_noise(tf_sess=tf_sess) + self.episode_started = False + + @override(Exploration) + def on_episode_end(self, + policy, + *, + environment=None, + episode=None, + tf_sess=None): + # Remove stored noise from weights (only if currently noisy). + if self.weights_are_currently_noisy: + self._remove_noise(tf_sess=tf_sess) + + @override(Exploration) + def postprocess_trajectory(self, policy, sample_batch, tf_sess=None): + noisy_action_dist = noise_free_action_dist = None + # Adjust the stddev depending on the action (pi)-distance. + # Also see [1] for details. + distribution = policy.compute_action_distribution( + obs_batch=sample_batch[SampleBatch.CUR_OBS], + # TODO(sven): What about state-ins and seq-lens? + prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS), + explore=self.weights_are_currently_noisy) + + # Categorical case (e.g. DQN). + if isinstance(distribution, Categorical): + action_dist = softmax(distribution.inputs) + else: # TODO(sven): Other action-dist cases. + raise NotImplementedError + + if self.weights_are_currently_noisy: + noisy_action_dist = action_dist + else: + noise_free_action_dist = action_dist + + distribution = policy.compute_action_distribution( + obs_batch=sample_batch[SampleBatch.CUR_OBS], + # TODO(sven): What about state-ins and seq-lens? + prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS), + explore=not self.weights_are_currently_noisy) + + # Categorical case (e.g. DQN). + if isinstance(distribution, Categorical): + action_dist = softmax(distribution.inputs) + + if not self.weights_are_currently_noisy: + noisy_action_dist = action_dist + else: + noise_free_action_dist = action_dist + + # Categorical case (e.g. DQN). + if isinstance(distribution, Categorical): + # Calculate KL-divergence (DKL(clean||noisy)) according to [2]. + # TODO(sven): Allow KL-divergence to be calculated by our + # Distribution classes (don't support off-graph/numpy yet). + kl_divergence = np.nanmean( + np.sum( + noise_free_action_dist * + np.log(noise_free_action_dist / + (noisy_action_dist + SMALL_NUMBER)), 1)) + current_epsilon = self.sub_exploration.get_info()["cur_epsilon"] + if tf_sess is not None: + current_epsilon = tf_sess.run(current_epsilon) + delta = -np.log(1 - current_epsilon + + current_epsilon / self.action_space.n) + if kl_divergence <= delta: + self.stddev_val *= 1.01 + else: + self.stddev_val /= 1.01 + + # Set self.stddev to calculated value. + if self.framework == "tf": + self.stddev.load(self.stddev_val, session=tf_sess) + else: + self.stddev = self.stddev_val + + return sample_batch + + def _sample_new_noise(self, *, tf_sess=None): + """Samples new noise and stores it in `self.noise`.""" + if self.framework == "tf": + if tf.executing_eagerly(): + self._tf_sample_new_noise_op() + else: + tf_sess.run(self.tf_sample_new_noise_op) + else: + for i in range(len(self.noise)): + self.noise[i] = torch.normal( + 0.0, self.stddev, size=self.noise[i].size) + + def _tf_sample_new_noise_op(self): + added_noises = [] + for noise in self.noise: + added_noises.append( + tf.assign( + noise, + tf.random_normal( + shape=noise.shape, + stddev=self.stddev, + dtype=tf.float32))) + return tf.group(*added_noises) + + def _sample_new_noise_and_add(self, *, tf_sess=None, override=False): + if self.framework == "tf" and not tf.executing_eagerly(): + if override and self.weights_are_currently_noisy: + tf_sess.run(self.tf_remove_noise_op) + tf_sess.run(self.tf_sample_new_noise_and_add_op) + else: + if override and self.weights_are_currently_noisy: + self._remove_noise() + self._sample_new_noise() + self._add_stored_noise() + + self.weights_are_currently_noisy = True + + def _add_stored_noise(self, *, tf_sess=None): + """Adds the stored `self.noise` to the model's parameters. + + Note: No new sampling of noise here. + + Args: + tf_sess (Optional[tf.Session]): The tf-session to use to add the + stored noise to the (currently noise-free) weights. + override (bool): If True, undo any currently applied noise first, + then add the currently stored noise. + """ + # Make sure we only add noise to currently noise-free weights. + assert self.weights_are_currently_noisy is False + + if self.framework == "tf": + if tf.executing_eagerly(): + self._tf_add_stored_noise_op() + else: + tf_sess.run(self.tf_add_stored_noise_op) + # Add stored noise to the model's parameters. + else: + for i in range(len(self.noise)): + # Add noise to weights in-place. + torch.add_(self.model_variables[i], self.noise[i]) + + self.weights_are_currently_noisy = True + + def _tf_add_stored_noise_op(self): + """Generates tf-op that assigns the stored noise to weights. + + Also used by tf-eager. + + Returns: + tf.op: The tf op to apply the already stored noise to the NN. + """ + add_noise_ops = list() + for var, noise in zip(self.model_variables, self.noise): + add_noise_ops.append(tf.assign_add(var, noise)) + ret = tf.group(*tuple(add_noise_ops)) + with tf.control_dependencies([ret]): + return tf.no_op() + + def _remove_noise(self, *, tf_sess=None): + """ + Removes the current action noise from the model parameters. + + Args: + tf_sess (Optional[tf.Session]): The tf-session to use to remove + the noise from the (currently noisy) weights. + """ + # Make sure we only remove noise iff currently noisy. + assert self.weights_are_currently_noisy is True + + if self.framework == "tf": + if tf.executing_eagerly(): + self._tf_remove_noise_op() + else: + tf_sess.run(self.tf_remove_noise_op) + else: + # Removes the stored noise from the model's parameters. + for var, noise in zip(self.model_variables, self.noise): + # Remove noise from weights in-place. + torch.add_(var, -noise) + + self.weights_are_currently_noisy = False + + def _tf_remove_noise_op(self): + """Generates a tf-op for removing noise from the model's weights. + + Also used by tf-eager. + + Returns: + tf.op: The tf op to remve the currently stored noise from the NN. + """ + remove_noise_ops = list() + for var, noise in zip(self.model_variables, self.noise): + remove_noise_ops.append(tf.assign_add(var, -noise)) + ret = tf.group(*tuple(remove_noise_ops)) + with tf.control_dependencies([ret]): + return tf.no_op() + + @override(Exploration) + def get_info(self): + return {"cur_stddev": self.stddev} diff --git a/rllib/utils/exploration/tests/test_explorations.py b/rllib/utils/exploration/tests/test_explorations.py index c136c9f25..71d67621e 100644 --- a/rllib/utils/exploration/tests/test_explorations.py +++ b/rllib/utils/exploration/tests/test_explorations.py @@ -12,7 +12,9 @@ import ray.rllib.agents.impala as impala import ray.rllib.agents.pg as pg import ray.rllib.agents.ppo as ppo import ray.rllib.agents.sac as sac -from ray.rllib.utils import check +from ray.rllib.utils import check, try_import_tf + +tf = try_import_tf() def do_test_explorations(run, @@ -53,6 +55,9 @@ def do_test_explorations(run, eager_mode_ctx = eager_mode() if fw == "eager": eager_mode_ctx.__enter__() + assert tf.executing_eagerly() + elif fw == "tf": + assert not tf.executing_eagerly() trainer = run(config=config, env=env) diff --git a/rllib/utils/framework.py b/rllib/utils/framework.py index 4b85e8ad4..ece6ebac3 100644 --- a/rllib/utils/framework.py +++ b/rllib/utils/framework.py @@ -134,7 +134,12 @@ def get_variable(value, framework="tf", tf_name="unnamed-variable"): """ if framework == "tf": import tensorflow as tf - return tf.compat.v1.get_variable(tf_name, initializer=value) + dtype = getattr( + value, "dtype", tf.float32 + if isinstance(value, float) else tf.int32 + if isinstance(value, int) else None) + return tf.compat.v1.get_variable( + tf_name, initializer=value, dtype=dtype) # torch or None: Return python primitive. return value