[rllib] rollout.py should reduce num workers (#3263)

## What do these changes do?

Don't create an excessive amount of workers for rollout.py, and also fix up the env wrapping to be consistent with the internal agent wrapper.

## Related issue number

Closes #3260.
This commit is contained in:
Eric Liang
2018-11-09 12:29:16 -08:00
committed by Richard Liaw
parent 22113be04c
commit 9dd3eedbac
4 changed files with 16 additions and 39 deletions
+5 -10
View File
@@ -199,9 +199,9 @@ There is a full example of this in the `example training script <https://github.
Implementing a Centralized Critic
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Implementing a shared critic between multiple policies requires the definition of custom policy graphs. It can be done as follows:
Implementing a centralized critic that takes as input the observations and actions of other concurrent agents requires the definition of custom policy graphs. It can be done as follows:
1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. This assumes you use variable sharing to access the critic network from multiple policies. The critic predictions can then be added to the postprocessed trajectory. Here's an example:
1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. The batch of critic predictions can then be added to the postprocessed trajectory. Here's an example:
.. code-block:: python
@@ -212,16 +212,11 @@ Implementing a shared critic between multiple policies requires the definition o
axis=1)
# add the global obs and global critic value
sample_batch["global_obs"] = global_obs_batch
sample_batch["global_vf"] = self.sess.run(
self.global_critic_network, feed_dict={"obs": global_obs_batch})
# metrics like "global reward" can be retrieved from the info return of the environment
sample_batch["global_rewards"] = [
info["global_reward"] for info in sample_batch["infos"]]
sample_batch["central_vf"] = self.sess.run(
self.critic_network, feed_dict={"obs": global_obs_batch})
return sample_batch
2. Updating the critic: the centralized critic loss can be added to the loss of some arbitrary policy graph. The policy graph that is chosen must add the inputs for the critic loss to its postprocessed trajectory batches.
For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.
2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy graph, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.
Agent-Driven
------------
+1 -1
View File
@@ -5,7 +5,7 @@ RLlib is an open-source library for reinforcement learning that offers both a co
.. image:: rllib-stack.svg
Learn more about RLlib's design by reading the `ICML paper <https://arxiv.org/abs/1712.09381>`__.
RLlib is built on `Ray <https://github.com/ray-project/ray>`__. Learn more about RLlib's design by reading the `ICML paper <https://arxiv.org/abs/1712.09381>`__.
Installation
------------
@@ -1,19 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.models import ModelCatalog
from ray.rllib.env.atari_wrappers import wrap_deepmind
def wrap_dqn(env, options):
"""Apply a common set of wrappers for DQN."""
is_atari = hasattr(env.unwrapped, "ale")
# Override atari default to use the deepmind wrappers.
# TODO(ekl) this logic should be pushed to the catalog.
if is_atari and not options.get("custom_preprocessor"):
return wrap_deepmind(env, dim=options.get("dim", 84))
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
+10 -9
View File
@@ -12,7 +12,6 @@ import pickle
import gym
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.dqn.common.wrappers import wrap_dqn
from ray.rllib.models import ModelCatalog
EXAMPLE_USAGE = """
@@ -66,7 +65,8 @@ def create_parser(parser_creator=None):
def run(args, parser):
if not args.config:
config = args.config
if not config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.json")
@@ -77,23 +77,24 @@ def run(args, parser):
"Could not find params.json in either the checkpoint dir or "
"its parent directory.")
with open(config_path) as f:
args.config = json.load(f)
config = json.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
if not args.env:
if not args.config.get("env"):
if not config.get("env"):
parser.error("the following arguments are required: --env")
args.env = args.config.get("env")
args.env = config.get("env")
ray.init()
cls = get_agent_class(args.run)
agent = cls(env=args.env, config=args.config)
agent = cls(env=args.env, config=config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
if args.run == "DQN":
env = gym.make(args.env)
env = wrap_dqn(env, args.config.get("model", {}))
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
else:
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
if args.out is not None: