[rllib] train-eval loop implementation for rllib.Trainer class (#4647)

This commit is contained in:
Andrew
2019-04-21 22:08:04 +03:00
committed by Eric Liang
parent d951eb740f
commit 06c768823c
3 changed files with 104 additions and 41 deletions
-24
View File
@@ -48,15 +48,6 @@ DEFAULT_CONFIG = with_common_config({
# N-step Q learning
"n_step": 1,
# === Evaluation ===
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,
# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
@@ -208,16 +199,6 @@ class DQNTrainer(Trainer):
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy_graph)
if config["evaluation_interval"]:
self.evaluation_ev = self.make_local_evaluator(
env_creator,
self._policy_graph,
extra_config={
"batch_mode": "complete_episodes",
"batch_steps": 1,
})
self.evaluation_metrics = self._evaluate()
def create_remote_evaluators():
return self.make_remote_evaluators(env_creator, self._policy_graph,
config["num_workers"])
@@ -277,11 +258,6 @@ class DQNTrainer(Trainer):
"num_target_updates": self.num_target_updates,
}, **self.optimizer.stats()))
if self.config["evaluation_interval"]:
if self.iteration % self.config["evaluation_interval"] == 0:
self.evaluation_metrics = self._evaluate()
result.update(self.evaluation_metrics)
return result
def update_target_if_needed(self):
+65 -1
View File
@@ -21,6 +21,7 @@ from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
@@ -92,6 +93,20 @@ COMMON_CONFIG = {
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# === Evaluation ===
# Evaluate with every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,
# Extra arguments to pass to evaluation workers.
# Typical usage is to pass extra args to evaluation env creator
# and to disable exploration by computing deterministic actions
# TODO(kismuz): implement determ. actions and include relevant keys hints
"evaluation_config": {},
# === Resources ===
# Number of actors used for parallelism
"num_workers": 2,
@@ -250,7 +265,7 @@ class Trainer(Trainable):
_allow_unknown_configs = False
_allow_unknown_subkeys = [
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
"custom_resources_per_worker"
"custom_resources_per_worker", "evaluation_config"
]
@PublicAPI
@@ -352,6 +367,14 @@ class Trainer(Trainable):
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
if self.config["evaluation_interval"]:
if self._iteration % self.config["evaluation_interval"] == 0:
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)
return result
@override(Trainable)
@@ -393,6 +416,23 @@ class Trainer(Trainable):
with tf.Graph().as_default():
self._init(self.config, self.env_creator)
# Evaluation related
if self.config.get("evaluation_interval"):
# Update env_config with evaluation settings:
extra_config = copy.deepcopy(self.config["evaluation_config"])
extra_config.update({
"batch_mode": "complete_episodes",
"batch_steps": 1,
})
logger.debug(
"using evaluation_config: {}".format(extra_config))
# Make local evaluation evaluators
self.evaluation_ev = self.make_local_evaluator(
self.env_creator,
self._policy_graph,
extra_config=extra_config)
self.evaluation_metrics = self._evaluate()
@override(Trainable)
def _stop(self):
# Call stop on all evaluators to release resources
@@ -428,6 +468,30 @@ class Trainer(Trainable):
raise NotImplementedError
@DeveloperAPI
def _evaluate(self):
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
"""
if not self.config["evaluation_config"]:
raise ValueError(
"No evaluation_config specified. It doesn't make sense "
"to enable evaluation without specifying any config "
"overrides, since the results will be the "
"same as reported during normal policy evaluation.")
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self.evaluation_ev.restore(self.local_evaluator.save())
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_ev.sample()
metrics = collect_metrics(self.evaluation_ev)
return {"evaluation": metrics}
@PublicAPI
def compute_action(self,
observation,
+39 -16
View File
@@ -6,11 +6,14 @@ import unittest
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.a3c import A3CTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
from ray.tune.registry import register_env
import gym
class DQNTest(unittest.TestCase):
def testNStep(self):
class EvalTest(unittest.TestCase):
def testDqnNStep(self):
obs = [1, 2, 3, 4, 5, 6, 7]
actions = ["a", "b", "a", "a", "a", "b", "a"]
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
@@ -25,20 +28,40 @@ class DQNTest(unittest.TestCase):
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
def testEvaluationOption(self):
ray.init()
agent = DQNTrainer(
env="CartPole-v0", config={"evaluation_interval": 2})
r0 = agent.train()
r1 = agent.train()
r2 = agent.train()
r3 = agent.train()
r4 = agent.train()
self.assertTrue("evaluation" in r0)
self.assertTrue("episode_reward_mean" in r0["evaluation"])
self.assertEqual(r0["evaluation"], r1["evaluation"])
self.assertNotEqual(r1["evaluation"], r2["evaluation"])
self.assertEqual(r2["evaluation"], r3["evaluation"])
self.assertNotEqual(r3["evaluation"], r4["evaluation"])
def env_creator(env_config):
return gym.make("CartPole-v0")
agent_classes = [DQNTrainer, A3CTrainer]
for agent_cls in agent_classes:
ray.init()
register_env("CartPoleWrapped-v0", env_creator)
agent = agent_cls(
env="CartPoleWrapped-v0",
config={
"evaluation_interval": 2,
"evaluation_num_episodes": 2,
"evaluation_config": {
"gamma": 0.98,
"env_config": {
"fake_arg": True
}
},
})
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics while r1, r3 should do.
r0 = agent.train()
r1 = agent.train()
r2 = agent.train()
r3 = agent.train()
self.assertTrue("evaluation" in r1)
self.assertTrue("evaluation" in r3)
self.assertFalse("evaluation" in r0)
self.assertFalse("evaluation" in r2)
self.assertTrue("episode_reward_mean" in r1["evaluation"])
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
ray.shutdown()
if __name__ == "__main__":