mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
[rllib] Rename ServingEnv => ExternalEnv (#3302)
This commit is contained in:
@@ -13,7 +13,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
|
||||
@@ -51,5 +51,5 @@ __all__ = [
|
||||
"AsyncVectorEnv",
|
||||
"MultiAgentEnv",
|
||||
"VectorEnv",
|
||||
"ServingEnv",
|
||||
"ExternalEnv",
|
||||
]
|
||||
|
||||
Vendored
+3
-1
@@ -1,9 +1,11 @@
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
|
||||
__all__ = [
|
||||
"AsyncVectorEnv", "MultiAgentEnv", "ServingEnv", "VectorEnv", "EnvContext"
|
||||
"AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv",
|
||||
"ServingEnv", "EnvContext"
|
||||
]
|
||||
|
||||
+21
-21
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
@@ -20,7 +20,7 @@ class AsyncVectorEnv(object):
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
|
||||
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
|
||||
rllib.ServingEnv => rllib.AsyncVectorEnv
|
||||
rllib.ExternalEnv => rllib.AsyncVectorEnv
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space. This must be defined for
|
||||
@@ -70,11 +70,11 @@ class AsyncVectorEnv(object):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
env = _MultiAgentEnvToAsync(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
elif isinstance(env, ServingEnv):
|
||||
elif isinstance(env, ExternalEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"ServingEnv does not currently support num_envs > 1.")
|
||||
env = _ServingEnvToAsync(env)
|
||||
"ExternalEnv does not currently support num_envs > 1.")
|
||||
env = _ExternalEnvToAsync(env)
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = _VectorEnvToAsync(env)
|
||||
else:
|
||||
@@ -145,40 +145,40 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
||||
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
|
||||
|
||||
|
||||
class _ServingEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
|
||||
class _ExternalEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ExternalEnv to AsyncVectorEnv."""
|
||||
|
||||
def __init__(self, serving_env, preprocessor=None):
|
||||
self.serving_env = serving_env
|
||||
def __init__(self, external_env, preprocessor=None):
|
||||
self.external_env = external_env
|
||||
self.prep = preprocessor
|
||||
self.action_space = serving_env.action_space
|
||||
self.action_space = external_env.action_space
|
||||
if preprocessor:
|
||||
self.observation_space = preprocessor.observation_space
|
||||
else:
|
||||
self.observation_space = serving_env.observation_space
|
||||
serving_env.start()
|
||||
self.observation_space = external_env.observation_space
|
||||
external_env.start()
|
||||
|
||||
def poll(self):
|
||||
with self.serving_env._results_avail_condition:
|
||||
with self.external_env._results_avail_condition:
|
||||
results = self._poll()
|
||||
while len(results[0]) == 0:
|
||||
self.serving_env._results_avail_condition.wait()
|
||||
self.external_env._results_avail_condition.wait()
|
||||
results = self._poll()
|
||||
if not self.serving_env.isAlive():
|
||||
if not self.external_env.isAlive():
|
||||
raise Exception("Serving thread has stopped.")
|
||||
limit = self.serving_env._max_concurrent_episodes
|
||||
limit = self.external_env._max_concurrent_episodes
|
||||
assert len(results[0]) < limit, \
|
||||
("Too many concurrent episodes, were some leaked? This ServingEnv "
|
||||
"was created with max_concurrent={}".format(limit))
|
||||
("Too many concurrent episodes, were some leaked? This "
|
||||
"ExternalEnv was created with max_concurrent={}".format(limit))
|
||||
return results
|
||||
|
||||
def _poll(self):
|
||||
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
|
||||
off_policy_actions = {}
|
||||
for eid, episode in self.serving_env._episodes.copy().items():
|
||||
for eid, episode in self.external_env._episodes.copy().items():
|
||||
data = episode.get_data()
|
||||
if episode.cur_done:
|
||||
del self.serving_env._episodes[eid]
|
||||
del self.external_env._episodes[eid]
|
||||
if data:
|
||||
if self.prep:
|
||||
all_obs[eid] = self.prep.transform(data["obs"])
|
||||
@@ -197,7 +197,7 @@ class _ServingEnvToAsync(AsyncVectorEnv):
|
||||
|
||||
def send_actions(self, action_dict):
|
||||
for eid, action in action_dict.items():
|
||||
self.serving_env._episodes[eid].action_queue.put(
|
||||
self.external_env._episodes[eid].action_queue.put(
|
||||
action[_DUMMY_AGENT_ID])
|
||||
|
||||
|
||||
|
||||
Vendored
+226
@@ -0,0 +1,226 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import queue
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
|
||||
class ExternalEnv(threading.Thread):
|
||||
"""An environment that interfaces with external agents.
|
||||
|
||||
Unlike simulator envs, control is inverted. The environment queries the
|
||||
policy to obtain actions and logs observations and rewards for training.
|
||||
This is in contrast to gym.Env, where the algorithm drives the simulation
|
||||
through env.step() calls.
|
||||
|
||||
You can use ExternalEnv as the backend for policy serving (by serving HTTP
|
||||
requests in the run loop), for ingesting offline logs data (by reading
|
||||
offline transitions in the run loop), or other custom use cases not easily
|
||||
expressed through gym.Env.
|
||||
|
||||
ExternalEnv supports both on-policy actions (through self.get_action()),
|
||||
and off-policy actions (through self.log_action()).
|
||||
|
||||
This env is thread-safe, but individual episodes must be executed serially.
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space.
|
||||
observation_space (gym.Space): Observation space.
|
||||
|
||||
Examples:
|
||||
>>> register_env("my_env", lambda config: YourExternalEnv(config))
|
||||
>>> agent = DQNAgent(env="my_env")
|
||||
>>> while True:
|
||||
print(agent.train())
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||
"""Initialize an external env.
|
||||
|
||||
ExternalEnv subclasses must call this during their __init__.
|
||||
|
||||
Arguments:
|
||||
action_space (gym.Space): Action space of the env.
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
max_concurrent (int): Max number of active episodes to allow at
|
||||
once. Exceeding this limit raises an error.
|
||||
"""
|
||||
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.action_space = action_space
|
||||
self.observation_space = observation_space
|
||||
self._episodes = {}
|
||||
self._finished = set()
|
||||
self._results_avail_condition = threading.Condition()
|
||||
self._max_concurrent_episodes = max_concurrent
|
||||
|
||||
def run(self):
|
||||
"""Override this to implement the run loop.
|
||||
|
||||
Your loop should continuously:
|
||||
1. Call self.start_episode(episode_id)
|
||||
2. Call self.get_action(episode_id, obs)
|
||||
-or-
|
||||
self.log_action(episode_id, obs, action)
|
||||
3. Call self.log_returns(episode_id, reward)
|
||||
4. Call self.end_episode(episode_id, obs)
|
||||
5. Wait if nothing to do.
|
||||
|
||||
Multiple episodes may be started at the same time.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Unique string id for the episode or None for
|
||||
it to be auto-assigned.
|
||||
training_enabled (bool): Whether to use experiences for this
|
||||
episode to improve the policy.
|
||||
|
||||
Returns:
|
||||
episode_id (str): Unique string id for the episode.
|
||||
"""
|
||||
|
||||
if episode_id is None:
|
||||
episode_id = uuid.uuid4().hex
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id in self._episodes:
|
||||
raise ValueError(
|
||||
"Episode {} is already started".format(episode_id))
|
||||
|
||||
self._episodes[episode_id] = _ExternalEnvEpisode(
|
||||
episode_id, self._results_avail_condition, training_enabled)
|
||||
|
||||
return episode_id
|
||||
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
|
||||
Returns:
|
||||
action (obj): Action from the env action space.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
return episode.wait_for_action(observation)
|
||||
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
action (obj): Action for the observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.log_action(observation, action)
|
||||
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
episode. Rewards accumulate until the next action. If no reward is
|
||||
logged before the next action, a reward of 0.0 is assumed.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward (float): Reward from the environment.
|
||||
info (dict): Optional info dict.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.cur_reward += reward
|
||||
if info:
|
||||
episode.cur_info = info or {}
|
||||
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
self._finished.add(episode.episode_id)
|
||||
episode.done(observation)
|
||||
|
||||
def _get(self, episode_id):
|
||||
"""Get a started episode or raise an error."""
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id not in self._episodes:
|
||||
raise ValueError("Episode {} not found.".format(episode_id))
|
||||
|
||||
return self._episodes[episode_id]
|
||||
|
||||
|
||||
class _ExternalEnvEpisode(object):
|
||||
"""Tracked state for each active episode."""
|
||||
|
||||
def __init__(self, episode_id, results_avail_condition, training_enabled):
|
||||
self.episode_id = episode_id
|
||||
self.results_avail_condition = results_avail_condition
|
||||
self.training_enabled = training_enabled
|
||||
self.data_queue = queue.Queue()
|
||||
self.action_queue = queue.Queue()
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
self.cur_done = False
|
||||
self.cur_info = {}
|
||||
|
||||
def get_data(self):
|
||||
if self.data_queue.empty():
|
||||
return None
|
||||
return self.data_queue.get_nowait()
|
||||
|
||||
def log_action(self, observation, action):
|
||||
self.new_observation = observation
|
||||
self.new_action = action
|
||||
self._send()
|
||||
self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def wait_for_action(self, observation):
|
||||
self.new_observation = observation
|
||||
self._send()
|
||||
return self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def done(self, observation):
|
||||
self.new_observation = observation
|
||||
self.cur_done = True
|
||||
self._send()
|
||||
|
||||
def _send(self):
|
||||
item = {
|
||||
"obs": self.new_observation,
|
||||
"reward": self.cur_reward,
|
||||
"done": self.cur_done,
|
||||
"info": self.cur_info,
|
||||
}
|
||||
if self.new_action is not None:
|
||||
item["off_policy_action"] = self.new_action
|
||||
if not self.training_enabled:
|
||||
item["info"]["training_enabled"] = False
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
with self.results_avail_condition:
|
||||
self.data_queue.put_nowait(item)
|
||||
self.results_avail_condition.notify()
|
||||
Vendored
+3
-221
@@ -2,225 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import queue
|
||||
import threading
|
||||
import uuid
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
|
||||
|
||||
class ServingEnv(threading.Thread):
|
||||
"""An environment that provides policy serving.
|
||||
|
||||
Unlike simulator envs, control is inverted. The environment queries the
|
||||
policy to obtain actions and logs observations and rewards for training.
|
||||
This is in contrast to gym.Env, where the algorithm drives the simulation
|
||||
through env.step() calls.
|
||||
|
||||
You can use ServingEnv as the backend for policy serving (by serving HTTP
|
||||
requests in the run loop), for ingesting offline logs data (by reading
|
||||
offline transitions in the run loop), or other custom use cases not easily
|
||||
expressed through gym.Env.
|
||||
|
||||
ServingEnv supports both on-policy serving (through self.get_action()), and
|
||||
off-policy serving (through self.log_action()).
|
||||
|
||||
This env is thread-safe, but individual episodes must be executed serially.
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space.
|
||||
observation_space (gym.Space): Observation space.
|
||||
|
||||
Examples:
|
||||
>>> register_env("my_env", lambda config: YourServingEnv(config))
|
||||
>>> agent = DQNAgent(env="my_env")
|
||||
>>> while True:
|
||||
print(agent.train())
|
||||
"""
|
||||
|
||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||
"""Initialize a serving env.
|
||||
|
||||
ServingEnv subclasses must call this during their __init__.
|
||||
|
||||
Arguments:
|
||||
action_space (gym.Space): Action space of the env.
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
max_concurrent (int): Max number of active episodes to allow at
|
||||
once. Exceeding this limit raises an error.
|
||||
"""
|
||||
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.action_space = action_space
|
||||
self.observation_space = observation_space
|
||||
self._episodes = {}
|
||||
self._finished = set()
|
||||
self._results_avail_condition = threading.Condition()
|
||||
self._max_concurrent_episodes = max_concurrent
|
||||
|
||||
def run(self):
|
||||
"""Override this to implement the run loop.
|
||||
|
||||
Your loop should continuously:
|
||||
1. Call self.start_episode(episode_id)
|
||||
2. Call self.get_action(episode_id, obs)
|
||||
-or-
|
||||
self.log_action(episode_id, obs, action)
|
||||
3. Call self.log_returns(episode_id, reward)
|
||||
4. Call self.end_episode(episode_id, obs)
|
||||
5. Wait if nothing to do.
|
||||
|
||||
Multiple episodes may be started at the same time.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Unique string id for the episode or None for
|
||||
it to be auto-assigned.
|
||||
training_enabled (bool): Whether to use experiences for this
|
||||
episode to improve the policy.
|
||||
|
||||
Returns:
|
||||
episode_id (str): Unique string id for the episode.
|
||||
"""
|
||||
|
||||
if episode_id is None:
|
||||
episode_id = uuid.uuid4().hex
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id in self._episodes:
|
||||
raise ValueError(
|
||||
"Episode {} is already started".format(episode_id))
|
||||
|
||||
self._episodes[episode_id] = _ServingEnvEpisode(
|
||||
episode_id, self._results_avail_condition, training_enabled)
|
||||
|
||||
return episode_id
|
||||
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
|
||||
Returns:
|
||||
action (obj): Action from the env action space.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
return episode.wait_for_action(observation)
|
||||
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
action (obj): Action for the observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.log_action(observation, action)
|
||||
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
episode. Rewards accumulate until the next action. If no reward is
|
||||
logged before the next action, a reward of 0.0 is assumed.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward (float): Reward from the environment.
|
||||
info (dict): Optional info dict.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.cur_reward += reward
|
||||
if info:
|
||||
episode.cur_info = info or {}
|
||||
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
observation (obj): Current environment observation.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
self._finished.add(episode.episode_id)
|
||||
episode.done(observation)
|
||||
|
||||
def _get(self, episode_id):
|
||||
"""Get a started episode or raise an error."""
|
||||
|
||||
if episode_id in self._finished:
|
||||
raise ValueError(
|
||||
"Episode {} has already completed.".format(episode_id))
|
||||
|
||||
if episode_id not in self._episodes:
|
||||
raise ValueError("Episode {} not found.".format(episode_id))
|
||||
|
||||
return self._episodes[episode_id]
|
||||
|
||||
|
||||
class _ServingEnvEpisode(object):
|
||||
"""Tracked state for each active episode."""
|
||||
|
||||
def __init__(self, episode_id, results_avail_condition, training_enabled):
|
||||
self.episode_id = episode_id
|
||||
self.results_avail_condition = results_avail_condition
|
||||
self.training_enabled = training_enabled
|
||||
self.data_queue = queue.Queue()
|
||||
self.action_queue = queue.Queue()
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
self.cur_done = False
|
||||
self.cur_info = {}
|
||||
|
||||
def get_data(self):
|
||||
if self.data_queue.empty():
|
||||
return None
|
||||
return self.data_queue.get_nowait()
|
||||
|
||||
def log_action(self, observation, action):
|
||||
self.new_observation = observation
|
||||
self.new_action = action
|
||||
self._send()
|
||||
self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def wait_for_action(self, observation):
|
||||
self.new_observation = observation
|
||||
self._send()
|
||||
return self.action_queue.get(True, timeout=60.0)
|
||||
|
||||
def done(self, observation):
|
||||
self.new_observation = observation
|
||||
self.cur_done = True
|
||||
self._send()
|
||||
|
||||
def _send(self):
|
||||
item = {
|
||||
"obs": self.new_observation,
|
||||
"reward": self.cur_reward,
|
||||
"done": self.cur_done,
|
||||
"info": self.cur_info,
|
||||
}
|
||||
if self.new_action is not None:
|
||||
item["off_policy_action"] = self.new_action
|
||||
if not self.training_enabled:
|
||||
item["info"]["training_enabled"] = False
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
with self.results_avail_condition:
|
||||
self.data_queue.put_nowait(item)
|
||||
self.results_avail_condition.notify()
|
||||
# renamed to ExternalEnv in 0.6
|
||||
ServingEnv = ExternalEnv
|
||||
|
||||
@@ -29,7 +29,7 @@ parser.add_argument(
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
env = gym.make("CartPole-v0")
|
||||
client = PolicyClient("http://localhost:8900")
|
||||
client = PolicyClient("http://localhost:9900")
|
||||
|
||||
eid = client.start_episode(training_enabled=not args.no_train)
|
||||
obs = env.reset()
|
||||
|
||||
@@ -14,19 +14,19 @@ import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
SERVER_ADDRESS = "localhost"
|
||||
SERVER_PORT = 8900
|
||||
SERVER_PORT = 9900
|
||||
CHECKPOINT_FILE = "last_checkpoint.out"
|
||||
|
||||
|
||||
class CartpoleServing(ServingEnv):
|
||||
class CartpoleServing(ExternalEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
ExternalEnv.__init__(
|
||||
self, spaces.Discrete(2),
|
||||
spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32))
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ pkill -f cartpole_server.py
|
||||
(python cartpole_server.py 2>&1 | grep -v 200) &
|
||||
pid=$!
|
||||
|
||||
while ! curl localhost:8900; do
|
||||
while ! curl localhost:9900; do
|
||||
sleep 1
|
||||
done
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from functools import partial
|
||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
_global_registry
|
||||
|
||||
from ray.rllib.env.async_vector_env import _ServingEnvToAsync
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.async_vector_env import _ExternalEnvToAsync
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
|
||||
@@ -270,7 +270,7 @@ class ModelCatalog(object):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
env (gym.Env|VectorEnv|ServingEnv): The environment to wrap.
|
||||
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
@@ -300,7 +300,7 @@ class ModelCatalog(object):
|
||||
"""Returns a preprocessor as a gym observation wrapper.
|
||||
|
||||
Args:
|
||||
env (gym.Env|VectorEnv|ServingEnv): The environment to wrap.
|
||||
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
Returns:
|
||||
@@ -313,8 +313,8 @@ class ModelCatalog(object):
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
elif isinstance(env, VectorEnv):
|
||||
return _RLlibVectorPreprocessorWrapper(env, preprocessor)
|
||||
elif isinstance(env, ServingEnv):
|
||||
return _ServingEnvToAsync(env, preprocessor)
|
||||
elif isinstance(env, ExternalEnv):
|
||||
return _ExternalEnvToAsync(env, preprocessor)
|
||||
else:
|
||||
raise ValueError("Don't know how to wrap {}".format(env))
|
||||
|
||||
|
||||
+16
-16
@@ -12,15 +12,15 @@ import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.test.test_policy_evaluator import BadPolicyGraph, \
|
||||
MockPolicyGraph, MockEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class SimpleServing(ServingEnv):
|
||||
class SimpleServing(ExternalEnv):
|
||||
def __init__(self, env):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
ExternalEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
|
||||
def run(self):
|
||||
@@ -36,9 +36,9 @@ class SimpleServing(ServingEnv):
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class PartOffPolicyServing(ServingEnv):
|
||||
class PartOffPolicyServing(ExternalEnv):
|
||||
def __init__(self, env, off_pol_frac):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
ExternalEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
self.off_pol_frac = off_pol_frac
|
||||
|
||||
@@ -59,9 +59,9 @@ class PartOffPolicyServing(ServingEnv):
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class SimpleOffPolicyServing(ServingEnv):
|
||||
class SimpleOffPolicyServing(ExternalEnv):
|
||||
def __init__(self, env, fixed_action):
|
||||
ServingEnv.__init__(self, env.action_space, env.observation_space)
|
||||
ExternalEnv.__init__(self, env.action_space, env.observation_space)
|
||||
self.env = env
|
||||
self.fixed_action = fixed_action
|
||||
|
||||
@@ -79,12 +79,12 @@ class SimpleOffPolicyServing(ServingEnv):
|
||||
eid = self.start_episode()
|
||||
|
||||
|
||||
class MultiServing(ServingEnv):
|
||||
class MultiServing(ExternalEnv):
|
||||
def __init__(self, env_creator):
|
||||
self.env_creator = env_creator
|
||||
self.env = env_creator()
|
||||
ServingEnv.__init__(self, self.env.action_space,
|
||||
self.env.observation_space)
|
||||
ExternalEnv.__init__(self, self.env.action_space,
|
||||
self.env.observation_space)
|
||||
|
||||
def run(self):
|
||||
envs = [self.env_creator() for _ in range(5)]
|
||||
@@ -107,8 +107,8 @@ class MultiServing(ServingEnv):
|
||||
del cur_obs[i]
|
||||
|
||||
|
||||
class TestServingEnv(unittest.TestCase):
|
||||
def testServingEnvCompleteEpisodes(self):
|
||||
class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvCompleteEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
@@ -118,7 +118,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testServingEnvTruncateEpisodes(self):
|
||||
def testExternalEnvTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
@@ -128,7 +128,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 40)
|
||||
|
||||
def testServingEnvOffPolicy(self):
|
||||
def testExternalEnvOffPolicy(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy_graph=MockPolicyGraph,
|
||||
@@ -140,7 +140,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
self.assertEqual(batch["actions"][0], 42)
|
||||
self.assertEqual(batch["actions"][-1], 42)
|
||||
|
||||
def testServingEnvBadActions(self):
|
||||
def testExternalEnvBadActions(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=BadPolicyGraph,
|
||||
@@ -185,7 +185,7 @@ class TestServingEnv(unittest.TestCase):
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testServingEnvHorizonNotSupported(self):
|
||||
def testExternalEnvHorizonNotSupported(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
@@ -17,7 +17,7 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.test.test_serving_env import SimpleServing
|
||||
from ray.rllib.test.test_external_env import SimpleServing
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
DICT_SPACE = spaces.Dict({
|
||||
|
||||
@@ -18,15 +18,15 @@ elif sys.version_info[0] == 3:
|
||||
|
||||
|
||||
class PolicyServer(ThreadingMixIn, HTTPServer):
|
||||
"""REST server than can be launched from a ServingEnv.
|
||||
"""REST server than can be launched from a ExternalEnv.
|
||||
|
||||
This launches a multi-threaded server that listens on the specified host
|
||||
and port to serve policy requests and forward experiences to RLlib.
|
||||
|
||||
Examples:
|
||||
>>> class CartpoleServing(ServingEnv):
|
||||
>>> class CartpoleServing(ExternalEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
ExternalEnv.__init__(
|
||||
self, spaces.Discrete(2),
|
||||
spaces.Box(
|
||||
low=-10,
|
||||
@@ -50,12 +50,12 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
|
||||
>>> client.log_returns(eps_id, reward)
|
||||
"""
|
||||
|
||||
def __init__(self, serving_env, address, port):
|
||||
handler = _make_handler(serving_env)
|
||||
def __init__(self, external_env, address, port):
|
||||
handler = _make_handler(external_env)
|
||||
HTTPServer.__init__(self, (address, port), handler)
|
||||
|
||||
|
||||
def _make_handler(serving_env):
|
||||
def _make_handler(external_env):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
@@ -73,20 +73,20 @@ def _make_handler(serving_env):
|
||||
command = args["command"]
|
||||
response = {}
|
||||
if command == PolicyClient.START_EPISODE:
|
||||
response["episode_id"] = serving_env.start_episode(
|
||||
response["episode_id"] = external_env.start_episode(
|
||||
args["episode_id"], args["training_enabled"])
|
||||
elif command == PolicyClient.GET_ACTION:
|
||||
response["action"] = serving_env.get_action(
|
||||
response["action"] = external_env.get_action(
|
||||
args["episode_id"], args["observation"])
|
||||
elif command == PolicyClient.LOG_ACTION:
|
||||
serving_env.log_action(args["episode_id"], args["observation"],
|
||||
args["action"])
|
||||
external_env.log_action(args["episode_id"],
|
||||
args["observation"], args["action"])
|
||||
elif command == PolicyClient.LOG_RETURNS:
|
||||
serving_env.log_returns(args["episode_id"], args["reward"],
|
||||
args["info"])
|
||||
external_env.log_returns(args["episode_id"], args["reward"],
|
||||
args["info"])
|
||||
elif command == PolicyClient.END_EPISODE:
|
||||
serving_env.end_episode(args["episode_id"],
|
||||
args["observation"])
|
||||
external_env.end_episode(args["episode_id"],
|
||||
args["observation"])
|
||||
else:
|
||||
raise Exception("Unknown command: {}".format(command))
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user