diff --git a/python/ray/experimental/tf_utils.py b/python/ray/experimental/tf_utils.py index 8fa0d4552..6bc1255a1 100644 --- a/python/ray/experimental/tf_utils.py +++ b/python/ray/experimental/tf_utils.py @@ -1,7 +1,8 @@ from collections import deque, OrderedDict import numpy as np -from ray.rllib.utils import try_import_tf +from ray.rllib.utils import force_list +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() @@ -47,8 +48,7 @@ class TensorFlowVariables: list. """ self.sess = sess - if not isinstance(output, (list, tuple)): - output = [output] + output = force_list(output) queue = deque(output) variable_names = [] explored_inputs = set(output) diff --git a/rllib/agents/ars/ars_tf_policy.py b/rllib/agents/ars/ars_tf_policy.py index c201000b7..002b97490 100644 --- a/rllib/agents/ars/ars_tf_policy.py +++ b/rllib/agents/ars/ars_tf_policy.py @@ -64,6 +64,12 @@ class ARSTFPolicy: action += np.random.randn(*action.shape) * self.action_noise_std return action + def get_state(self): + return {"state": self.get_flat_weights()} + + def set_state(self, state): + return self.set_flat_weights(state["state"]) + def set_flat_weights(self, x): self.variables.set_flat(x) diff --git a/rllib/agents/es/es_tf_policy.py b/rllib/agents/es/es_tf_policy.py index 166129b73..7c34b25b8 100644 --- a/rllib/agents/es/es_tf_policy.py +++ b/rllib/agents/es/es_tf_policy.py @@ -120,6 +120,12 @@ class ESTFPolicy: self.action_noise_std return single_action + def get_state(self): + return {"state": self.get_flat_weights()} + + def set_state(self, state): + return self.set_flat_weights(state["state"]) + def set_flat_weights(self, x): self.variables.set_flat(x) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 162731de9..8dd15584f 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -332,10 +332,12 @@ def build_eager_tf_policy(name, "is_training": tf.constant(False), } if obs_include_prev_action_reward: - input_dict[SampleBatch.PREV_ACTIONS] = \ - tf.convert_to_tensor(prev_action_batch) - input_dict[SampleBatch.PREV_REWARDS] = \ - tf.convert_to_tensor(prev_reward_batch) + if prev_action_batch is not None: + input_dict[SampleBatch.PREV_ACTIONS] = \ + tf.convert_to_tensor(prev_action_batch) + if prev_reward_batch is not None: + input_dict[SampleBatch.PREV_REWARDS] = \ + tf.convert_to_tensor(prev_reward_batch) # Use Exploration object. with tf.variable_creator_scope(_disallow_var_creation): @@ -464,6 +466,29 @@ def build_eager_tf_policy(name, for v, w in zip(variables, weights): v.assign(w) + @override(Policy) + def get_state(self): + state = {"_state": super().get_state()} + state["_optimizer_variables"] = self._optimizer.variables() + return state + + @override(Policy) + def set_state(self, state): + state = state.copy() # shallow copy + # Set optimizer vars first. + optimizer_vars = state.pop("_optimizer_variables", None) + if optimizer_vars and self._optimizer.variables(): + logger.warning( + "Cannot restore an optimizer's state for tf eager! Keras " + "is not able to save the v1.x optimizers (from " + "tf.compat.v1.train) since they aren't compatible with " + "checkpoints.") + for opt_var, value in zip(self._optimizer.variables(), + optimizer_vars): + opt_var.assign(value) + # Then the Policy's (NN) weights. + super().set_state(state["_state"]) + def variables(self): """Return the list of all savable variables for this policy.""" return self.model.variables() diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 58305ea40..d3b00a2d8 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -139,23 +139,11 @@ class TFPolicy(Policy): self._action_input = action_input # For logp calculations. self._dist_inputs = dist_inputs self.dist_class = dist_class - self._log_likelihood = log_likelihood + self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] self._seq_lens = seq_lens self._max_seq_len = max_seq_len - self._batch_divisibility_req = batch_divisibility_req - self._update_ops = update_ops - self._stats_fetches = {} - self._loss_input_dict = None - self._timestep = timestep if timestep is not None else \ - tf.placeholder(tf.int32, (), name="timestep") - - if loss is not None: - self._initialize_loss(loss, loss_inputs) - else: - self._loss = None - if len(self._state_inputs) != len(self._state_outputs): raise ValueError( "Number of state input and output tensors must match, got: " @@ -169,9 +157,34 @@ class TFPolicy(Policy): raise ValueError( "seq_lens tensor must be given if state inputs are defined") + self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._apply_op = None + self._stats_fetches = {} + self._timestep = timestep if timestep is not None else \ + tf.placeholder(tf.int32, (), name="timestep") + + self._optimizer = None + self._grads_and_vars = None + self._grads = None + # Policy tf-variables (weights), whose values to get/set via + # get_weights/set_weights. + self._variables = None + # Local optimizer's tf-variables (e.g. state vars for Adam). + # Will be stored alongside `self._variables` when checkpointing. + self._optimizer_variables = None + + # The loss tf-op. + self._loss = None + # A batch dict passed into loss function as input. + self._loss_input_dict = None + if loss is not None: + self._initialize_loss(loss, loss_inputs) + # The log-likelihood calculator op. - self._log_likelihood = None - if self._dist_inputs is not None and self.dist_class is not None: + self._log_likelihood = log_likelihood + if self._log_likelihood is None and self._dist_inputs is not None and \ + self.dist_class is not None: self._log_likelihood = self.dist_class( self._dist_inputs, self.model).logp(self._action_input) @@ -250,6 +263,11 @@ class TFPolicy(Policy): summarize(self._loss_input_dict))) self._sess.run(tf.global_variables_initializer()) + self._optimizer_variables = None + if self._optimizer: + self._optimizer_variables = \ + ray.experimental.tf_utils.TensorFlowVariables( + self._optimizer.variables(), self._sess) @override(Policy) def compute_actions(self, @@ -355,6 +373,26 @@ class TFPolicy(Policy): def set_weights(self, weights): return self._variables.set_weights(weights) + @override(Policy) + def get_state(self): + # For tf Policies, return Policy weights and optimizer var values. + state = super().get_state() + if self._optimizer_variables and \ + len(self._optimizer_variables.variables) > 0: + state["_optimizer_variables"] = \ + self._sess.run(self._optimizer_variables.variables) + return state + + @override(Policy) + def set_state(self, state): + state = state.copy() # shallow copy + # Set optimizer vars first. + optimizer_vars = state.pop("_optimizer_variables", None) + if optimizer_vars: + self._optimizer_variables.set_weights(optimizer_vars) + # Then the Policy's (NN) weights. + super().set_state(state) + @override(Policy) def export_model(self, export_dir): """Export tensorflow graph to export_dir for serving.""" @@ -441,7 +479,7 @@ class TFPolicy(Policy): def optimizer(self): """TF optimizer to use for policy optimization.""" if hasattr(self, "config"): - return tf.train.AdamOptimizer(self.config["lr"]) + return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) else: return tf.train.AdamOptimizer() @@ -686,7 +724,7 @@ class LearningRateSchedule: @override(TFPolicy) def optimizer(self): - return tf.train.AdamOptimizer(self.cur_lr) + return tf.train.AdamOptimizer(learning_rate=self.cur_lr) @DeveloperAPI diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index a8f8be298..c13ff5f8a 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -323,6 +323,26 @@ class TorchPolicy(Policy): weights = convert_to_torch_tensor(weights, device=self.device) self.model.load_state_dict(weights) + @override(Policy) + def get_state(self): + state = super().get_state() + state["_optimizer_variables"] = [] + for i, o in enumerate(self._optimizers): + state["_optimizer_variables"].append(o.state_dict()) + return state + + @override(Policy) + def set_state(self, state): + state = state.copy() # shallow copy + # Set optimizer vars first. + optimizer_vars = state.pop("_optimizer_variables", None) + if optimizer_vars: + assert len(optimizer_vars) == len(self._optimizers) + for o, s in zip(self._optimizers, optimizer_vars): + o.load_state_dict(s) + # Then the Policy's (NN) weights. + super().set_state(state) + @override(Policy) def is_recurrent(self): return len(self.model.get_initial_state()) > 0 diff --git a/rllib/tests/test_checkpoint_restore.py b/rllib/tests/test_checkpoint_restore.py index b05c59608..994a07b84 100644 --- a/rllib/tests/test_checkpoint_restore.py +++ b/rllib/tests/test_checkpoint_restore.py @@ -5,7 +5,7 @@ import unittest import ray from ray.rllib.agents.registry import get_agent_class -from ray.rllib.utils.test_utils import framework_iterator +from ray.rllib.utils.test_utils import check, framework_iterator def get_mean_action(alg, obs): @@ -63,45 +63,67 @@ CONFIGS = { } -def ckpt_restore_test(use_object_store, alg_name, failures, framework="tf"): - cls = get_agent_class(alg_name) +def ckpt_restore_test(alg_name, tfe=False): config = CONFIGS[alg_name] - config["framework"] = framework - if "DDPG" in alg_name or "SAC" in alg_name: - alg1 = cls(config=config, env="Pendulum-v0") - alg2 = cls(config=config, env="Pendulum-v0") - else: - alg1 = cls(config=config, env="CartPole-v0") - alg2 = cls(config=config, env="CartPole-v0") + frameworks = (["tfe"] if tfe else []) + ["torch", "tf"] + for fw in framework_iterator(config, frameworks=frameworks): + for use_object_store in [False, True]: + print("use_object_store={}".format(use_object_store)) + cls = get_agent_class(alg_name) + if "DDPG" in alg_name or "SAC" in alg_name: + alg1 = cls(config=config, env="Pendulum-v0") + alg2 = cls(config=config, env="Pendulum-v0") + else: + alg1 = cls(config=config, env="CartPole-v0") + alg2 = cls(config=config, env="CartPole-v0") - policy1 = alg1.get_policy() + policy1 = alg1.get_policy() - for _ in range(1): - res = alg1.train() - print("current status: " + str(res)) + for _ in range(1): + res = alg1.train() + print("current status: " + str(res)) - # Sync the models - if use_object_store: - alg2.restore_from_object(alg1.save_to_object()) - else: - alg2.restore(alg1.save()) + # Check optimizer state as well. + optim_state = policy1.get_state().get("_optimizer_variables") - for _ in range(1): - if "DDPG" in alg_name or "SAC" in alg_name: - obs = np.clip( - np.random.uniform(size=3), - policy1.observation_space.low, - policy1.observation_space.high) - else: - obs = np.clip( - np.random.uniform(size=4), - policy1.observation_space.low, - policy1.observation_space.high) - a1 = get_mean_action(alg1, obs) - a2 = get_mean_action(alg2, obs) - print("Checking computed actions", alg1, obs, a1, a2) - if abs(a1 - a2) > .1: - failures.append((alg_name, [a1, a2])) + # Sync the models + if use_object_store: + alg2.restore_from_object(alg1.save_to_object()) + else: + alg2.restore(alg1.save()) + + # Compare optimizer state with re-loaded one. + if optim_state: + s2 = alg2.get_policy().get_state().get("_optimizer_variables") + # Tf -> Compare states 1:1. + if fw in ["tf", "tfe"]: + check(s2, optim_state) + # For torch, optimizers have state_dicts with keys=params, + # which are different for the two models (ignore these + # different keys, but compare all values nevertheless). + else: + for i, s2_ in enumerate(s2): + check( + list(s2_["state"].values()), + list(optim_state[i]["state"].values())) + + for _ in range(1): + if "DDPG" in alg_name or "SAC" in alg_name: + obs = np.clip( + np.random.uniform(size=3), + policy1.observation_space.low, + policy1.observation_space.high) + else: + obs = np.clip( + np.random.uniform(size=4), + policy1.observation_space.low, + policy1.observation_space.high) + a1 = get_mean_action(alg1, obs) + a2 = get_mean_action(alg2, obs) + print("Checking computed actions", alg1, obs, a1, a2) + if abs(a1 - a2) > .1: + raise AssertionError("algo={} [a1={} a2={}]".format( + alg_name, a1, a2)) class TestCheckpointRestore(unittest.TestCase): @@ -113,21 +135,29 @@ class TestCheckpointRestore(unittest.TestCase): def tearDownClass(cls): ray.shutdown() - def test_checkpoint_restore(self): - failures = [] - for fw in framework_iterator(frameworks=("tf", "torch")): - for use_object_store in [False, True]: - for name in [ - "A3C", "APEX_DDPG", "ARS", "DDPG", "DQN", "ES", "PPO", - "SAC" - ]: - print("Testing algo={} (use_object_store={})".format( - name, use_object_store)) - ckpt_restore_test( - use_object_store, name, failures, framework=fw) + def test_a3c_checkpoint_restore(self): + ckpt_restore_test("A3C") - assert not failures, failures - print("All checkpoint restore tests passed!") + def test_apex_ddpg_checkpoint_restore(self): + ckpt_restore_test("APEX_DDPG") + + def test_ars_checkpoint_restore(self): + ckpt_restore_test("ARS") + + def test_ddpg_checkpoint_restore(self): + ckpt_restore_test("DDPG") + + def test_dqn_checkpoint_restore(self): + ckpt_restore_test("DQN") + + def test_es_checkpoint_restore(self): + ckpt_restore_test("ES") + + def test_ppo_checkpoint_restore(self): + ckpt_restore_test("PPO") + + def test_sac_checkpoint_restore(self): + ckpt_restore_test("SAC") if __name__ == "__main__": diff --git a/rllib/tests/test_local.py b/rllib/tests/test_local.py index c971bc632..aabce3c3c 100644 --- a/rllib/tests/test_local.py +++ b/rllib/tests/test_local.py @@ -14,9 +14,10 @@ class LocalModeTest(unittest.TestCase): def test_local(self): cf = DEFAULT_CONFIG.copy() - for fw in framework_iterator(cf): + for _ in framework_iterator(cf): agent = PPOTrainer(cf, "CartPole-v0") print(agent.train()) + agent.stop() if __name__ == "__main__": diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 52bdc7cff..ef02f8898 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -58,9 +58,9 @@ def minimize_and_clip(optimizer, clip_val=10): torch.nn.utils.clip_grad_norm_(p.grad, clip_val) -def sequence_mask(lengths, maxlen, dtype=None): - """ - Exact same behavior as tf.sequence_mask. +def sequence_mask(lengths, maxlen=None, dtype=None): + """Offers same behavior as tf.sequence_mask for torch. + Thanks to Dimitris Papatheodorou (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ 39036).