[rllib] Add a simple REST policy server and client example (#2232)

* wip

* cls

* re

* wip

* wip

* a3c working

* torch support

* pg works

* lint

* rm v2

* consumer id

* clean up pg

* clean up more

* fix python 2.7

* tf session management

* docs

* dqn wip

* fix compile

* dqn

* apex runs

* up

* impotrs

* ddpg

* quotes

* fix tests

* fix last r

* fix tests

* lint

* pass checkpoint restore

* kwar

* nits

* policy graph

* fix yapf

* com

* class

* pyt

* vectorization

* update

* test cpe

* unit test

* fix ddpg2

* changes

* wip

* args

* faster test

* common

* fix

* add alg option

* batch mode and policy serving

* multi serving test

* todo

* wip

* serving test

* doc async env

* num envs

* comments

* thread

* remove init hook

* update

* policy serve

* spaces

* checkpoint

* no train

* fix ppo

* comments1

* fix

* updates

* add jenkins tests

* fix

* fix pytorch

* fix

* fixes

* fix a3c policy

* fix squeeze

* fix trunc on apex

* fix squeezing for real

* update

* remove horizon test for now

* fix race condition

* update

* com

* updat

* add test

* Update run_multi_node_tests.sh

* use curl

* curl

* kill

* Update run_multi_node_tests.sh

* Update run_multi_node_tests.sh

* fix import

* update
This commit is contained in:
Eric Liang
2018-06-20 13:22:39 -07:00
committed by GitHub
parent 418cd6804a
commit e5724a9cfe
15 changed files with 384 additions and 78 deletions
-5
View File
@@ -1,5 +0,0 @@
# flake8: noqa
from ray.rllib.examples.multiagent_mountaincar_env \
import MultiAgentMountainCarEnv
from ray.rllib.examples.multiagent_pendulum_env \
import MultiAgentPendulumEnv
@@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv",
entry_point=(
"ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:"
"MultiAgentMountainCarEnv"),
max_episode_steps=200,
kwargs={}
)
@@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
register(
id=env_name,
entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv",
entry_point=(
"ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:"
"MultiAgentPendulumEnv"),
max_episode_steps=100,
kwargs={}
)
+55
View File
@@ -0,0 +1,55 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of querying a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py
"""
import argparse
import gym
from ray.rllib.utils.policy_client import PolicyClient
parser = argparse.ArgumentParser()
parser.add_argument(
"--no-train", action="store_true", help="Whether to disable training.")
parser.add_argument(
"--off-policy", action="store_true",
help="Whether to take random instead of on-policy actions.")
parser.add_argument(
"--stop-at-reward", type=int, default=9999,
help="Stop once the specified reward is reached.")
if __name__ == "__main__":
args = parser.parse_args()
env = gym.make("CartPole-v0")
client = PolicyClient("http://localhost:8900")
eid = client.start_episode(training_enabled=not args.no_train)
obs = env.reset()
rewards = 0
while True:
if args.off_policy:
action = env.action_space.sample()
client.log_action(eid, obs, action)
else:
action = client.get_action(eid, obs)
obs, reward, done, info = env.step(action)
rewards += reward
client.log_returns(eid, reward, info=info)
if done:
print("Total reward:", rewards)
if rewards >= args.stop_at_reward:
print("Target reward achieved, exiting")
exit(0)
rewards = 0
client.end_episode(eid, obs)
obs = env.reset()
eid = client.start_episode(training_enabled=not args.no_train)
+66
View File
@@ -0,0 +1,66 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of running a policy server. Copy this file for your use case.
To try this out, in two separate shells run:
$ python cartpole_server.py
$ python cartpole_client.py
"""
import os
from gym import spaces
import ray
from ray.rllib.dqn import DQNAgent
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.utils.policy_server import PolicyServer
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
SERVER_ADDRESS = "localhost"
SERVER_PORT = 8900
CHECKPOINT_FILE = "last_checkpoint.out"
class CartpoleServing(ServingEnv):
def __init__(self):
ServingEnv.__init__(
self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,)))
def run(self):
print("Starting policy server at {}:{}".format(
SERVER_ADDRESS, SERVER_PORT))
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
server.serve_forever()
if __name__ == "__main__":
ray.init()
register_env("srv", lambda _: CartpoleServing())
# We use DQN since it supports off-policy actions, but you can choose and
# configure any agent.
dqn = DQNAgent(env="srv", config={
# Use a single process to avoid needing to set up a load balancer
"num_workers": 0,
# Configure the agent to run short iterations for debugging
"exploration_fraction": 0.01,
"learning_starts": 100,
"timesteps_per_iteration": 200,
})
# Attempt to restore from checkpoint if possible.
if os.path.exists(CHECKPOINT_FILE):
checkpoint_path = open(CHECKPOINT_FILE).read()
print("Restoring from checkpoint path", checkpoint_path)
dqn.restore(checkpoint_path)
# Serving and training loop
while True:
print(pretty_print(dqn.train()))
checkpoint_path = dqn.save()
print("Last checkpoint", checkpoint_path)
with open(CHECKPOINT_FILE, "w") as f:
f.write(checkpoint_path)
+12
View File
@@ -0,0 +1,12 @@
#!/bin/bash
pkill -f cartpole_server.py
(python cartpole_server.py 2>&1 | grep -v 200) &
pid=$!
while ! curl localhost:8900; do
sleep 1
done
python cartpole_client.py --stop-at-reward=100
kill $pid