From e5724a9cfefd5c11467ff7975397ddfb5d2fe0e7 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 20 Jun 2018 13:22:39 -0700 Subject: [PATCH] [rllib] Add a simple REST policy server and client example (#2232) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * policy serve * spaces * checkpoint * no train * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * fix race condition * update * com * updat * add test * Update run_multi_node_tests.sh * use curl * curl * kill * Update run_multi_node_tests.sh * Update run_multi_node_tests.sh * fix import * update --- python/ray/rllib/examples/__init__.py | 5 - .../examples/legacy_multiagent/__init__.py | 0 .../multiagent_mountaincar.py | 4 +- .../multiagent_mountaincar_env.py | 0 .../multiagent_pendulum.py | 4 +- .../multiagent_pendulum_env.py | 0 .../rllib/examples/serving/cartpole_client.py | 55 +++++++++ .../rllib/examples/serving/cartpole_server.py | 66 ++++++++++ python/ray/rllib/examples/serving/test.sh | 12 ++ python/ray/rllib/test/test_serving_env.py | 39 +++--- python/ray/rllib/utils/policy_client.py | 116 ++++++++++++++++++ python/ray/rllib/utils/policy_server.py | 62 ++++++++++ python/ray/rllib/utils/sampler.py | 37 +++--- python/ray/rllib/utils/serving_env.py | 58 ++++----- test/jenkins_tests/run_multi_node_tests.sh | 4 +- 15 files changed, 384 insertions(+), 78 deletions(-) create mode 100644 python/ray/rllib/examples/legacy_multiagent/__init__.py rename python/ray/rllib/examples/{ => legacy_multiagent}/multiagent_mountaincar.py (91%) rename python/ray/rllib/examples/{ => legacy_multiagent}/multiagent_mountaincar_env.py (100%) rename python/ray/rllib/examples/{ => legacy_multiagent}/multiagent_pendulum.py (92%) rename python/ray/rllib/examples/{ => legacy_multiagent}/multiagent_pendulum_env.py (100%) create mode 100755 python/ray/rllib/examples/serving/cartpole_client.py create mode 100755 python/ray/rllib/examples/serving/cartpole_server.py create mode 100755 python/ray/rllib/examples/serving/test.sh create mode 100644 python/ray/rllib/utils/policy_client.py create mode 100644 python/ray/rllib/utils/policy_server.py diff --git a/python/ray/rllib/examples/__init__.py b/python/ray/rllib/examples/__init__.py index bcedb9af0..e69de29bb 100644 --- a/python/ray/rllib/examples/__init__.py +++ b/python/ray/rllib/examples/__init__.py @@ -1,5 +0,0 @@ -# flake8: noqa -from ray.rllib.examples.multiagent_mountaincar_env \ - import MultiAgentMountainCarEnv -from ray.rllib.examples.multiagent_pendulum_env \ - import MultiAgentPendulumEnv diff --git a/python/ray/rllib/examples/legacy_multiagent/__init__.py b/python/ray/rllib/examples/legacy_multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/rllib/examples/multiagent_mountaincar.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py similarity index 91% rename from python/ray/rllib/examples/multiagent_mountaincar.py rename to python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py index 29a7590b3..e3e20344b 100644 --- a/python/ray/rllib/examples/multiagent_mountaincar.py +++ b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py @@ -21,7 +21,9 @@ def pass_params_to_gym(env_name): register( id=env_name, - entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv", + entry_point=( + "ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:" + "MultiAgentMountainCarEnv"), max_episode_steps=200, kwargs={} ) diff --git a/python/ray/rllib/examples/multiagent_mountaincar_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py similarity index 100% rename from python/ray/rllib/examples/multiagent_mountaincar_env.py rename to python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py diff --git a/python/ray/rllib/examples/multiagent_pendulum.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py similarity index 92% rename from python/ray/rllib/examples/multiagent_pendulum.py rename to python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py index 9754f681e..d4cf8e5bf 100644 --- a/python/ray/rllib/examples/multiagent_pendulum.py +++ b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py @@ -21,7 +21,9 @@ def pass_params_to_gym(env_name): register( id=env_name, - entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv", + entry_point=( + "ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:" + "MultiAgentPendulumEnv"), max_episode_steps=100, kwargs={} ) diff --git a/python/ray/rllib/examples/multiagent_pendulum_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py similarity index 100% rename from python/ray/rllib/examples/multiagent_pendulum_env.py rename to python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py diff --git a/python/ray/rllib/examples/serving/cartpole_client.py b/python/ray/rllib/examples/serving/cartpole_client.py new file mode 100755 index 000000000..fb27e8567 --- /dev/null +++ b/python/ray/rllib/examples/serving/cartpole_client.py @@ -0,0 +1,55 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +"""Example of querying a policy server. Copy this file for your use case. + +To try this out, in two separate shells run: + $ python cartpole_server.py + $ python cartpole_client.py +""" + +import argparse +import gym + +from ray.rllib.utils.policy_client import PolicyClient + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--no-train", action="store_true", help="Whether to disable training.") +parser.add_argument( + "--off-policy", action="store_true", + help="Whether to take random instead of on-policy actions.") +parser.add_argument( + "--stop-at-reward", type=int, default=9999, + help="Stop once the specified reward is reached.") + + +if __name__ == "__main__": + args = parser.parse_args() + env = gym.make("CartPole-v0") + client = PolicyClient("http://localhost:8900") + + eid = client.start_episode(training_enabled=not args.no_train) + obs = env.reset() + rewards = 0 + + while True: + if args.off_policy: + action = env.action_space.sample() + client.log_action(eid, obs, action) + else: + action = client.get_action(eid, obs) + obs, reward, done, info = env.step(action) + rewards += reward + client.log_returns(eid, reward, info=info) + if done: + print("Total reward:", rewards) + if rewards >= args.stop_at_reward: + print("Target reward achieved, exiting") + exit(0) + rewards = 0 + client.end_episode(eid, obs) + obs = env.reset() + eid = client.start_episode(training_enabled=not args.no_train) diff --git a/python/ray/rllib/examples/serving/cartpole_server.py b/python/ray/rllib/examples/serving/cartpole_server.py new file mode 100755 index 000000000..ffbf9f6c6 --- /dev/null +++ b/python/ray/rllib/examples/serving/cartpole_server.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +"""Example of running a policy server. Copy this file for your use case. + +To try this out, in two separate shells run: + $ python cartpole_server.py + $ python cartpole_client.py +""" + +import os +from gym import spaces + +import ray +from ray.rllib.dqn import DQNAgent +from ray.rllib.utils.serving_env import ServingEnv +from ray.rllib.utils.policy_server import PolicyServer +from ray.tune.logger import pretty_print +from ray.tune.registry import register_env + +SERVER_ADDRESS = "localhost" +SERVER_PORT = 8900 +CHECKPOINT_FILE = "last_checkpoint.out" + + +class CartpoleServing(ServingEnv): + def __init__(self): + ServingEnv.__init__( + self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,))) + + def run(self): + print("Starting policy server at {}:{}".format( + SERVER_ADDRESS, SERVER_PORT)) + server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT) + server.serve_forever() + + +if __name__ == "__main__": + ray.init() + register_env("srv", lambda _: CartpoleServing()) + + # We use DQN since it supports off-policy actions, but you can choose and + # configure any agent. + dqn = DQNAgent(env="srv", config={ + # Use a single process to avoid needing to set up a load balancer + "num_workers": 0, + # Configure the agent to run short iterations for debugging + "exploration_fraction": 0.01, + "learning_starts": 100, + "timesteps_per_iteration": 200, + }) + + # Attempt to restore from checkpoint if possible. + if os.path.exists(CHECKPOINT_FILE): + checkpoint_path = open(CHECKPOINT_FILE).read() + print("Restoring from checkpoint path", checkpoint_path) + dqn.restore(checkpoint_path) + + # Serving and training loop + while True: + print(pretty_print(dqn.train())) + checkpoint_path = dqn.save() + print("Last checkpoint", checkpoint_path) + with open(CHECKPOINT_FILE, "w") as f: + f.write(checkpoint_path) diff --git a/python/ray/rllib/examples/serving/test.sh b/python/ray/rllib/examples/serving/test.sh new file mode 100755 index 000000000..d443a44a4 --- /dev/null +++ b/python/ray/rllib/examples/serving/test.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +pkill -f cartpole_server.py +(python cartpole_server.py 2>&1 | grep -v 200) & +pid=$! + +while ! curl localhost:8900; do + sleep 1 +done + +python cartpole_client.py --stop-at-reward=100 +kill $pid diff --git a/python/ray/rllib/test/test_serving_env.py b/python/ray/rllib/test/test_serving_env.py index 000cfac9f..4c4613a6d 100644 --- a/python/ray/rllib/test/test_serving_env.py +++ b/python/ray/rllib/test/test_serving_env.py @@ -24,16 +24,16 @@ class SimpleServing(ServingEnv): self.env = env def run(self): - self.start_episode() + eid = self.start_episode() obs = self.env.reset() while True: - action = self.get_action(obs) + action = self.get_action(eid, obs) obs, reward, done, info = self.env.step(action) - self.log_returns(reward, info=info) + self.log_returns(eid, reward, info=info) if done: - self.end_episode(obs) + self.end_episode(eid, obs) obs = self.env.reset() - self.start_episode() + eid = self.start_episode() class PartOffPolicyServing(ServingEnv): @@ -43,20 +43,20 @@ class PartOffPolicyServing(ServingEnv): self.off_pol_frac = off_pol_frac def run(self): - self.start_episode() + eid = self.start_episode() obs = self.env.reset() while True: if random.random() < self.off_pol_frac: action = self.env.action_space.sample() - self.log_action(obs, action) + self.log_action(eid, obs, action) else: - action = self.get_action(obs) + action = self.get_action(eid, obs) obs, reward, done, info = self.env.step(action) - self.log_returns(reward, info=info) + self.log_returns(eid, reward, info=info) if done: - self.end_episode(obs) + self.end_episode(eid, obs) obs = self.env.reset() - self.start_episode() + eid = self.start_episode() class SimpleOffPolicyServing(ServingEnv): @@ -65,18 +65,18 @@ class SimpleOffPolicyServing(ServingEnv): self.env = env def run(self): - self.start_episode() + eid = self.start_episode() obs = self.env.reset() while True: # Take random actions action = self.env.action_space.sample() - self.log_action(obs, action) + self.log_action(eid, obs, action) obs, reward, done, info = self.env.step(action) - self.log_returns(reward, info=info) + self.log_returns(eid, reward, info=info) if done: - self.end_episode(obs) + self.end_episode(eid, obs) obs = self.env.reset() - self.start_episode() + eid = self.start_episode() class MultiServing(ServingEnv): @@ -98,14 +98,13 @@ class MultiServing(ServingEnv): self.start_episode(episode_id=eids[i]) cur_obs[i] = envs[i].reset() actions = [ - self.get_action( - cur_obs[i], episode_id=eids[i]) for i in active] + self.get_action(eids[i], cur_obs[i]) for i in active] for i, action in zip(active, actions): obs, reward, done, _ = envs[i].step(action) cur_obs[i] = obs - self.log_returns(reward, episode_id=eids[i]) + self.log_returns(eids[i], reward) if done: - self.end_episode(obs, episode_id=eids[i]) + self.end_episode(eids[i], obs) del cur_obs[i] diff --git a/python/ray/rllib/utils/policy_client.py b/python/ray/rllib/utils/policy_client.py new file mode 100644 index 000000000..623d32c1e --- /dev/null +++ b/python/ray/rllib/utils/policy_client.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle + +try: + import requests # `requests` is not part of stdlib. +except ImportError: + requests = None + print("Couldn't import `requests` library. Be sure to install it on" + " the client side.") + + +class PolicyClient(object): + """Client to interact with a RLlib policy server.""" + + START_EPISODE = "START_EPISODE" + GET_ACTION = "GET_ACTION" + LOG_ACTION = "LOG_ACTION" + LOG_RETURNS = "LOG_RETURNS" + END_EPISODE = "END_EPISODE" + + def __init__(self, address): + self._address = address + + def start_episode(self, episode_id=None, training_enabled=True): + """Record the start of an episode. + + Arguments: + episode_id (str): Unique string id for the episode or None for + it to be auto-assigned. + training_enabled (bool): Whether to use experiences for this + episode to improve the policy. + + Returns: + episode_id (str): Unique string id for the episode. + """ + + return self._send({ + "episode_id": episode_id, + "command": PolicyClient.START_EPISODE, + "training_enabled": training_enabled, + })["episode_id"] + + def get_action(self, episode_id, observation): + """Record an observation and get the on-policy action. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + + Returns: + action (obj): Action from the env action space. + """ + return self._send({ + "command": PolicyClient.GET_ACTION, + "observation": observation, + "episode_id": episode_id, + })["action"] + + def log_action(self, episode_id, observation, action): + """Record an observation and (off-policy) action taken. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + action (obj): Action for the observation. + """ + self._send({ + "command": PolicyClient.LOG_ACTION, + "observation": observation, + "action": action, + "episode_id": episode_id, + }) + + def log_returns(self, episode_id, reward, info=None): + """Record returns from the environment. + + The reward will be attributed to the previous action taken by the + episode. Rewards accumulate until the next action. If no reward is + logged before the next action, a reward of 0.0 is assumed. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + reward (float): Reward from the environment. + """ + self._send({ + "command": PolicyClient.LOG_RETURNS, + "reward": reward, + "info": info, + "episode_id": episode_id, + }) + + def end_episode(self, episode_id, observation): + """Record the end of an episode. + + Arguments: + episode_id (str): Episode id returned from start_episode(). + observation (obj): Current environment observation. + """ + self._send({ + "command": PolicyClient.END_EPISODE, + "observation": observation, + "episode_id": episode_id, + }) + + def _send(self, data): + payload = pickle.dumps(data) + response = requests.post(self._address, data=payload) + if response.status_code != 200: + print("Request failed", data) + print(response.text) + response.raise_for_status() + parsed = pickle.loads(response.content) + return parsed diff --git a/python/ray/rllib/utils/policy_server.py b/python/ray/rllib/utils/policy_server.py new file mode 100644 index 000000000..708b14e05 --- /dev/null +++ b/python/ray/rllib/utils/policy_server.py @@ -0,0 +1,62 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle +import sys +import traceback + +from ray.rllib.utils.policy_client import PolicyClient + +if sys.version_info[0] == 2: + from SimpleHTTPServer import SimpleHTTPRequestHandler + from SocketServer import TCPServer as HTTPServer + from SocketServer import ThreadingMixIn +elif sys.version_info[0] == 3: + from http.server import SimpleHTTPRequestHandler, HTTPServer + from socketserver import ThreadingMixIn + + +class PolicyServer(ThreadingMixIn, HTTPServer): + def __init__(self, serving_env, address, port): + handler = _make_handler(serving_env) + HTTPServer.__init__(self, (address, port), handler) + + +def _make_handler(serving_env): + class Handler(SimpleHTTPRequestHandler): + def do_POST(self): + content_len = int(self.headers.get('Content-Length'), 0) + raw_body = self.rfile.read(content_len) + parsed_input = pickle.loads(raw_body) + try: + response = self.execute_command(parsed_input) + self.send_response(200) + self.end_headers() + self.wfile.write(pickle.dumps(response)) + except Exception: + self.send_error(500, traceback.format_exc()) + + def execute_command(self, args): + command = args["command"] + response = {} + if command == PolicyClient.START_EPISODE: + response["episode_id"] = serving_env.start_episode( + args["episode_id"], args["training_enabled"]) + elif command == PolicyClient.GET_ACTION: + response["action"] = serving_env.get_action( + args["episode_id"], args["observation"]) + elif command == PolicyClient.LOG_ACTION: + serving_env.log_action( + args["episode_id"], args["observation"], args["action"]) + elif command == PolicyClient.LOG_RETURNS: + serving_env.log_returns( + args["episode_id"], args["reward"], args["info"]) + elif command == PolicyClient.END_EPISODE: + serving_env.end_episode( + args["episode_id"], args["observation"]) + else: + raise Exception("Unknown command: {}".format(command)) + return response + + return Handler diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index 7ae66acd7..3a3364f8e 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -188,7 +188,7 @@ def _env_runner( while True: # Get observations from ready envs - unfiltered_obs, rewards, dones, _, off_policy_actions = \ + unfiltered_obs, rewards, dones, infos, off_policy_actions = \ async_vector_env.poll() ready_eids = [] ready_obs = [] @@ -216,24 +216,25 @@ def _env_runner( else: done = False - episode.batch_builder.add_values( - obs=episode.last_observation, - actions=episode.last_action_flat(), - rewards=rewards[eid], - dones=done, - new_obs=filtered_obs, - **episode.last_pi_info) + if infos[eid].get("training_enabled", True): + episode.batch_builder.add_values( + obs=episode.last_observation, + actions=episode.last_action_flat(), + rewards=rewards[eid], + dones=done, + new_obs=filtered_obs, + **episode.last_pi_info) - # Cut the batch if we're not packing multiple episodes into one, - # or if we've exceeded the requested batch size. - if (done and not pack) or \ - episode.batch_builder.count >= num_local_steps: - yield episode.batch_builder.build_and_reset( - policy.postprocess_trajectory) - elif done: - # Make sure postprocessor never goes across episode boundaries - episode.batch_builder.postprocess_batch_so_far( - policy.postprocess_trajectory) + # Cut the batch if we're not packing multiple episodes into + # one, or if we've exceeded the requested batch size. + if (done and not pack) or \ + episode.batch_builder.count >= num_local_steps: + yield episode.batch_builder.build_and_reset( + policy.postprocess_trajectory) + elif done: + # Make sure postprocessor never crosses episode boundaries + episode.batch_builder.postprocess_batch_so_far( + policy.postprocess_trajectory) if done: # Handle episode termination diff --git a/python/ray/rllib/utils/serving_env.py b/python/ray/rllib/utils/serving_env.py index cf09c5244..827a725b3 100644 --- a/python/ray/rllib/utils/serving_env.py +++ b/python/ray/rllib/utils/serving_env.py @@ -4,6 +4,7 @@ from __future__ import print_function from six.moves import queue import threading +import uuid from ray.rllib.utils.async_vector_env import AsyncVectorEnv @@ -26,8 +27,6 @@ class ServingEnv(threading.Thread): This env is thread-safe, but individual episodes must be executed serially. - TODO: Provide a HTTP server/client example based on ServingEnv. - Examples: >>> register_env("my_env", lambda config: YourServingEnv(config)) >>> agent = DQNAgent(env="my_env") @@ -51,8 +50,6 @@ class ServingEnv(threading.Thread): self.observation_space = observation_space self._episodes = {} self._finished = set() - self._num_episodes = 0 - self._cur_default_episode_id = None self._results_avail_condition = threading.Condition() self._max_concurrent_episodes = max_concurrent @@ -70,24 +67,21 @@ class ServingEnv(threading.Thread): """ raise NotImplementedError - def start_episode(self, episode_id=None): + def start_episode(self, episode_id=None, training_enabled=True): """Record the start of an episode. Arguments: episode_id (str): Unique string id for the episode or None for - it to be auto-assigned. Auto-assignment only works if there - is at most one active episode at a time. + it to be auto-assigned. + training_enabled (bool): Whether to use experiences for this + episode to improve the policy. + + Returns: + episode_id (str): Unique string id for the episode. """ if episode_id is None: - if self._cur_default_episode_id: - raise ValueError( - "An existing episode is still active. You must pass " - "`episode_id` if there are going to be multiple active " - "episodes at once.") - episode_id = "default_{}".format(self._num_episodes) - self._cur_default_episode_id = episode_id - self._num_episodes += 1 + episode_id = uuid.uuid4().hex if episode_id in self._finished: raise ValueError( @@ -98,14 +92,16 @@ class ServingEnv(threading.Thread): "Episode {} is already started".format(episode_id)) self._episodes[episode_id] = _Episode( - episode_id, self._results_avail_condition) + episode_id, self._results_avail_condition, training_enabled) - def get_action(self, observation, episode_id=None): + return episode_id + + def get_action(self, episode_id, observation): """Record an observation and get the on-policy action. Arguments: + episode_id (str): Episode id returned from start_episode(). observation (obj): Current environment observation. - episode_id (str): Episode id passed to start_episode() or None. Returns: action (obj): Action from the env action space. @@ -114,19 +110,19 @@ class ServingEnv(threading.Thread): episode = self._get(episode_id) return episode.wait_for_action(observation) - def log_action(self, observation, action, episode_id=None): + def log_action(self, episode_id, observation, action): """Record an observation and (off-policy) action taken. Arguments: + episode_id (str): Episode id returned from start_episode(). observation (obj): Current environment observation. action (obj): Action for the observation. - episode_id (str): Episode id passed to start_episode() or None. """ episode = self._get(episode_id) episode.log_action(observation, action) - def log_returns(self, reward, info=None, episode_id=None): + def log_returns(self, episode_id, reward, info=None): """Record returns from the environment. The reward will be attributed to the previous action taken by the @@ -134,34 +130,31 @@ class ServingEnv(threading.Thread): logged before the next action, a reward of 0.0 is assumed. Arguments: - episode_id (str): Episode id passed to start_episode() or None. + episode_id (str): Episode id returned from start_episode(). reward (float): Reward from the environment. + info (dict): Optional info dict. """ episode = self._get(episode_id) episode.cur_reward += reward if info: - episode.cur_info = info + episode.cur_info = info or {} - def end_episode(self, observation, episode_id=None): + def end_episode(self, episode_id, observation): """Record the end of an episode. Arguments: - episode_id (str): Episode id passed by start_episode() or None. + episode_id (str): Episode id returned from start_episode(). observation (obj): Current environment observation. """ episode = self._get(episode_id) self._finished.add(episode.episode_id) - self._cur_default_episode_id = None episode.done(observation) - def _get(self, episode_id=None): + def _get(self, episode_id): """Get a started episode or raise an error.""" - if episode_id is None: - episode_id = self._cur_default_episode_id - if episode_id in self._finished: raise ValueError( "Episode {} has already completed.".format(episode_id)) @@ -217,9 +210,10 @@ class _ServingEnvToAsync(AsyncVectorEnv): class _Episode(object): """Tracked state for each active episode.""" - def __init__(self, episode_id, results_avail_condition): + def __init__(self, episode_id, results_avail_condition, training_enabled): self.episode_id = episode_id self.results_avail_condition = results_avail_condition + self.training_enabled = training_enabled self.data_queue = queue.Queue() self.action_queue = queue.Queue() self.new_observation = None @@ -258,6 +252,8 @@ class _Episode(object): } if self.new_action is not None: item["off_policy_action"] = self.new_action + if not self.training_enabled: + item["info"]["training_enabled"] = False self.new_observation = None self.new_action = None self.cur_reward = 0.0 diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index bcdc5cfcc..d29077ac6 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -253,7 +253,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --smoke-test docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/multiagent_mountaincar.py + python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/multiagent_pendulum.py + python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py