mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[rllib] rollout.py should reduce num workers (#3263)
## What do these changes do? Don't create an excessive amount of workers for rollout.py, and also fix up the env wrapping to be consistent with the internal agent wrapper. ## Related issue number Closes #3260.
This commit is contained in:
@@ -1,19 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind
|
||||
|
||||
|
||||
def wrap_dqn(env, options):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = hasattr(env.unwrapped, "ale")
|
||||
|
||||
# Override atari default to use the deepmind wrappers.
|
||||
# TODO(ekl) this logic should be pushed to the catalog.
|
||||
if is_atari and not options.get("custom_preprocessor"):
|
||||
return wrap_deepmind(env, dim=options.get("dim", 84))
|
||||
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
@@ -12,7 +12,6 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
@@ -66,7 +65,8 @@ def create_parser(parser_creator=None):
|
||||
|
||||
|
||||
def run(args, parser):
|
||||
if not args.config:
|
||||
config = args.config
|
||||
if not config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
@@ -77,23 +77,24 @@ def run(args, parser):
|
||||
"Could not find params.json in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
with open(config_path) as f:
|
||||
args.config = json.load(f)
|
||||
config = json.load(f)
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
if not args.env:
|
||||
if not args.config.get("env"):
|
||||
if not config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
args.env = args.config.get("env")
|
||||
args.env = config.get("env")
|
||||
|
||||
ray.init()
|
||||
|
||||
cls = get_agent_class(args.run)
|
||||
agent = cls(env=args.env, config=args.config)
|
||||
agent = cls(env=args.env, config=config)
|
||||
agent.restore(args.checkpoint)
|
||||
num_steps = int(args.steps)
|
||||
|
||||
if args.run == "DQN":
|
||||
env = gym.make(args.env)
|
||||
env = wrap_dqn(env, args.config.get("model", {}))
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
else:
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
|
||||
if args.out is not None:
|
||||
|
||||
Reference in New Issue
Block a user