mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 06:12:09 +08:00
[rllib] Add a simple REST policy server and client example (#2232)
* wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * policy serve * spaces * checkpoint * no train * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * fix race condition * update * com * updat * add test * Update run_multi_node_tests.sh * use curl * curl * kill * Update run_multi_node_tests.sh * Update run_multi_node_tests.sh * fix import * update
This commit is contained in:
@@ -1,5 +0,0 @@
|
||||
# flake8: noqa
|
||||
from ray.rllib.examples.multiagent_mountaincar_env \
|
||||
import MultiAgentMountainCarEnv
|
||||
from ray.rllib.examples.multiagent_pendulum_env \
|
||||
import MultiAgentPendulumEnv
|
||||
|
||||
+3
-1
@@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
|
||||
|
||||
register(
|
||||
id=env_name,
|
||||
entry_point='ray.rllib.examples:' + "MultiAgentMountainCarEnv",
|
||||
entry_point=(
|
||||
"ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:"
|
||||
"MultiAgentMountainCarEnv"),
|
||||
max_episode_steps=200,
|
||||
kwargs={}
|
||||
)
|
||||
+3
-1
@@ -21,7 +21,9 @@ def pass_params_to_gym(env_name):
|
||||
|
||||
register(
|
||||
id=env_name,
|
||||
entry_point='ray.rllib.examples:' + "MultiAgentPendulumEnv",
|
||||
entry_point=(
|
||||
"ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:"
|
||||
"MultiAgentPendulumEnv"),
|
||||
max_episode_steps=100,
|
||||
kwargs={}
|
||||
)
|
||||
+55
@@ -0,0 +1,55 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of querying a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
$ python cartpole_server.py
|
||||
$ python cartpole_client.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--no-train", action="store_true", help="Whether to disable training.")
|
||||
parser.add_argument(
|
||||
"--off-policy", action="store_true",
|
||||
help="Whether to take random instead of on-policy actions.")
|
||||
parser.add_argument(
|
||||
"--stop-at-reward", type=int, default=9999,
|
||||
help="Stop once the specified reward is reached.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
env = gym.make("CartPole-v0")
|
||||
client = PolicyClient("http://localhost:8900")
|
||||
|
||||
eid = client.start_episode(training_enabled=not args.no_train)
|
||||
obs = env.reset()
|
||||
rewards = 0
|
||||
|
||||
while True:
|
||||
if args.off_policy:
|
||||
action = env.action_space.sample()
|
||||
client.log_action(eid, obs, action)
|
||||
else:
|
||||
action = client.get_action(eid, obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
rewards += reward
|
||||
client.log_returns(eid, reward, info=info)
|
||||
if done:
|
||||
print("Total reward:", rewards)
|
||||
if rewards >= args.stop_at_reward:
|
||||
print("Target reward achieved, exiting")
|
||||
exit(0)
|
||||
rewards = 0
|
||||
client.end_episode(eid, obs)
|
||||
obs = env.reset()
|
||||
eid = client.start_episode(training_enabled=not args.no_train)
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of running a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
$ python cartpole_server.py
|
||||
$ python cartpole_client.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from gym import spaces
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn import DQNAgent
|
||||
from ray.rllib.utils.serving_env import ServingEnv
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
SERVER_ADDRESS = "localhost"
|
||||
SERVER_PORT = 8900
|
||||
CHECKPOINT_FILE = "last_checkpoint.out"
|
||||
|
||||
|
||||
class CartpoleServing(ServingEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,)))
|
||||
|
||||
def run(self):
|
||||
print("Starting policy server at {}:{}".format(
|
||||
SERVER_ADDRESS, SERVER_PORT))
|
||||
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
register_env("srv", lambda _: CartpoleServing())
|
||||
|
||||
# We use DQN since it supports off-policy actions, but you can choose and
|
||||
# configure any agent.
|
||||
dqn = DQNAgent(env="srv", config={
|
||||
# Use a single process to avoid needing to set up a load balancer
|
||||
"num_workers": 0,
|
||||
# Configure the agent to run short iterations for debugging
|
||||
"exploration_fraction": 0.01,
|
||||
"learning_starts": 100,
|
||||
"timesteps_per_iteration": 200,
|
||||
})
|
||||
|
||||
# Attempt to restore from checkpoint if possible.
|
||||
if os.path.exists(CHECKPOINT_FILE):
|
||||
checkpoint_path = open(CHECKPOINT_FILE).read()
|
||||
print("Restoring from checkpoint path", checkpoint_path)
|
||||
dqn.restore(checkpoint_path)
|
||||
|
||||
# Serving and training loop
|
||||
while True:
|
||||
print(pretty_print(dqn.train()))
|
||||
checkpoint_path = dqn.save()
|
||||
print("Last checkpoint", checkpoint_path)
|
||||
with open(CHECKPOINT_FILE, "w") as f:
|
||||
f.write(checkpoint_path)
|
||||
Executable
+12
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
pkill -f cartpole_server.py
|
||||
(python cartpole_server.py 2>&1 | grep -v 200) &
|
||||
pid=$!
|
||||
|
||||
while ! curl localhost:8900; do
|
||||
sleep 1
|
||||
done
|
||||
|
||||
python cartpole_client.py --stop-at-reward=100
|
||||
kill $pid
|
||||
Reference in New Issue
Block a user