[RLlib] Extend on_learn_on_batch callback to allow for custom metrics to be added. (#13584)

This commit is contained in:
Sven Mika
2021-02-08 15:02:19 +01:00
committed by GitHub
parent ebeee1d59a
commit eb0038612f
20 changed files with 116 additions and 73 deletions
+16 -16
View File
@@ -1466,29 +1466,29 @@ py_test(
args = ["TestSupportedMultiAgentPG"]
)
#py_test(
# name = "tests/test_supported_multi_agent_off_policy",
# main = "tests/test_supported_multi_agent.py",
# tags = ["tests_dir", "tests_dir_S"],
# size = "medium",
# srcs = ["tests/test_supported_multi_agent.py"],
# args = ["TestSupportedMultiAgentOffPolicy"]
#)
py_test(
name = "tests/test_supported_spaces_pg",
main = "tests/test_supported_spaces.py",
name = "tests/test_supported_multi_agent_off_policy",
main = "tests/test_supported_multi_agent.py",
tags = ["tests_dir", "tests_dir_S"],
size = "enormous",
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesPG"]
size = "medium",
srcs = ["tests/test_supported_multi_agent.py"],
args = ["TestSupportedMultiAgentOffPolicy"]
)
# py_test(
# name = "tests/test_supported_spaces_pg",
# main = "tests/test_supported_spaces.py",
# tags = ["tests_dir", "tests_dir_S"],
# size = "enormous",
# srcs = ["tests/test_supported_spaces.py"],
# args = ["TestSupportedSpacesPG"]
# )
py_test(
name = "tests/test_supported_spaces_off_policy",
main = "tests/test_supported_spaces.py",
tags = ["tests_dir", "tests_dir_S"],
size = "enormous",
size = "medium",
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesOffPolicy"]
)
@@ -1497,7 +1497,7 @@ py_test(
name = "tests/test_supported_spaces_evolution_algos",
main = "tests/test_supported_spaces.py",
tags = ["tests_dir", "tests_dir_S"],
size = "large",
size = "medium",
srcs = ["tests/test_supported_spaces.py"],
args = ["TestSupportedSpacesEvolutionAlgos"]
)
+2 -14
View File
@@ -7,7 +7,6 @@ from ray.rllib.evaluation import MultiAgentEpisode
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import AgentID, PolicyID
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.rllib.evaluation import RolloutWorker
@@ -56,10 +55,6 @@ class DefaultCallbacks:
kwargs: Forward compatibility placeholder.
"""
if env_index is not None:
if log_once("callbacks_env_index_deprecated"):
deprecation_warning("env_index", "episode.env_id", error=False)
if self.legacy_callbacks.get("on_episode_start"):
self.legacy_callbacks["on_episode_start"]({
"env": base_env,
@@ -89,10 +84,6 @@ class DefaultCallbacks:
kwargs: Forward compatibility placeholder.
"""
if env_index is not None:
if log_once("callbacks_env_index_deprecated"):
deprecation_warning("env_index", "episode.env_id", error=False)
if self.legacy_callbacks.get("on_episode_step"):
self.legacy_callbacks["on_episode_step"]({
"env": base_env,
@@ -124,10 +115,6 @@ class DefaultCallbacks:
kwargs: Forward compatibility placeholder.
"""
if env_index is not None:
if log_once("callbacks_env_index_deprecated"):
deprecation_warning("env_index", "episode.env_id", error=False)
if self.legacy_callbacks.get("on_episode_end"):
self.legacy_callbacks["on_episode_end"]({
"env": base_env,
@@ -188,7 +175,7 @@ class DefaultCallbacks:
})
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
**kwargs) -> None:
result: dict, **kwargs) -> None:
"""Called at the beginning of Policy.learn_on_batch().
Note: This is called before 0-padding via
@@ -198,6 +185,7 @@ class DefaultCallbacks:
policy (Policy): Reference to the current Policy object.
train_batch (SampleBatch): SampleBatch to be trained on. You can
mutate this object to modify the samples generated.
result (dict): A results dict to add custom metrics to.
kwargs: Forward compatibility placeholder.
"""
+1 -1
View File
@@ -51,7 +51,7 @@ class TestMARWIL(unittest.TestCase):
min_reward = 70.0
# Test for all frameworks.
for _ in framework_iterator(config):
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = marwil.MARWILTrainer(config=config, env="CartPole-v0")
learnt = False
for i in range(num_iterations):
+2
View File
@@ -231,6 +231,8 @@ class SACTFModel(TFModelV2):
if isinstance(net.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = tf.concat(model_out, axis=-1)
elif isinstance(model_out, dict):
model_out = tf.concat(list(model_out.values()), axis=-1)
elif isinstance(model_out, dict):
model_out = list(model_out.values())
+2
View File
@@ -237,6 +237,8 @@ class SACTorchModel(TorchModelV2, nn.Module):
if isinstance(net.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = torch.cat(model_out, dim=-1)
elif isinstance(model_out, dict):
model_out = torch.cat(list(model_out.values()), dim=-1)
elif isinstance(model_out, dict):
model_out = list(model_out.values())
-1
View File
@@ -17,7 +17,6 @@ from ray.rllib.utils.typing import MultiAgentDict, EnvInfoDict, EnvObsType, \
EnvActionType
logger = logging.getLogger(__name__)
logger.setLevel("INFO") # TODO(ekl) seems to be needed for cartpole_client.py
try:
import requests # `requests` is not part of stdlib.
-1
View File
@@ -13,7 +13,6 @@ from ray.rllib.env.policy_client import PolicyClient, \
from ray.rllib.utils.annotations import override, PublicAPI
logger = logging.getLogger(__name__)
logger.setLevel("INFO") # TODO(ekl) this is needed for cartpole_server.py
class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
+14 -1
View File
@@ -1,7 +1,7 @@
import logging
import numpy as np
import collections
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import ray
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
@@ -14,6 +14,19 @@ from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict
logger = logging.getLogger(__name__)
def extract_stats(stats: Dict, key: str) -> Dict[str, Any]:
if key in stats:
return stats[key]
multiagent_stats = {}
for k, v in stats.items():
if isinstance(v, dict):
if key in v:
multiagent_stats[k] = v[key]
return multiagent_stats
@DeveloperAPI
def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict:
"""Return optimization stats reported from the policy.
@@ -25,7 +25,7 @@ from ray.rllib.utils.test_utils import framework_iterator, check
class MyCallbacks(DefaultCallbacks):
@override(DefaultCallbacks)
def on_learn_on_batch(self, *, policy, train_batch, **kwargs):
def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
assert train_batch.count == 201
assert sum(train_batch.seq_lens) == 201
for k, v in train_batch.data.items():
+12 -1
View File
@@ -59,6 +59,12 @@ class MyCallbacks(DefaultCallbacks):
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
result: dict, **kwargs) -> None:
result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
policy, result["sum_actions_in_train_batch"]))
def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
@@ -88,7 +94,7 @@ if __name__ == "__main__":
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
}).trials
# verify custom metrics for integration tests
# Verify episode-related custom metrics are there.
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
@@ -96,3 +102,8 @@ if __name__ == "__main__":
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert "callback_ok" in trials[0].last_result
# Verify `on_learn_on_batch` custom metrics are there (per policy).
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics
+1 -1
View File
@@ -17,7 +17,7 @@ parser = argparse.ArgumentParser()
parser.add_argument(
"--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
"--inference-mode", type=str, required=True, choices=["local", "remote"])
"--inference-mode", type=str, default="local", choices=["local", "remote"])
parser.add_argument(
"--off-policy",
action="store_true",
@@ -13,6 +13,7 @@ import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print
SERVER_ADDRESS = "localhost"
@@ -43,6 +44,7 @@ if __name__ == "__main__":
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
"callbacks": MyCallbacks,
}
if args.run == "DQN":
+3
View File
@@ -88,6 +88,7 @@ class CollectMetrics:
# Add in iterator metrics.
metrics = _get_shared_metrics()
custom_metrics_from_info = metrics.info.pop("custom_metrics", {})
timers = {}
counters = {}
info = {}
@@ -106,6 +107,8 @@ class CollectMetrics:
res["timers"] = timers
res["info"] = info
res["info"].update(counters)
res["custom_metrics"] = res.get("custom_metrics", {})
res["custom_metrics"].update(custom_metrics_from_info)
return res
+13 -5
View File
@@ -5,7 +5,8 @@ import math
from typing import List, Tuple, Any
import ray
from ray.rllib.evaluation.metrics import get_learner_stats, LEARNER_STATS_KEY
from ray.rllib.evaluation.metrics import extract_stats, get_learner_stats, \
LEARNER_STATS_KEY
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import \
STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, LEARNER_INFO, \
@@ -58,18 +59,25 @@ class TrainOneStep:
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
with learn_timer:
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
w = self.workers.local_worker()
lw = self.workers.local_worker()
info = do_minibatch_sgd(
batch, {p: w.get_policy(p)
for p in self.policies}, w, self.num_sgd_iter,
batch, {pid: lw.get_policy(pid)
for pid in self.policies}, lw, self.num_sgd_iter,
self.sgd_minibatch_size, [])
# TODO(ekl) shouldn't be returning learner stats directly here
# TODO(sven): Skips `custom_metrics` key from on_learn_on_batch
# callback (shouldn't).
metrics.info[LEARNER_INFO] = info
else:
info = self.workers.local_worker().learn_on_batch(batch)
metrics.info[LEARNER_INFO] = get_learner_stats(info)
metrics.info[LEARNER_INFO] = extract_stats(
info, LEARNER_STATS_KEY)
metrics.info["custom_metrics"] = extract_stats(
info, "custom_metrics")
learn_timer.push_units_processed(batch.count)
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights(
+7 -2
View File
@@ -320,8 +320,11 @@ def build_eager_tf_policy(name,
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch)
policy=self,
train_batch=postprocessed_batch,
result=learn_stats)
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
@@ -333,7 +336,9 @@ def build_eager_tf_policy(name,
self._is_training = True
postprocessed_batch["is_training"] = True
return self._learn_on_batch_eager(postprocessed_batch)
stats = self._learn_on_batch_eager(postprocessed_batch)
stats.update({"custom_metrics": learn_stats})
return stats
@convert_eager_inputs
@convert_eager_outputs
+10 -5
View File
@@ -423,9 +423,18 @@ class TFPolicy(Policy):
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
assert self.loss_initialized()
builder = TFRunBuilder(self._sess, "learn_on_batch")
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch, result=learn_stats)
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
stats = builder.get(fetches)
stats.update({"custom_metrics": learn_stats})
return stats
@override(Policy)
@DeveloperAPI
@@ -841,10 +850,6 @@ class TFPolicy(Policy):
def _build_learn_on_batch(self, builder, postprocessed_batch):
self._debug_vars()
# Callback handling.
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch)
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(
self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
+3 -1
View File
@@ -347,8 +347,9 @@ class TorchPolicy(Policy):
if self.model:
self.model.train()
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch)
policy=self, train_batch=postprocessed_batch, result=learn_stats)
# Compute gradients (will calculate all losses and `backward()`
# them to get the grads).
@@ -360,6 +361,7 @@ class TorchPolicy(Policy):
if self.model:
fetches["model"] = self.model.metrics()
fetches.update({"custom_metrics": learn_stats})
return fetches
+4 -1
View File
@@ -66,7 +66,7 @@ class TestSupportedMultiAgentPG(unittest.TestCase):
class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
ray.init(num_cpus=6)
@classmethod
def tearDownClass(cls) -> None:
@@ -82,6 +82,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
"min_iter_time_s": 1,
"learning_starts": 10,
"target_network_update_freq": 100,
"optimizer": {
"num_replay_buffer_shards": 1,
},
})
def test_apex_ddpg_multiagent(self):
+17 -16
View File
@@ -47,6 +47,8 @@ OBSERVATION_SPACES_TO_TEST = {
def check_support(alg, config, train=True, check_bounds=False, tfe=False):
config["log_level"] = "ERROR"
config["train_batch_size"] = 10
config["rollout_fragment_length"] = 10
def _do_check(alg, config, a_name, o_name):
fw = config["framework"]
@@ -88,25 +90,24 @@ def check_support(alg, config, train=True, check_bounds=False, tfe=False):
frameworks = ("tf", "torch")
if tfe:
frameworks += ("tfe", )
frameworks += ("tf2", "tfe")
for _ in framework_iterator(config, frameworks=frameworks):
# Check all action spaces (using a discrete obs-space).
for a_name in ACTION_SPACES_TO_TEST.keys():
_do_check(alg, config, a_name, "discrete")
# Check all obs spaces (using a supported action-space).
for o_name in OBSERVATION_SPACES_TO_TEST.keys():
# We already tested discrete observation spaces against all action
# spaces above -> skip.
if o_name == "discrete":
continue
a_name = "discrete" if alg not in ["DDPG", "SAC"] else "vector"
# Zip through action- and obs-spaces.
for a_name, o_name in zip(ACTION_SPACES_TO_TEST.keys(),
OBSERVATION_SPACES_TO_TEST.keys()):
_do_check(alg, config, a_name, o_name)
# Do the remaining obs spaces.
assert len(OBSERVATION_SPACES_TO_TEST) >= len(ACTION_SPACES_TO_TEST)
for i, o_name in enumerate(OBSERVATION_SPACES_TO_TEST.keys()):
if i < len(ACTION_SPACES_TO_TEST):
continue
_do_check(alg, config, "discrete", o_name)
class TestSupportedSpacesPG(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
ray.init(num_cpus=6)
@classmethod
def tearDownClass(cls) -> None:
@@ -125,11 +126,11 @@ class TestSupportedSpacesPG(unittest.TestCase):
def test_ppo(self):
config = {
"num_workers": 1,
"num_sgd_iter": 1,
"train_batch_size": 10,
"num_workers": 0,
"train_batch_size": 100,
"rollout_fragment_length": 10,
"sgd_minibatch_size": 1,
"num_sgd_iter": 1,
"sgd_minibatch_size": 10,
}
check_support("PPO", config, check_bounds=True, tfe=True)
+6 -6
View File
@@ -104,12 +104,12 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
"""Execute minibatch SGD.
Args:
samples (SampleBatch): batch of samples to optimize.
policies (dict): dictionary of policies to optimize.
local_worker (RolloutWorker): master rollout worker instance.
num_sgd_iter (int): number of epochs of optimization to take.
sgd_minibatch_size (int): size of minibatches to use for optimization.
standardize_fields (list): list of sample field names that should be
samples (SampleBatch): Batch of samples to optimize.
policies (dict): Dictionary of policies to optimize.
local_worker (RolloutWorker): Master rollout worker instance.
num_sgd_iter (int): Number of epochs of optimization to take.
sgd_minibatch_size (int): Size of minibatches to use for optimization.
standardize_fields (list): List of sample field names that should be
normalized prior to optimization.
Returns: