mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 05:52:36 +08:00
@@ -1,7 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of querying a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
@@ -14,18 +13,19 @@ import gym
|
||||
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--no-train", action="store_true", help="Whether to disable training.")
|
||||
parser.add_argument(
|
||||
"--off-policy", action="store_true",
|
||||
"--off-policy",
|
||||
action="store_true",
|
||||
help="Whether to take random instead of on-policy actions.")
|
||||
parser.add_argument(
|
||||
"--stop-at-reward", type=int, default=9999,
|
||||
"--stop-at-reward",
|
||||
type=int,
|
||||
default=9999,
|
||||
help="Stop once the specified reward is reached.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Example of running a policy server. Copy this file for your use case.
|
||||
|
||||
To try this out, in two separate shells run:
|
||||
@@ -26,12 +25,12 @@ CHECKPOINT_FILE = "last_checkpoint.out"
|
||||
|
||||
class CartpoleServing(ServingEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
self, spaces.Discrete(2), spaces.Box(low=-10, high=10, shape=(4,)))
|
||||
ServingEnv.__init__(self, spaces.Discrete(2),
|
||||
spaces.Box(low=-10, high=10, shape=(4, )))
|
||||
|
||||
def run(self):
|
||||
print("Starting policy server at {}:{}".format(
|
||||
SERVER_ADDRESS, SERVER_PORT))
|
||||
print("Starting policy server at {}:{}".format(SERVER_ADDRESS,
|
||||
SERVER_PORT))
|
||||
server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
|
||||
server.serve_forever()
|
||||
|
||||
@@ -42,14 +41,16 @@ if __name__ == "__main__":
|
||||
|
||||
# We use DQN since it supports off-policy actions, but you can choose and
|
||||
# configure any agent.
|
||||
dqn = DQNAgent(env="srv", config={
|
||||
# Use a single process to avoid needing to set up a load balancer
|
||||
"num_workers": 0,
|
||||
# Configure the agent to run short iterations for debugging
|
||||
"exploration_fraction": 0.01,
|
||||
"learning_starts": 100,
|
||||
"timesteps_per_iteration": 200,
|
||||
})
|
||||
dqn = DQNAgent(
|
||||
env="srv",
|
||||
config={
|
||||
# Use a single process to avoid needing to set up a load balancer
|
||||
"num_workers": 0,
|
||||
# Configure the agent to run short iterations for debugging
|
||||
"exploration_fraction": 0.01,
|
||||
"learning_starts": 100,
|
||||
"timesteps_per_iteration": 200,
|
||||
})
|
||||
|
||||
# Attempt to restore from checkpoint if possible.
|
||||
if os.path.exists(CHECKPOINT_FILE):
|
||||
|
||||
Reference in New Issue
Block a user