mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:41:11 +08:00
[rllib] Remove "Common", cleanup some code (#2348)
This commit is contained in:
@@ -12,7 +12,7 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
|
||||
|
||||
@@ -27,6 +27,6 @@ def _register_all():
|
||||
_register_all()
|
||||
|
||||
__all__ = [
|
||||
"PolicyGraph", "TFPolicyGraph", "CommonPolicyEvaluator", "SampleBatch",
|
||||
"PolicyGraph", "TFPolicyGraph", "PolicyEvaluator", "SampleBatch",
|
||||
"AsyncVectorEnv", "MultiAgentEnv", "VectorEnv", "ServingEnv",
|
||||
]
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import pickle
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
@@ -115,13 +115,13 @@ class Agent(Trainable):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
return self._make_evaluator(
|
||||
CommonPolicyEvaluator, env_creator, policy_graph, 0)
|
||||
PolicyEvaluator, env_creator, policy_graph, 0)
|
||||
|
||||
def make_remote_evaluators(
|
||||
self, env_creator, policy_graph, count, remote_args):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
cls = CommonPolicyEvaluator.as_remote(**remote_args).remote
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i+1)
|
||||
for i in range(count)]
|
||||
|
||||
@@ -8,11 +8,11 @@ from six.moves import queue
|
||||
import ray
|
||||
from ray.rllib.agents.bc.experience_dataset import ExperienceDataset
|
||||
from ray.rllib.agents.bc.policy import BCPolicy
|
||||
from ray.rllib.evaluation.interface import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
class BCEvaluator(PolicyEvaluator):
|
||||
class BCEvaluator(EvaluatorInterface):
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env_creator(
|
||||
config["env_config"]), config["model"])
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ class AsyncVectorEnv(object):
|
||||
can be sent back via send_actions().
|
||||
|
||||
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
|
||||
conversions internally in CommonPolicyEvaluator, for example:
|
||||
conversions internally in PolicyEvaluator, for example:
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
|
||||
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch, \
|
||||
SampleBatchBuilder, MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
__all__ = [
|
||||
"PolicyEvaluator", "CommonPolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
|
||||
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
|
||||
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
|
||||
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
|
||||
"compute_advantages", "collect_metrics"
|
||||
]
|
||||
|
||||
@@ -5,10 +5,10 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
|
||||
class PolicyEvaluator(object):
|
||||
class EvaluatorInterface(object):
|
||||
"""This is the interface between policy optimizers and policy evaluation.
|
||||
|
||||
See also: CommonPolicyEvaluator
|
||||
See also: PolicyEvaluator
|
||||
"""
|
||||
|
||||
def sample(self):
|
||||
@@ -109,44 +109,3 @@ class PolicyEvaluator(object):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
return func(self, *args)
|
||||
|
||||
|
||||
class TFMultiGPUSupport(PolicyEvaluator):
|
||||
"""The multi-GPU TF optimizer requires additional TF-specific support.
|
||||
|
||||
Attributes:
|
||||
sess (Session): the tensorflow session associated with this evaluator.
|
||||
"""
|
||||
|
||||
def tf_loss_inputs(self):
|
||||
"""Returns a list of the input placeholders required for the loss.
|
||||
|
||||
For example, the following calls should work:
|
||||
|
||||
Returns:
|
||||
list: a (name, placeholder) tuple for each loss input argument.
|
||||
Each placeholder name must correspond to one of the SampleBatch
|
||||
column keys returned by sample().
|
||||
|
||||
Examples:
|
||||
>>> print(ev.tf_loss_inputs())
|
||||
[("action", action_placeholder), ("reward", reward_placeholder)]
|
||||
|
||||
>>> print(ev.sample()[0].data.keys())
|
||||
["action", "reward"]
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def build_tf_loss(self, input_placeholders):
|
||||
"""Returns a new loss tensor graph for the specified inputs.
|
||||
|
||||
The graph must share vars with this Evaluator's policy model, so that
|
||||
the multi-gpu optimizer can update the weights.
|
||||
|
||||
Examples:
|
||||
>>> loss_inputs = ev.tf_loss_inputs()
|
||||
>>> ev.build_tf_loss([ph for _, ph in loss_inputs])
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[]):
|
||||
"""Gathers episode metrics from CommonPolicyEvaluator instances."""
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
|
||||
+5
-5
@@ -14,7 +14,7 @@ from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.evaluation.interface import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
@@ -25,7 +25,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
|
||||
class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
|
||||
|
||||
This class wraps a policy graph instance and an environment class to
|
||||
@@ -37,7 +37,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
|
||||
Examples:
|
||||
>>> # Create a policy evaluator and using it to collect experiences.
|
||||
>>> evaluator = CommonPolicyEvaluator(
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy_graph=PGPolicyGraph)
|
||||
>>> print(evaluator.sample())
|
||||
@@ -47,7 +47,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
|
||||
>>> # Creating policy evaluators using optimizer_cls.make().
|
||||
>>> optimizer = SyncSamplesOptimizer.make(
|
||||
... evaluator_cls=CommonPolicyEvaluator,
|
||||
... evaluator_cls=PolicyEvaluator,
|
||||
... evaluator_args={
|
||||
... "env_creator": lambda _: gym.make("CartPole-v0"),
|
||||
... "policy_graph": PGPolicyGraph,
|
||||
@@ -56,7 +56,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
>>> for _ in range(10): optimizer.step()
|
||||
|
||||
>>> # Creating a multi-agent policy evaluator
|
||||
>>> evaluator = CommonPolicyEvaluator(
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
... policy_graphs={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
@@ -25,7 +25,7 @@ class PolicyGraph(object):
|
||||
"""Initialize the graph.
|
||||
|
||||
This is the standard constructor for policy graphs. The policy graph
|
||||
class you pass into CommonPolicyEvaluator will be constructed with
|
||||
class you pass into PolicyEvaluator will be constructed with
|
||||
these arguments.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,6 +7,6 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
|
||||
|
||||
__all__ = ["ActionDistribution", "ActionDistribution", "Categorical",
|
||||
__all__ = ["ActionDistribution", "Categorical",
|
||||
"DiagGaussian", "Deterministic", "ModelCatalog", "Model",
|
||||
"Preprocessor", "FullyConnectedNetwork", "LSTM"]
|
||||
|
||||
@@ -9,4 +9,5 @@ from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
|
||||
__all__ = [
|
||||
"PolicyOptimizer", "AsyncSamplesOptimizer", "AsyncGradientsOptimizer",
|
||||
"SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer"]
|
||||
"SyncSamplesOptimizer", "SyncReplayOptimizer", "LocalMultiGPUOptimizer"
|
||||
]
|
||||
|
||||
@@ -12,9 +12,9 @@ from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, \
|
||||
SyncReplayOptimizer, AsyncGradientsOptimizer
|
||||
from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \
|
||||
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
|
||||
MockPolicyGraph
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
@@ -205,7 +205,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSample(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
@@ -224,7 +224,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
@@ -283,13 +283,13 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (PGPolicyGraph, obs_space, act_space, {}),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
}
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)
|
||||
if optimizer_cls == AsyncGradientsOptimizer:
|
||||
remote_evs = [CommonPolicyEvaluator.as_remote().remote(
|
||||
remote_evs = [PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
@@ -333,7 +333,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policies["pg_{}".format(i)] = (
|
||||
PGPolicyGraph, obs_space, act_space, {})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
|
||||
+17
-17
@@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
@@ -89,9 +89,9 @@ class MockVectorEnv(VectorEnv):
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
||||
|
||||
class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph)
|
||||
batch = ev.sample()
|
||||
@@ -110,10 +110,10 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)])
|
||||
|
||||
def testMetrics(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
|
||||
remote_ev = CommonPolicyEvaluator.as_remote().remote(
|
||||
remote_ev = PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph, batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
@@ -123,7 +123,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result.episode_reward_mean, 10)
|
||||
|
||||
def testAsync(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
sample_async=True,
|
||||
policy_graph=MockPolicyGraph)
|
||||
@@ -133,7 +133,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testAutoConcat(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=40),
|
||||
policy_graph=MockPolicyGraph,
|
||||
sample_async=True,
|
||||
@@ -145,7 +145,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
|
||||
|
||||
def testAutoVectorization(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=20),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -164,14 +164,14 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
def testBatchDivisibilityCheck(self):
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
lambda: CommonPolicyEvaluator(
|
||||
lambda: PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=15, num_envs=4))
|
||||
|
||||
def testBatchesSmallerWhenVectorized(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -185,7 +185,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result.episodes_total, 4)
|
||||
|
||||
def testVectorEnvSupport(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockVectorEnv(
|
||||
episode_length=20, num_envs=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
@@ -203,7 +203,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result.episodes_total, 8)
|
||||
|
||||
def testTruncateEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=15,
|
||||
@@ -212,7 +212,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 15)
|
||||
|
||||
def testCompleteEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=5,
|
||||
@@ -221,7 +221,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 10)
|
||||
|
||||
def testCompleteEpisodesPacking(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=15,
|
||||
@@ -233,7 +233,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
def testFilterSync(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
sample_async=True,
|
||||
@@ -246,7 +246,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
|
||||
def testGetFilters(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
sample_async=True,
|
||||
@@ -261,7 +261,7 @@ class TestCommonPolicyEvaluator(unittest.TestCase):
|
||||
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
||||
|
||||
def testSyncFilter(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
sample_async=True,
|
||||
@@ -11,9 +11,9 @@ import uuid
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.evaluation.common_policy_evaluator import CommonPolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.test.test_common_policy_evaluator import BadPolicyGraph, \
|
||||
from ray.rllib.test.test_policy_evaluator import BadPolicyGraph, \
|
||||
MockPolicyGraph, MockEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -110,7 +110,7 @@ class MultiServing(ServingEnv):
|
||||
|
||||
class TestServingEnv(unittest.TestCase):
|
||||
def testServingEnvCompleteEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
@@ -120,7 +120,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testServingEnvTruncateEpisodes(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
@@ -130,7 +130,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 40)
|
||||
|
||||
def testServingEnvOffPolicy(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=40,
|
||||
@@ -142,7 +142,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
self.assertEqual(batch["actions"][-1], 42)
|
||||
|
||||
def testServingEnvBadActions(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=BadPolicyGraph,
|
||||
sample_async=True,
|
||||
@@ -188,7 +188,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testServingEnvHorizonNotSupported(self):
|
||||
ev = CommonPolicyEvaluator(
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
episode_horizon=20,
|
||||
|
||||
Reference in New Issue
Block a user