[rllib] Rename ServingEnv => ExternalEnv (#3302)

This commit is contained in:
Eric Liang
2018-11-12 16:31:27 -08:00
committed by GitHub
parent e37891d79d
commit bd0dbde149
16 changed files with 313 additions and 301 deletions
+2 -2
View File
@@ -13,7 +13,7 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import SampleBatch
@@ -51,5 +51,5 @@ __all__ = [
"AsyncVectorEnv",
"MultiAgentEnv",
"VectorEnv",
"ServingEnv",
"ExternalEnv",
]
+3 -1
View File
@@ -1,9 +1,11 @@
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.env_context import EnvContext
__all__ = [
"AsyncVectorEnv", "MultiAgentEnv", "ServingEnv", "VectorEnv", "EnvContext"
"AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv",
"ServingEnv", "EnvContext"
]
+21 -21
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -20,7 +20,7 @@ class AsyncVectorEnv(object):
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
rllib.ExternalEnv => rllib.AsyncVectorEnv
Attributes:
action_space (gym.Space): Action space. This must be defined for
@@ -70,11 +70,11 @@ class AsyncVectorEnv(object):
if isinstance(env, MultiAgentEnv):
env = _MultiAgentEnvToAsync(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ServingEnv):
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
"ServingEnv does not currently support num_envs > 1.")
env = _ServingEnvToAsync(env)
"ExternalEnv does not currently support num_envs > 1.")
env = _ExternalEnvToAsync(env)
elif isinstance(env, VectorEnv):
env = _VectorEnvToAsync(env)
else:
@@ -145,40 +145,40 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
class _ServingEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ServingEnv to AsyncVectorEnv."""
class _ExternalEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ExternalEnv to AsyncVectorEnv."""
def __init__(self, serving_env, preprocessor=None):
self.serving_env = serving_env
def __init__(self, external_env, preprocessor=None):
self.external_env = external_env
self.prep = preprocessor
self.action_space = serving_env.action_space
self.action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
else:
self.observation_space = serving_env.observation_space
serving_env.start()
self.observation_space = external_env.observation_space
external_env.start()
def poll(self):
with self.serving_env._results_avail_condition:
with self.external_env._results_avail_condition:
results = self._poll()
while len(results[0]) == 0:
self.serving_env._results_avail_condition.wait()
self.external_env._results_avail_condition.wait()
results = self._poll()
if not self.serving_env.isAlive():
if not self.external_env.isAlive():
raise Exception("Serving thread has stopped.")
limit = self.serving_env._max_concurrent_episodes
limit = self.external_env._max_concurrent_episodes
assert len(results[0]) < limit, \
("Too many concurrent episodes, were some leaked? This ServingEnv "
"was created with max_concurrent={}".format(limit))
("Too many concurrent episodes, were some leaked? This "
"ExternalEnv was created with max_concurrent={}".format(limit))
return results
def _poll(self):
all_obs, all_rewards, all_dones, all_infos = {}, {}, {}, {}
off_policy_actions = {}
for eid, episode in self.serving_env._episodes.copy().items():
for eid, episode in self.external_env._episodes.copy().items():
data = episode.get_data()
if episode.cur_done:
del self.serving_env._episodes[eid]
del self.external_env._episodes[eid]
if data:
if self.prep:
all_obs[eid] = self.prep.transform(data["obs"])
@@ -197,7 +197,7 @@ class _ServingEnvToAsync(AsyncVectorEnv):
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.serving_env._episodes[eid].action_queue.put(
self.external_env._episodes[eid].action_queue.put(
action[_DUMMY_AGENT_ID])
+226
View File
@@ -0,0 +1,226 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import queue
import threading
import uuid
class ExternalEnv(threading.Thread):
"""An environment that interfaces with external agents.
Unlike simulator envs, control is inverted. The environment queries the
policy to obtain actions and logs observations and rewards for training.
This is in contrast to gym.Env, where the algorithm drives the simulation
through env.step() calls.
You can use ExternalEnv as the backend for policy serving (by serving HTTP
requests in the run loop), for ingesting offline logs data (by reading
offline transitions in the run loop), or other custom use cases not easily
expressed through gym.Env.
ExternalEnv supports both on-policy actions (through self.get_action()),
and off-policy actions (through self.log_action()).
This env is thread-safe, but individual episodes must be executed serially.
Attributes:
action_space (gym.Space): Action space.
observation_space (gym.Space): Observation space.
Examples:
>>> register_env("my_env", lambda config: YourExternalEnv(config))
>>> agent = DQNAgent(env="my_env")
>>> while True:
print(agent.train())
"""
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize an external env.
ExternalEnv subclasses must call this during their __init__.
Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
once. Exceeding this limit raises an error.
"""
threading.Thread.__init__(self)
self.daemon = True
self.action_space = action_space
self.observation_space = observation_space
self._episodes = {}
self._finished = set()
self._results_avail_condition = threading.Condition()
self._max_concurrent_episodes = max_concurrent
def run(self):
"""Override this to implement the run loop.
Your loop should continuously:
1. Call self.start_episode(episode_id)
2. Call self.get_action(episode_id, obs)
-or-
self.log_action(episode_id, obs, action)
3. Call self.log_returns(episode_id, reward)
4. Call self.end_episode(episode_id, obs)
5. Wait if nothing to do.
Multiple episodes may be started at the same time.
"""
raise NotImplementedError
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
if episode_id is None:
episode_id = uuid.uuid4().hex
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id in self._episodes:
raise ValueError(
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _ExternalEnvEpisode(
episode_id, self._results_avail_condition, training_enabled)
return episode_id
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
Returns:
action (obj): Action from the env action space.
"""
episode = self._get(episode_id)
return episode.wait_for_action(observation)
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
"""
episode = self._get(episode_id)
episode.log_action(observation, action)
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Optional info dict.
"""
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info or {}
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
episode.done(observation)
def _get(self, episode_id):
"""Get a started episode or raise an error."""
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id not in self._episodes:
raise ValueError("Episode {} not found.".format(episode_id))
return self._episodes[episode_id]
class _ExternalEnvEpisode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition, training_enabled):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
def get_data(self):
if self.data_queue.empty():
return None
return self.data_queue.get_nowait()
def log_action(self, observation, action):
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)
def wait_for_action(self, observation):
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=60.0)
def done(self, observation):
self.new_observation = observation
self.cur_done = True
self._send()
def _send(self):
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
if not self.training_enabled:
item["info"]["training_enabled"] = False
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
+3 -221
View File
@@ -2,225 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import queue
import threading
import uuid
from ray.rllib.env.external_env import ExternalEnv
class ServingEnv(threading.Thread):
"""An environment that provides policy serving.
Unlike simulator envs, control is inverted. The environment queries the
policy to obtain actions and logs observations and rewards for training.
This is in contrast to gym.Env, where the algorithm drives the simulation
through env.step() calls.
You can use ServingEnv as the backend for policy serving (by serving HTTP
requests in the run loop), for ingesting offline logs data (by reading
offline transitions in the run loop), or other custom use cases not easily
expressed through gym.Env.
ServingEnv supports both on-policy serving (through self.get_action()), and
off-policy serving (through self.log_action()).
This env is thread-safe, but individual episodes must be executed serially.
Attributes:
action_space (gym.Space): Action space.
observation_space (gym.Space): Observation space.
Examples:
>>> register_env("my_env", lambda config: YourServingEnv(config))
>>> agent = DQNAgent(env="my_env")
>>> while True:
print(agent.train())
"""
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize a serving env.
ServingEnv subclasses must call this during their __init__.
Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
once. Exceeding this limit raises an error.
"""
threading.Thread.__init__(self)
self.daemon = True
self.action_space = action_space
self.observation_space = observation_space
self._episodes = {}
self._finished = set()
self._results_avail_condition = threading.Condition()
self._max_concurrent_episodes = max_concurrent
def run(self):
"""Override this to implement the run loop.
Your loop should continuously:
1. Call self.start_episode(episode_id)
2. Call self.get_action(episode_id, obs)
-or-
self.log_action(episode_id, obs, action)
3. Call self.log_returns(episode_id, reward)
4. Call self.end_episode(episode_id, obs)
5. Wait if nothing to do.
Multiple episodes may be started at the same time.
"""
raise NotImplementedError
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
Returns:
episode_id (str): Unique string id for the episode.
"""
if episode_id is None:
episode_id = uuid.uuid4().hex
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id in self._episodes:
raise ValueError(
"Episode {} is already started".format(episode_id))
self._episodes[episode_id] = _ServingEnvEpisode(
episode_id, self._results_avail_condition, training_enabled)
return episode_id
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
Returns:
action (obj): Action from the env action space.
"""
episode = self._get(episode_id)
return episode.wait_for_action(observation)
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
"""
episode = self._get(episode_id)
episode.log_action(observation, action)
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Optional info dict.
"""
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info or {}
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
episode = self._get(episode_id)
self._finished.add(episode.episode_id)
episode.done(observation)
def _get(self, episode_id):
"""Get a started episode or raise an error."""
if episode_id in self._finished:
raise ValueError(
"Episode {} has already completed.".format(episode_id))
if episode_id not in self._episodes:
raise ValueError("Episode {} not found.".format(episode_id))
return self._episodes[episode_id]
class _ServingEnvEpisode(object):
"""Tracked state for each active episode."""
def __init__(self, episode_id, results_avail_condition, training_enabled):
self.episode_id = episode_id
self.results_avail_condition = results_avail_condition
self.training_enabled = training_enabled
self.data_queue = queue.Queue()
self.action_queue = queue.Queue()
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
self.cur_done = False
self.cur_info = {}
def get_data(self):
if self.data_queue.empty():
return None
return self.data_queue.get_nowait()
def log_action(self, observation, action):
self.new_observation = observation
self.new_action = action
self._send()
self.action_queue.get(True, timeout=60.0)
def wait_for_action(self, observation):
self.new_observation = observation
self._send()
return self.action_queue.get(True, timeout=60.0)
def done(self, observation):
self.new_observation = observation
self.cur_done = True
self._send()
def _send(self):
item = {
"obs": self.new_observation,
"reward": self.cur_reward,
"done": self.cur_done,
"info": self.cur_info,
}
if self.new_action is not None:
item["off_policy_action"] = self.new_action
if not self.training_enabled:
item["info"]["training_enabled"] = False
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
# renamed to ExternalEnv in 0.6
ServingEnv = ExternalEnv
@@ -29,7 +29,7 @@ parser.add_argument(
if __name__ == "__main__":
args = parser.parse_args()
env = gym.make("CartPole-v0")
client = PolicyClient("http://localhost:8900")
client = PolicyClient("http://localhost:9900")
eid = client.start_episode(training_enabled=not args.no_train)
obs = env.reset()
@@ -14,19 +14,19 @@ import numpy as np
import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
SERVER_ADDRESS = "localhost"
SERVER_PORT = 8900
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint.out"
class CartpoleServing(ServingEnv):
class CartpoleServing(ExternalEnv):
def __init__(self):
ServingEnv.__init__(
ExternalEnv.__init__(
self, spaces.Discrete(2),
spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32))
+1 -1
View File
@@ -4,7 +4,7 @@ pkill -f cartpole_server.py
(python cartpole_server.py 2>&1 | grep -v 200) &
pid=$!
while ! curl localhost:8900; do
while ! curl localhost:9900; do
sleep 1
done
+6 -6
View File
@@ -11,8 +11,8 @@ from functools import partial
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
_global_registry
from ray.rllib.env.async_vector_env import _ServingEnvToAsync
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.async_vector_env import _ExternalEnvToAsync
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
@@ -270,7 +270,7 @@ class ModelCatalog(object):
"""Returns a suitable processor for the given environment.
Args:
env (gym.Env|VectorEnv|ServingEnv): The environment to wrap.
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
options (dict): Options to pass to the preprocessor.
Returns:
@@ -300,7 +300,7 @@ class ModelCatalog(object):
"""Returns a preprocessor as a gym observation wrapper.
Args:
env (gym.Env|VectorEnv|ServingEnv): The environment to wrap.
env (gym.Env|VectorEnv|ExternalEnv): The environment to wrap.
options (dict): Options to pass to the preprocessor.
Returns:
@@ -313,8 +313,8 @@ class ModelCatalog(object):
return _RLlibPreprocessorWrapper(env, preprocessor)
elif isinstance(env, VectorEnv):
return _RLlibVectorPreprocessorWrapper(env, preprocessor)
elif isinstance(env, ServingEnv):
return _ServingEnvToAsync(env, preprocessor)
elif isinstance(env, ExternalEnv):
return _ExternalEnvToAsync(env, preprocessor)
else:
raise ValueError("Don't know how to wrap {}".format(env))
@@ -12,15 +12,15 @@ import ray
from ray.rllib.agents.dqn import DQNAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.serving_env import ServingEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.test.test_policy_evaluator import BadPolicyGraph, \
MockPolicyGraph, MockEnv
from ray.tune.registry import register_env
class SimpleServing(ServingEnv):
class SimpleServing(ExternalEnv):
def __init__(self, env):
ServingEnv.__init__(self, env.action_space, env.observation_space)
ExternalEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
def run(self):
@@ -36,9 +36,9 @@ class SimpleServing(ServingEnv):
eid = self.start_episode()
class PartOffPolicyServing(ServingEnv):
class PartOffPolicyServing(ExternalEnv):
def __init__(self, env, off_pol_frac):
ServingEnv.__init__(self, env.action_space, env.observation_space)
ExternalEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
self.off_pol_frac = off_pol_frac
@@ -59,9 +59,9 @@ class PartOffPolicyServing(ServingEnv):
eid = self.start_episode()
class SimpleOffPolicyServing(ServingEnv):
class SimpleOffPolicyServing(ExternalEnv):
def __init__(self, env, fixed_action):
ServingEnv.__init__(self, env.action_space, env.observation_space)
ExternalEnv.__init__(self, env.action_space, env.observation_space)
self.env = env
self.fixed_action = fixed_action
@@ -79,12 +79,12 @@ class SimpleOffPolicyServing(ServingEnv):
eid = self.start_episode()
class MultiServing(ServingEnv):
class MultiServing(ExternalEnv):
def __init__(self, env_creator):
self.env_creator = env_creator
self.env = env_creator()
ServingEnv.__init__(self, self.env.action_space,
self.env.observation_space)
ExternalEnv.__init__(self, self.env.action_space,
self.env.observation_space)
def run(self):
envs = [self.env_creator() for _ in range(5)]
@@ -107,8 +107,8 @@ class MultiServing(ServingEnv):
del cur_obs[i]
class TestServingEnv(unittest.TestCase):
def testServingEnvCompleteEpisodes(self):
class TestExternalEnv(unittest.TestCase):
def testExternalEnvCompleteEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
@@ -118,7 +118,7 @@ class TestServingEnv(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 50)
def testServingEnvTruncateEpisodes(self):
def testExternalEnvTruncateEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
@@ -128,7 +128,7 @@ class TestServingEnv(unittest.TestCase):
batch = ev.sample()
self.assertEqual(batch.count, 40)
def testServingEnvOffPolicy(self):
def testExternalEnvOffPolicy(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy_graph=MockPolicyGraph,
@@ -140,7 +140,7 @@ class TestServingEnv(unittest.TestCase):
self.assertEqual(batch["actions"][0], 42)
self.assertEqual(batch["actions"][-1], 42)
def testServingEnvBadActions(self):
def testExternalEnvBadActions(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=BadPolicyGraph,
@@ -185,7 +185,7 @@ class TestServingEnv(unittest.TestCase):
return
raise Exception("failed to improve reward")
def testServingEnvHorizonNotSupported(self):
def testExternalEnvHorizonNotSupported(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
+1 -1
View File
@@ -17,7 +17,7 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
from ray.rllib.test.test_serving_env import SimpleServing
from ray.rllib.test.test_external_env import SimpleServing
from ray.tune.registry import register_env
DICT_SPACE = spaces.Dict({
+14 -14
View File
@@ -18,15 +18,15 @@ elif sys.version_info[0] == 3:
class PolicyServer(ThreadingMixIn, HTTPServer):
"""REST server than can be launched from a ServingEnv.
"""REST server than can be launched from a ExternalEnv.
This launches a multi-threaded server that listens on the specified host
and port to serve policy requests and forward experiences to RLlib.
Examples:
>>> class CartpoleServing(ServingEnv):
>>> class CartpoleServing(ExternalEnv):
def __init__(self):
ServingEnv.__init__(
ExternalEnv.__init__(
self, spaces.Discrete(2),
spaces.Box(
low=-10,
@@ -50,12 +50,12 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
>>> client.log_returns(eps_id, reward)
"""
def __init__(self, serving_env, address, port):
handler = _make_handler(serving_env)
def __init__(self, external_env, address, port):
handler = _make_handler(external_env)
HTTPServer.__init__(self, (address, port), handler)
def _make_handler(serving_env):
def _make_handler(external_env):
class Handler(SimpleHTTPRequestHandler):
def do_POST(self):
content_len = int(self.headers.get('Content-Length'), 0)
@@ -73,20 +73,20 @@ def _make_handler(serving_env):
command = args["command"]
response = {}
if command == PolicyClient.START_EPISODE:
response["episode_id"] = serving_env.start_episode(
response["episode_id"] = external_env.start_episode(
args["episode_id"], args["training_enabled"])
elif command == PolicyClient.GET_ACTION:
response["action"] = serving_env.get_action(
response["action"] = external_env.get_action(
args["episode_id"], args["observation"])
elif command == PolicyClient.LOG_ACTION:
serving_env.log_action(args["episode_id"], args["observation"],
args["action"])
external_env.log_action(args["episode_id"],
args["observation"], args["action"])
elif command == PolicyClient.LOG_RETURNS:
serving_env.log_returns(args["episode_id"], args["reward"],
args["info"])
external_env.log_returns(args["episode_id"], args["reward"],
args["info"])
elif command == PolicyClient.END_EPISODE:
serving_env.end_episode(args["episode_id"],
args["observation"])
external_env.end_episode(args["episode_id"],
args["observation"])
else:
raise Exception("Unknown command: {}".format(command))
return response