[rllib] Remove "Common", cleanup some code (#2348)

This commit is contained in:
Richard Liaw
2018-07-08 13:03:53 -07:00
committed by GitHub
parent 1d05cd7077
commit 4d7da9f668
15 changed files with 58 additions and 95 deletions
+2 -2
View File
@@ -12,7 +12,7 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import SampleBatch
@@ -27,6 +27,6 @@ def _register_all():
_register_all()
__all__ = [
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch",
"PolicyGraph", "TFPolicyGraph", "PolicyEvaluator", "SampleBatch",
"AsyncVectorEnv", "MultiAgentEnv", "VectorEnv", "ServingEnv",
]
+3 -3
View File
@@ -9,7 +9,7 @@ import os
import pickle
import tensorflow as tf
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
@@ -115,13 +115,13 @@ class Agent(Trainable):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
CommonPolicyEvaluator, env_creator, policy_graph, 0)
PolicyEvaluator, env_creator, policy_graph, 0)
def make_remote_evaluators(
self, env_creator, policy_graph, count, remote_args):
"""Convenience method to return a number of remote evaluators."""
cls = CommonPolicyEvaluator.as_remote(**remote_args).remote
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i+1)
for i in range(count)]
+2 -2
View File
@@ -8,11 +8,11 @@ from six.moves import queue
import ray
from ray.rllib.agents.bc.experience_dataset import ExperienceDataset
from ray.rllib.agents.bc.policy import BCPolicy
from ray.rllib.evaluation.interface import PolicyEvaluator
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.models import ModelCatalog
class BCEvaluator(PolicyEvaluator):
class BCEvaluator(EvaluatorInterface):
def __init__(self, env_creator, config, logdir):
env = ModelCatalog.get_preprocessor_as_wrapper(env_creator(
config["env_config"]), config["model"])
+1 -1
View File
@@ -16,7 +16,7 @@ class AsyncVectorEnv(object):
can be sent back via send_actions().
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
conversions internally in CommonPolicyEvaluator, for example:
conversions internally in PolicyEvaluator, for example:
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
+6 -3
View File
@@ -1,14 +1,17 @@
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.interface import PolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch, \
SampleBatchBuilder, MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.metrics import collect_metrics
__all__ = [
"PolicyEvaluator", "CommonPolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
"compute_advantages", "collect_metrics"
]
+2 -43
View File
@@ -5,10 +5,10 @@ from __future__ import print_function
import os
class PolicyEvaluator(object):
class EvaluatorInterface(object):
"""This is the interface between policy optimizers and policy evaluation.
See also: CommonPolicyEvaluator
See also: PolicyEvaluator
"""
def sample(self):
@@ -109,44 +109,3 @@ class PolicyEvaluator(object):
"""Apply the given function to this evaluator instance."""
return func(self, *args)
class TFMultiGPUSupport(PolicyEvaluator):
"""The multi-GPU TF optimizer requires additional TF-specific support.
Attributes:
sess (Session): the tensorflow session associated with this evaluator.
"""
def tf_loss_inputs(self):
"""Returns a list of the input placeholders required for the loss.
For example, the following calls should work:
Returns:
list: a (name, placeholder) tuple for each loss input argument.
Each placeholder name must correspond to one of the SampleBatch
column keys returned by sample().
Examples:
>>> print(ev.tf_loss_inputs())
[("action", action_placeholder), ("reward", reward_placeholder)]
>>> print(ev.sample()[0].data.keys())
["action", "reward"]
"""
raise NotImplementedError
def build_tf_loss(self, input_placeholders):
"""Returns a new loss tensor graph for the specified inputs.
The graph must share vars with this Evaluator's policy model, so that
the multi-gpu optimizer can update the weights.
Examples:
>>> loss_inputs = ev.tf_loss_inputs()
>>> ev.build_tf_loss([ph for _, ph in loss_inputs])
"""
raise NotImplementedError
+1 -1
View File
@@ -10,7 +10,7 @@ from ray.tune.result import TrainingResult
def collect_metrics(local_evaluator, remote_evaluators=[]):
"""Gathers episode metrics from CommonPolicyEvaluator instances."""
"""Gathers episode metrics from PolicyEvaluator instances."""
episode_rewards = []
episode_lengths = []
@@ -14,7 +14,7 @@ from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.evaluation.interface import PolicyEvaluator
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
@@ -25,7 +25,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.tf_run_builder import TFRunBuilder
class CommonPolicyEvaluator(PolicyEvaluator):
class PolicyEvaluator(EvaluatorInterface):
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
This class wraps a policy graph instance and an environment class to
@@ -37,7 +37,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
Examples:
>>> # Create a policy evaluator and using it to collect experiences.
>>> evaluator = CommonPolicyEvaluator(
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy_graph=PGPolicyGraph)
>>> print(evaluator.sample())
@@ -47,7 +47,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
>>> # Creating policy evaluators using optimizer_cls.make().
>>> optimizer = SyncSamplesOptimizer.make(
... evaluator_cls=CommonPolicyEvaluator,
... evaluator_cls=PolicyEvaluator,
... evaluator_args={
... "env_creator": lambda _: gym.make("CartPole-v0"),
... "policy_graph": PGPolicyGraph,
@@ -56,7 +56,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
>>> for _ in range(10): optimizer.step()
>>> # Creating a multi-agent policy evaluator
>>> evaluator = CommonPolicyEvaluator(
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policy_graphs={
... # Use an ensemble of two policies for car agents
+1 -1
View File
@@ -25,7 +25,7 @@ class PolicyGraph(object):
"""Initialize the graph.
This is the standard constructor for policy graphs. The policy graph
class you pass into CommonPolicyEvaluator will be constructed with
class you pass into PolicyEvaluator will be constructed with
these arguments.
Args:
+1 -1
View File
@@ -7,6 +7,6 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.lstm import LSTM
__all__ = ["ActionDistribution", "ActionDistribution", "Categorical",
__all__ = ["ActionDistribution", "Categorical",
"DiagGaussian", "Deterministic", "ModelCatalog", "Model",
"Preprocessor", "FullyConnectedNetwork", "LSTM"]
+2 -1
View File
@@ -9,4 +9,5 @@ from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
__all__ = [
"PolicyOptimizer", "AsyncSamplesOptimizer", "AsyncGradientsOptimizer",
"SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer"]
"SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer"
]
@@ -12,9 +12,9 @@ from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer, \
SyncReplayOptimizer, AsyncGradientsOptimizer
from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
MockPolicyGraph
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -205,7 +205,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSample(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
@@ -224,7 +224,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
@@ -283,13 +283,13 @@ class TestMultiAgentEnv(unittest.TestCase):
"p1": (PGPolicyGraph, obs_space, act_space, {}),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
}
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
batch_steps=50)
if optimizer_cls == AsyncGradientsOptimizer:
remote_evs = [CommonPolicyEvaluator.as_remote().remote(
remote_evs = [PolicyEvaluator.as_remote().remote(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
@@ -333,7 +333,7 @@ class TestMultiAgentEnv(unittest.TestCase):
policies["pg_{}".format(i)] = (
PGPolicyGraph, obs_space, act_space, {})
policy_ids = list(policies.keys())
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
@@ -8,7 +8,7 @@ import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.postprocessing import compute_advantages
@@ -89,9 +89,9 @@ class MockVectorEnv(VectorEnv):
return obs_batch, rew_batch, done_batch, info_batch
class TestCommonPolicyEvaluator(unittest.TestCase):
class TestPolicyEvaluator(unittest.TestCase):
def testBasic(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph)
batch = ev.sample()
@@ -110,10 +110,10 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
def testMetrics(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
remote_ev = CommonPolicyEvaluator.as_remote().remote(
remote_ev = PolicyEvaluator.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
ev.sample()
@@ -123,7 +123,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(result.episode_reward_mean, 10)
def testAsync(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy_graph=MockPolicyGraph)
@@ -133,7 +133,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertGreater(batch["advantages"][0], 1)
def testAutoConcat(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=40),
policy_graph=MockPolicyGraph,
sample_async=True,
@@ -145,7 +145,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
def testAutoVectorization(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=20),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
@@ -164,14 +164,14 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
def testBatchDivisibilityCheck(self):
self.assertRaises(
ValueError,
lambda: CommonPolicyEvaluator(
lambda: PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
batch_steps=15, num_envs=4))
def testBatchesSmallerWhenVectorized(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
batch_mode="truncate_episodes",
@@ -185,7 +185,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(result.episodes_total, 4)
def testVectorEnvSupport(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockVectorEnv(
episode_length=20, num_envs=8),
policy_graph=MockPolicyGraph,
@@ -203,7 +203,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(result.episodes_total, 8)
def testTruncateEpisodes(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=15,
@@ -212,7 +212,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 15)
def testCompleteEpisodes(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=5,
@@ -221,7 +221,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 10)
def testCompleteEpisodesPacking(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
batch_steps=15,
@@ -233,7 +233,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def testFilterSync(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
sample_async=True,
@@ -246,7 +246,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertNotEqual(obs_f.buffer.n, 0)
def testGetFilters(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
sample_async=True,
@@ -261,7 +261,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
def testSyncFilter(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
sample_async=True,
+7 -7
View File
@@ -11,9 +11,9 @@ import uuid
import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.test.test_common_policy_evaluator import BadPolicyGraph, \
from ray.rllib.test.test_policy_evaluator import BadPolicyGraph, \
MockPolicyGraph, MockEnv
from ray.tune.registry import register_env
@@ -110,7 +110,7 @@ class MultiServing(ServingEnv):
class TestServingEnv(unittest.TestCase):
def testServingEnvCompleteEpisodes(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
batch_steps=40,
@@ -120,7 +120,7 @@ class TestServingEnv(unittest.TestCase):
self.assertEqual(batch.count, 50)
def testServingEnvTruncateEpisodes(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
batch_steps=40,
@@ -130,7 +130,7 @@ class TestServingEnv(unittest.TestCase):
self.assertEqual(batch.count, 40)
def testServingEnvOffPolicy(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy_graph=MockPolicyGraph,
batch_steps=40,
@@ -142,7 +142,7 @@ class TestServingEnv(unittest.TestCase):
self.assertEqual(batch["actions"][-1], 42)
def testServingEnvBadActions(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=BadPolicyGraph,
sample_async=True,
@@ -188,7 +188,7 @@ class TestServingEnv(unittest.TestCase):
raise Exception("failed to improve reward")
def testServingEnvHorizonNotSupported(self):
ev = CommonPolicyEvaluator(
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
episode_horizon=20,