[rllib] rollout.py should reduce num workers (#3263)

## What do these changes do?

Don't create an excessive amount of workers for rollout.py, and also fix up the env wrapping to be consistent with the internal agent wrapper.

## Related issue number

Closes #3260.
This commit is contained in:
Eric Liang
2018-11-09 12:29:16 -08:00
committed by Richard Liaw
parent 22113be04c
commit 9dd3eedbac
4 changed files with 16 additions and 39 deletions
@@ -1,19 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.models import ModelCatalog
from ray.rllib.env.atari_wrappers import wrap_deepmind
def wrap_dqn(env, options):
"""Apply a common set of wrappers for DQN."""
is_atari = hasattr(env.unwrapped, "ale")
# Override atari default to use the deepmind wrappers.
# TODO(ekl) this logic should be pushed to the catalog.
if is_atari and not options.get("custom_preprocessor"):
return wrap_deepmind(env, dim=options.get("dim", 84))
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
+10 -9
View File
@@ -12,7 +12,6 @@ import pickle
import gym
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.dqn.common.wrappers import wrap_dqn
from ray.rllib.models import ModelCatalog
EXAMPLE_USAGE = """
@@ -66,7 +65,8 @@ def create_parser(parser_creator=None):
def run(args, parser):
if not args.config:
config = args.config
if not config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
@@ -77,23 +77,24 @@ def run(args, parser):
"Could not find params.json in either the checkpoint dir or "
"its parent directory.")
with open(config_path) as f:
args.config = json.load(f)
config = json.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
if not args.env:
if not args.config.get("env"):
if not config.get("env"):
parser.error("the following arguments are required: --env")
args.env = args.config.get("env")
args.env = config.get("env")
ray.init()
cls = get_agent_class(args.run)
agent = cls(env=args.env, config=args.config)
agent = cls(env=args.env, config=config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
if args.run == "DQN":
env = gym.make(args.env)
env = wrap_dqn(env, args.config.get("model", {}))
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
else:
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
if args.out is not None: