mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:39:44 +08:00
[rllib] train-eval loop implementation for rllib.Trainer class (#4647)
This commit is contained in:
@@ -48,15 +48,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# N-step Q learning
|
||||
"n_step": 1,
|
||||
|
||||
# === Evaluation ===
|
||||
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 10,
|
||||
|
||||
# === Exploration ===
|
||||
# Max num timesteps for annealing schedules. Exploration is annealed from
|
||||
# 1.0 to exploration_fraction over this number of timesteps scaled by
|
||||
@@ -208,16 +199,6 @@ class DQNTrainer(Trainer):
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy_graph)
|
||||
|
||||
if config["evaluation_interval"]:
|
||||
self.evaluation_ev = self.make_local_evaluator(
|
||||
env_creator,
|
||||
self._policy_graph,
|
||||
extra_config={
|
||||
"batch_mode": "complete_episodes",
|
||||
"batch_steps": 1,
|
||||
})
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
|
||||
def create_remote_evaluators():
|
||||
return self.make_remote_evaluators(env_creator, self._policy_graph,
|
||||
config["num_workers"])
|
||||
@@ -277,11 +258,6 @@ class DQNTrainer(Trainer):
|
||||
"num_target_updates": self.num_target_updates,
|
||||
}, **self.optimizer.stats()))
|
||||
|
||||
if self.config["evaluation_interval"]:
|
||||
if self.iteration % self.config["evaluation_interval"] == 0:
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
result.update(self.evaluation_metrics)
|
||||
|
||||
return result
|
||||
|
||||
def update_target_if_needed(self):
|
||||
|
||||
@@ -21,6 +21,7 @@ from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
@@ -92,6 +93,20 @@ COMMON_CONFIG = {
|
||||
# Whether to use rllib or deepmind preprocessors by default
|
||||
"preprocessor_pref": "deepmind",
|
||||
|
||||
# === Evaluation ===
|
||||
# Evaluate with every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 10,
|
||||
# Extra arguments to pass to evaluation workers.
|
||||
# Typical usage is to pass extra args to evaluation env creator
|
||||
# and to disable exploration by computing deterministic actions
|
||||
# TODO(kismuz): implement determ. actions and include relevant keys hints
|
||||
"evaluation_config": {},
|
||||
|
||||
# === Resources ===
|
||||
# Number of actors used for parallelism
|
||||
"num_workers": 2,
|
||||
@@ -250,7 +265,7 @@ class Trainer(Trainable):
|
||||
_allow_unknown_configs = False
|
||||
_allow_unknown_subkeys = [
|
||||
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
|
||||
"custom_resources_per_worker"
|
||||
"custom_resources_per_worker", "evaluation_config"
|
||||
]
|
||||
|
||||
@PublicAPI
|
||||
@@ -352,6 +367,14 @@ class Trainer(Trainable):
|
||||
if self._has_policy_optimizer():
|
||||
result["num_healthy_workers"] = len(
|
||||
self.optimizer.remote_evaluators)
|
||||
|
||||
if self.config["evaluation_interval"]:
|
||||
if self._iteration % self.config["evaluation_interval"] == 0:
|
||||
evaluation_metrics = self._evaluate()
|
||||
assert isinstance(evaluation_metrics, dict), \
|
||||
"_evaluate() needs to return a dict."
|
||||
result.update(evaluation_metrics)
|
||||
|
||||
return result
|
||||
|
||||
@override(Trainable)
|
||||
@@ -393,6 +416,23 @@ class Trainer(Trainable):
|
||||
with tf.Graph().as_default():
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
# Evaluation related
|
||||
if self.config.get("evaluation_interval"):
|
||||
# Update env_config with evaluation settings:
|
||||
extra_config = copy.deepcopy(self.config["evaluation_config"])
|
||||
extra_config.update({
|
||||
"batch_mode": "complete_episodes",
|
||||
"batch_steps": 1,
|
||||
})
|
||||
logger.debug(
|
||||
"using evaluation_config: {}".format(extra_config))
|
||||
# Make local evaluation evaluators
|
||||
self.evaluation_ev = self.make_local_evaluator(
|
||||
self.env_creator,
|
||||
self._policy_graph,
|
||||
extra_config=extra_config)
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
# Call stop on all evaluators to release resources
|
||||
@@ -428,6 +468,30 @@ class Trainer(Trainable):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def _evaluate(self):
|
||||
"""Evaluates current policy under `evaluation_config` settings.
|
||||
|
||||
Note that this default implementation does not do anything beyond
|
||||
merging evaluation_config with the normal trainer config.
|
||||
"""
|
||||
|
||||
if not self.config["evaluation_config"]:
|
||||
raise ValueError(
|
||||
"No evaluation_config specified. It doesn't make sense "
|
||||
"to enable evaluation without specifying any config "
|
||||
"overrides, since the results will be the "
|
||||
"same as reported during normal policy evaluation.")
|
||||
|
||||
logger.info("Evaluating current policy for {} episodes".format(
|
||||
self.config["evaluation_num_episodes"]))
|
||||
self.evaluation_ev.restore(self.local_evaluator.save())
|
||||
for _ in range(self.config["evaluation_num_episodes"]):
|
||||
self.evaluation_ev.sample()
|
||||
|
||||
metrics = collect_metrics(self.evaluation_ev)
|
||||
return {"evaluation": metrics}
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
observation,
|
||||
|
||||
@@ -6,11 +6,14 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.a3c import A3CTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
|
||||
from ray.tune.registry import register_env
|
||||
import gym
|
||||
|
||||
|
||||
class DQNTest(unittest.TestCase):
|
||||
def testNStep(self):
|
||||
class EvalTest(unittest.TestCase):
|
||||
def testDqnNStep(self):
|
||||
obs = [1, 2, 3, 4, 5, 6, 7]
|
||||
actions = ["a", "b", "a", "a", "a", "b", "a"]
|
||||
rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0]
|
||||
@@ -25,20 +28,40 @@ class DQNTest(unittest.TestCase):
|
||||
[91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0])
|
||||
|
||||
def testEvaluationOption(self):
|
||||
ray.init()
|
||||
agent = DQNTrainer(
|
||||
env="CartPole-v0", config={"evaluation_interval": 2})
|
||||
r0 = agent.train()
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
r3 = agent.train()
|
||||
r4 = agent.train()
|
||||
self.assertTrue("evaluation" in r0)
|
||||
self.assertTrue("episode_reward_mean" in r0["evaluation"])
|
||||
self.assertEqual(r0["evaluation"], r1["evaluation"])
|
||||
self.assertNotEqual(r1["evaluation"], r2["evaluation"])
|
||||
self.assertEqual(r2["evaluation"], r3["evaluation"])
|
||||
self.assertNotEqual(r3["evaluation"], r4["evaluation"])
|
||||
def env_creator(env_config):
|
||||
return gym.make("CartPole-v0")
|
||||
|
||||
agent_classes = [DQNTrainer, A3CTrainer]
|
||||
|
||||
for agent_cls in agent_classes:
|
||||
ray.init()
|
||||
register_env("CartPoleWrapped-v0", env_creator)
|
||||
agent = agent_cls(
|
||||
env="CartPoleWrapped-v0",
|
||||
config={
|
||||
"evaluation_interval": 2,
|
||||
"evaluation_num_episodes": 2,
|
||||
"evaluation_config": {
|
||||
"gamma": 0.98,
|
||||
"env_config": {
|
||||
"fake_arg": True
|
||||
}
|
||||
},
|
||||
})
|
||||
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
||||
# evaluation metrics while r1, r3 should do.
|
||||
r0 = agent.train()
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
r3 = agent.train()
|
||||
|
||||
self.assertTrue("evaluation" in r1)
|
||||
self.assertTrue("evaluation" in r3)
|
||||
self.assertFalse("evaluation" in r0)
|
||||
self.assertFalse("evaluation" in r2)
|
||||
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
||||
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user