mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 07:09:33 +08:00
f4b313eaad
* Moved clip_action into policy_graph; Clip actions in compute_single_action * Update policy_graph.py * Changed formatting * Updated codebase for convencience
182 lines
6.4 KiB
Python
Executable File
182 lines
6.4 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import pickle
|
|
|
|
import gym
|
|
import ray
|
|
from ray.rllib.agents.registry import get_agent_class
|
|
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
|
from ray.tune.util import merge_dicts
|
|
|
|
EXAMPLE_USAGE = """
|
|
Example Usage via RLlib CLI:
|
|
rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
|
|
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
|
|
|
Example Usage via executable:
|
|
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
|
|
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
|
"""
|
|
|
|
# Note: if you use any custom models or envs, register them here first, e.g.:
|
|
#
|
|
# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
|
|
# register_env("pa_cartpole", lambda _: ParametricActionCartpole(10))
|
|
|
|
|
|
def create_parser(parser_creator=None):
|
|
parser_creator = parser_creator or argparse.ArgumentParser
|
|
parser = parser_creator(
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="Roll out a reinforcement learning agent "
|
|
"given a checkpoint.",
|
|
epilog=EXAMPLE_USAGE)
|
|
|
|
parser.add_argument(
|
|
"checkpoint", type=str, help="Checkpoint from which to roll out.")
|
|
required_named = parser.add_argument_group("required named arguments")
|
|
required_named.add_argument(
|
|
"--run",
|
|
type=str,
|
|
required=True,
|
|
help="The algorithm or model to train. This may refer to the name "
|
|
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
|
"user-defined trainable function or class registered in the "
|
|
"tune registry.")
|
|
required_named.add_argument(
|
|
"--env", type=str, help="The gym environment to use.")
|
|
parser.add_argument(
|
|
"--no-render",
|
|
default=False,
|
|
action="store_const",
|
|
const=True,
|
|
help="Surpress rendering of the environment.")
|
|
parser.add_argument(
|
|
"--steps", default=10000, help="Number of steps to roll out.")
|
|
parser.add_argument("--out", default=None, help="Output filename.")
|
|
parser.add_argument(
|
|
"--config",
|
|
default="{}",
|
|
type=json.loads,
|
|
help="Algorithm-specific configuration (e.g. env, hyperparams). "
|
|
"Surpresses loading of configuration from checkpoint.")
|
|
return parser
|
|
|
|
|
|
def run(args, parser):
|
|
config = {}
|
|
# Load configuration from file
|
|
config_dir = os.path.dirname(args.checkpoint)
|
|
config_path = os.path.join(config_dir, "params.pkl")
|
|
if not os.path.exists(config_path):
|
|
config_path = os.path.join(config_dir, "../params.pkl")
|
|
if not os.path.exists(config_path):
|
|
if not args.config:
|
|
raise ValueError(
|
|
"Could not find params.pkl in either the checkpoint dir or "
|
|
"its parent directory.")
|
|
else:
|
|
with open(config_path, 'rb') as f:
|
|
config = pickle.load(f)
|
|
if "num_workers" in config:
|
|
config["num_workers"] = min(2, config["num_workers"])
|
|
config = merge_dicts(config, args.config)
|
|
if not args.env:
|
|
if not config.get("env"):
|
|
parser.error("the following arguments are required: --env")
|
|
args.env = config.get("env")
|
|
|
|
ray.init()
|
|
|
|
cls = get_agent_class(args.run)
|
|
agent = cls(env=args.env, config=config)
|
|
agent.restore(args.checkpoint)
|
|
num_steps = int(args.steps)
|
|
rollout(agent, args.env, num_steps, args.out, args.no_render)
|
|
|
|
|
|
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
|
if hasattr(agent, "local_evaluator"):
|
|
env = agent.local_evaluator.env
|
|
multiagent = agent.local_evaluator.multiagent
|
|
if multiagent:
|
|
policy_agent_mapping = agent.config["multiagent"][
|
|
"policy_mapping_fn"]
|
|
mapping_cache = {}
|
|
policy_map = agent.local_evaluator.policy_map
|
|
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
|
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
|
else:
|
|
env = gym.make(env_name)
|
|
multiagent = False
|
|
use_lstm = {DEFAULT_POLICY_ID: False}
|
|
|
|
if out is not None:
|
|
rollouts = []
|
|
steps = 0
|
|
while steps < (num_steps or steps + 1):
|
|
if out is not None:
|
|
rollout = []
|
|
state = env.reset()
|
|
done = False
|
|
reward_total = 0.0
|
|
while not done and steps < (num_steps or steps + 1):
|
|
if multiagent:
|
|
action_dict = {}
|
|
for agent_id in state.keys():
|
|
a_state = state[agent_id]
|
|
if a_state is not None:
|
|
policy_id = mapping_cache.setdefault(
|
|
agent_id, policy_agent_mapping(agent_id))
|
|
p_use_lstm = use_lstm[policy_id]
|
|
if p_use_lstm:
|
|
a_action, p_state_init, _ = agent.compute_action(
|
|
a_state,
|
|
state=state_init[policy_id],
|
|
policy_id=policy_id)
|
|
state_init[policy_id] = p_state_init
|
|
else:
|
|
a_action = agent.compute_action(
|
|
a_state, policy_id=policy_id)
|
|
action_dict[agent_id] = a_action
|
|
action = action_dict
|
|
else:
|
|
if use_lstm[DEFAULT_POLICY_ID]:
|
|
action, state_init, _ = agent.compute_action(
|
|
state, state=state_init)
|
|
else:
|
|
action = agent.compute_action(state)
|
|
|
|
next_state, reward, done, _ = env.step(action)
|
|
|
|
if multiagent:
|
|
done = done["__all__"]
|
|
reward_total += sum(reward.values())
|
|
else:
|
|
reward_total += reward
|
|
if not no_render:
|
|
env.render()
|
|
if out is not None:
|
|
rollout.append([state, action, next_state, reward, done])
|
|
steps += 1
|
|
state = next_state
|
|
if out is not None:
|
|
rollouts.append(rollout)
|
|
print("Episode reward", reward_total)
|
|
|
|
if out is not None:
|
|
pickle.dump(rollouts, open(out, "wb"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
run(args, parser)
|