diff --git a/rllib/BUILD b/rllib/BUILD index 05c09d85d..431f6b75a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1466,29 +1466,29 @@ py_test( args = ["TestSupportedMultiAgentPG"] ) -#py_test( -# name = "tests/test_supported_multi_agent_off_policy", -# main = "tests/test_supported_multi_agent.py", -# tags = ["tests_dir", "tests_dir_S"], -# size = "medium", -# srcs = ["tests/test_supported_multi_agent.py"], -# args = ["TestSupportedMultiAgentOffPolicy"] -#) - py_test( - name = "tests/test_supported_spaces_pg", - main = "tests/test_supported_spaces.py", + name = "tests/test_supported_multi_agent_off_policy", + main = "tests/test_supported_multi_agent.py", tags = ["tests_dir", "tests_dir_S"], - size = "enormous", - srcs = ["tests/test_supported_spaces.py"], - args = ["TestSupportedSpacesPG"] + size = "medium", + srcs = ["tests/test_supported_multi_agent.py"], + args = ["TestSupportedMultiAgentOffPolicy"] ) +# py_test( +# name = "tests/test_supported_spaces_pg", +# main = "tests/test_supported_spaces.py", +# tags = ["tests_dir", "tests_dir_S"], +# size = "enormous", +# srcs = ["tests/test_supported_spaces.py"], +# args = ["TestSupportedSpacesPG"] +# ) + py_test( name = "tests/test_supported_spaces_off_policy", main = "tests/test_supported_spaces.py", tags = ["tests_dir", "tests_dir_S"], - size = "enormous", + size = "medium", srcs = ["tests/test_supported_spaces.py"], args = ["TestSupportedSpacesOffPolicy"] ) @@ -1497,7 +1497,7 @@ py_test( name = "tests/test_supported_spaces_evolution_algos", main = "tests/test_supported_spaces.py", tags = ["tests_dir", "tests_dir_S"], - size = "large", + size = "medium", srcs = ["tests/test_supported_spaces.py"], args = ["TestSupportedSpacesEvolutionAlgos"] ) diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index e84cf4148..1972fabec 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -7,7 +7,6 @@ from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.typing import AgentID, PolicyID -from ray.util.debug import log_once if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -56,10 +55,6 @@ class DefaultCallbacks: kwargs: Forward compatibility placeholder. """ - if env_index is not None: - if log_once("callbacks_env_index_deprecated"): - deprecation_warning("env_index", "episode.env_id", error=False) - if self.legacy_callbacks.get("on_episode_start"): self.legacy_callbacks["on_episode_start"]({ "env": base_env, @@ -89,10 +84,6 @@ class DefaultCallbacks: kwargs: Forward compatibility placeholder. """ - if env_index is not None: - if log_once("callbacks_env_index_deprecated"): - deprecation_warning("env_index", "episode.env_id", error=False) - if self.legacy_callbacks.get("on_episode_step"): self.legacy_callbacks["on_episode_step"]({ "env": base_env, @@ -124,10 +115,6 @@ class DefaultCallbacks: kwargs: Forward compatibility placeholder. """ - if env_index is not None: - if log_once("callbacks_env_index_deprecated"): - deprecation_warning("env_index", "episode.env_id", error=False) - if self.legacy_callbacks.get("on_episode_end"): self.legacy_callbacks["on_episode_end"]({ "env": base_env, @@ -188,7 +175,7 @@ class DefaultCallbacks: }) def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, - **kwargs) -> None: + result: dict, **kwargs) -> None: """Called at the beginning of Policy.learn_on_batch(). Note: This is called before 0-padding via @@ -198,6 +185,7 @@ class DefaultCallbacks: policy (Policy): Reference to the current Policy object. train_batch (SampleBatch): SampleBatch to be trained on. You can mutate this object to modify the samples generated. + result (dict): A results dict to add custom metrics to. kwargs: Forward compatibility placeholder. """ diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index afb3ec9ee..a0b3caa10 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -51,7 +51,7 @@ class TestMARWIL(unittest.TestCase): min_reward = 70.0 # Test for all frameworks. - for _ in framework_iterator(config): + for _ in framework_iterator(config, frameworks=("tf", "torch")): trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0") learnt = False for i in range(num_iterations): diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index e2c56b521..b457f1e94 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -231,6 +231,8 @@ class SACTFModel(TFModelV2): if isinstance(net.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = tf.concat(model_out, axis=-1) + elif isinstance(model_out, dict): + model_out = tf.concat(list(model_out.values()), axis=-1) elif isinstance(model_out, dict): model_out = list(model_out.values()) diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index f3fe34e23..1288d20da 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -237,6 +237,8 @@ class SACTorchModel(TorchModelV2, nn.Module): if isinstance(net.obs_space, Box): if isinstance(model_out, (list, tuple)): model_out = torch.cat(model_out, dim=-1) + elif isinstance(model_out, dict): + model_out = torch.cat(list(model_out.values()), dim=-1) elif isinstance(model_out, dict): model_out = list(model_out.values()) diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 232f74f1a..39a85a5cf 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -17,7 +17,6 @@ from ray.rllib.utils.typing import MultiAgentDict, EnvInfoDict, EnvObsType, \ EnvActionType logger = logging.getLogger(__name__) -logger.setLevel("INFO") # TODO(ekl) seems to be needed for cartpole_client.py try: import requests # `requests` is not part of stdlib. diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 45c2a00d2..952130ac5 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -13,7 +13,6 @@ from ray.rllib.env.policy_client import PolicyClient, \ from ray.rllib.utils.annotations import override, PublicAPI logger = logging.getLogger(__name__) -logger.setLevel("INFO") # TODO(ekl) this is needed for cartpole_server.py class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader): diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index 6ed723b15..e44b301f4 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -1,7 +1,7 @@ import logging import numpy as np import collections -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import ray from ray.rllib.evaluation.rollout_metrics import RolloutMetrics @@ -14,6 +14,19 @@ from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict logger = logging.getLogger(__name__) +def extract_stats(stats: Dict, key: str) -> Dict[str, Any]: + if key in stats: + return stats[key] + + multiagent_stats = {} + for k, v in stats.items(): + if isinstance(v, dict): + if key in v: + multiagent_stats[k] = v[key] + + return multiagent_stats + + @DeveloperAPI def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: """Return optimization stats reported from the policy. diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 1601e07f3..1c56ef2b9 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -25,7 +25,7 @@ from ray.rllib.utils.test_utils import framework_iterator, check class MyCallbacks(DefaultCallbacks): @override(DefaultCallbacks) - def on_learn_on_batch(self, *, policy, train_batch, **kwargs): + def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs): assert train_batch.count == 201 assert sum(train_batch.seq_lens) == 201 for k, v in train_batch.data.items(): diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index 745a94029..ecbe99bd7 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -59,6 +59,12 @@ class MyCallbacks(DefaultCallbacks): # you can mutate the result dict to add new fields to return result["callback_ok"] = True + def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, + result: dict, **kwargs) -> None: + result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"]) + print("policy.learn_on_batch() result: {} -> sum actions: {}".format( + policy, result["sum_actions_in_train_batch"])) + def on_postprocess_trajectory( self, *, worker: RolloutWorker, episode: MultiAgentEpisode, agent_id: str, policy_id: str, policies: Dict[str, Policy], @@ -88,7 +94,7 @@ if __name__ == "__main__": "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), }).trials - # verify custom metrics for integration tests + # Verify episode-related custom metrics are there. custom_metrics = trials[0].last_result["custom_metrics"] print(custom_metrics) assert "pole_angle_mean" in custom_metrics @@ -96,3 +102,8 @@ if __name__ == "__main__": assert "pole_angle_max" in custom_metrics assert "num_batches_mean" in custom_metrics assert "callback_ok" in trials[0].last_result + + # Verify `on_learn_on_batch` custom metrics are there (per policy). + info_custom_metrics = custom_metrics["default_policy"] + print(info_custom_metrics) + assert "sum_actions_in_train_batch" in info_custom_metrics diff --git a/rllib/examples/serving/cartpole_client.py b/rllib/examples/serving/cartpole_client.py index 3541e0f6f..f2d45b5b3 100755 --- a/rllib/examples/serving/cartpole_client.py +++ b/rllib/examples/serving/cartpole_client.py @@ -17,7 +17,7 @@ parser = argparse.ArgumentParser() parser.add_argument( "--no-train", action="store_true", help="Whether to disable training.") parser.add_argument( - "--inference-mode", type=str, required=True, choices=["local", "remote"]) + "--inference-mode", type=str, default="local", choices=["local", "remote"]) parser.add_argument( "--off-policy", action="store_true", diff --git a/rllib/examples/serving/cartpole_server.py b/rllib/examples/serving/cartpole_server.py index 297320422..f76a34a91 100755 --- a/rllib/examples/serving/cartpole_server.py +++ b/rllib/examples/serving/cartpole_server.py @@ -13,6 +13,7 @@ import ray from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.env.policy_server_input import PolicyServerInput +from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks from ray.tune.logger import pretty_print SERVER_ADDRESS = "localhost" @@ -43,6 +44,7 @@ if __name__ == "__main__": "num_workers": 0, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], + "callbacks": MyCallbacks, } if args.run == "DQN": diff --git a/rllib/execution/metric_ops.py b/rllib/execution/metric_ops.py index 70ae38e3f..06857f674 100644 --- a/rllib/execution/metric_ops.py +++ b/rllib/execution/metric_ops.py @@ -88,6 +88,7 @@ class CollectMetrics: # Add in iterator metrics. metrics = _get_shared_metrics() + custom_metrics_from_info = metrics.info.pop("custom_metrics", {}) timers = {} counters = {} info = {} @@ -106,6 +107,8 @@ class CollectMetrics: res["timers"] = timers res["info"] = info res["info"].update(counters) + res["custom_metrics"] = res.get("custom_metrics", {}) + res["custom_metrics"].update(custom_metrics_from_info) return res diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index e2411ed32..fe8e7b95b 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -5,7 +5,8 @@ import math from typing import List, Tuple, Any import ray -from ray.rllib.evaluation.metrics import get_learner_stats, LEARNER_STATS_KEY +from ray.rllib.evaluation.metrics import extract_stats, get_learner_stats, \ + LEARNER_STATS_KEY from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import \ STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \ @@ -58,18 +59,25 @@ class TrainOneStep: learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: - w = self.workers.local_worker() + lw = self.workers.local_worker() info = do_minibatch_sgd( - batch, {p: w.get_policy(p) - for p in self.policies}, w, self.num_sgd_iter, + batch, {pid: lw.get_policy(pid) + for pid in self.policies}, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here + # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch + # callback (shouldn't). metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) - metrics.info[LEARNER_INFO] = get_learner_stats(info) + metrics.info[LEARNER_INFO] = extract_stats( + info, LEARNER_STATS_KEY) + metrics.info["custom_metrics"] = extract_stats( + info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count + # Update weights - after learning on the local worker - on all remote + # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 1e1f42c05..050e655ca 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -320,8 +320,11 @@ def build_eager_tf_policy(name, @override(Policy) def learn_on_batch(self, postprocessed_batch): # Callback handling. + learn_stats = {} self.callbacks.on_learn_on_batch( - policy=self, train_batch=postprocessed_batch) + policy=self, + train_batch=postprocessed_batch, + result=learn_stats) pad_batch_to_sequences_of_same_size( postprocessed_batch, @@ -333,7 +336,9 @@ def build_eager_tf_policy(name, self._is_training = True postprocessed_batch["is_training"] = True - return self._learn_on_batch_eager(postprocessed_batch) + stats = self._learn_on_batch_eager(postprocessed_batch) + stats.update({"custom_metrics": learn_stats}) + return stats @convert_eager_inputs @convert_eager_outputs diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 3ac644415..f16f3f72a 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -423,9 +423,18 @@ class TFPolicy(Policy): def learn_on_batch( self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: assert self.loss_initialized() + builder = TFRunBuilder(self._sess, "learn_on_batch") + + # Callback handling. + learn_stats = {} + self.callbacks.on_learn_on_batch( + policy=self, train_batch=postprocessed_batch, result=learn_stats) + fetches = self._build_learn_on_batch(builder, postprocessed_batch) - return builder.get(fetches) + stats = builder.get(fetches) + stats.update({"custom_metrics": learn_stats}) + return stats @override(Policy) @DeveloperAPI @@ -841,10 +850,6 @@ class TFPolicy(Policy): def _build_learn_on_batch(self, builder, postprocessed_batch): self._debug_vars() - # Callback handling. - self.callbacks.on_learn_on_batch( - policy=self, train_batch=postprocessed_batch) - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) builder.add_feed_dict( self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index e492a5048..7ff26dfda 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -347,8 +347,9 @@ class TorchPolicy(Policy): if self.model: self.model.train() # Callback handling. + learn_stats = {} self.callbacks.on_learn_on_batch( - policy=self, train_batch=postprocessed_batch) + policy=self, train_batch=postprocessed_batch, result=learn_stats) # Compute gradients (will calculate all losses and `backward()` # them to get the grads). @@ -360,6 +361,7 @@ class TorchPolicy(Policy): if self.model: fetches["model"] = self.model.metrics() + fetches.update({"custom_metrics": learn_stats}) return fetches diff --git a/rllib/tests/test_supported_multi_agent.py b/rllib/tests/test_supported_multi_agent.py index 933c2d608..0f4063bb2 100644 --- a/rllib/tests/test_supported_multi_agent.py +++ b/rllib/tests/test_supported_multi_agent.py @@ -66,7 +66,7 @@ class TestSupportedMultiAgentPG(unittest.TestCase): class TestSupportedMultiAgentOffPolicy(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - ray.init(num_cpus=4) + ray.init(num_cpus=6) @classmethod def tearDownClass(cls) -> None: @@ -82,6 +82,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase): "min_iter_time_s": 1, "learning_starts": 10, "target_network_update_freq": 100, + "optimizer": { + "num_replay_buffer_shards": 1, + }, }) def test_apex_ddpg_multiagent(self): diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 05b90cba5..9da624927 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -47,6 +47,8 @@ OBSERVATION_SPACES_TO_TEST = { def check_support(alg, config, train=True, check_bounds=False, tfe=False): config["log_level"] = "ERROR" + config["train_batch_size"] = 10 + config["rollout_fragment_length"] = 10 def _do_check(alg, config, a_name, o_name): fw = config["framework"] @@ -88,25 +90,24 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False): frameworks = ("tf", "torch") if tfe: - frameworks += ("tfe", ) + frameworks += ("tf2", "tfe") for _ in framework_iterator(config, frameworks=frameworks): - # Check all action spaces (using a discrete obs-space). - for a_name in ACTION_SPACES_TO_TEST.keys(): - _do_check(alg, config, a_name, "discrete") - # Check all obs spaces (using a supported action-space). - for o_name in OBSERVATION_SPACES_TO_TEST.keys(): - # We already tested discrete observation spaces against all action - # spaces above -> skip. - if o_name == "discrete": - continue - a_name = "discrete" if alg not in ["DDPG", "SAC"] else "vector" + # Zip through action- and obs-spaces. + for a_name, o_name in zip(ACTION_SPACES_TO_TEST.keys(), + OBSERVATION_SPACES_TO_TEST.keys()): _do_check(alg, config, a_name, o_name) + # Do the remaining obs spaces. + assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST) + for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()): + if i < len(ACTION_SPACES_TO_TEST): + continue + _do_check(alg, config, "discrete", o_name) class TestSupportedSpacesPG(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - ray.init(num_cpus=4) + ray.init(num_cpus=6) @classmethod def tearDownClass(cls) -> None: @@ -125,11 +126,11 @@ class TestSupportedSpacesPG(unittest.TestCase): def test_ppo(self): config = { - "num_workers": 1, - "num_sgd_iter": 1, - "train_batch_size": 10, + "num_workers": 0, + "train_batch_size": 100, "rollout_fragment_length": 10, - "sgd_minibatch_size": 1, + "num_sgd_iter": 1, + "sgd_minibatch_size": 10, } check_support("PPO", config, check_bounds=True, tfe=True) diff --git a/rllib/utils/sgd.py b/rllib/utils/sgd.py index b5b72d44d..787b885cd 100644 --- a/rllib/utils/sgd.py +++ b/rllib/utils/sgd.py @@ -104,12 +104,12 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, """Execute minibatch SGD. Args: - samples (SampleBatch): batch of samples to optimize. - policies (dict): dictionary of policies to optimize. - local_worker (RolloutWorker): master rollout worker instance. - num_sgd_iter (int): number of epochs of optimization to take. - sgd_minibatch_size (int): size of minibatches to use for optimization. - standardize_fields (list): list of sample field names that should be + samples (SampleBatch): Batch of samples to optimize. + policies (dict): Dictionary of policies to optimize. + local_worker (RolloutWorker): Master rollout worker instance. + num_sgd_iter (int): Number of epochs of optimization to take. + sgd_minibatch_size (int): Size of minibatches to use for optimization. + standardize_fields (list): List of sample field names that should be normalized prior to optimization. Returns: