mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 06:30:33 +08:00
[rllib] rollout.py should reduce num workers (#3263)
## What do these changes do? Don't create an excessive amount of workers for rollout.py, and also fix up the env wrapping to be consistent with the internal agent wrapper. ## Related issue number Closes #3260.
This commit is contained in:
@@ -199,9 +199,9 @@ There is a full example of this in the `example training script <https://github.
|
||||
Implementing a Centralized Critic
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Implementing a shared critic between multiple policies requires the definition of custom policy graphs. It can be done as follows:
|
||||
Implementing a centralized critic that takes as input the observations and actions of other concurrent agents requires the definition of custom policy graphs. It can be done as follows:
|
||||
|
||||
1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. This assumes you use variable sharing to access the critic network from multiple policies. The critic predictions can then be added to the postprocessed trajectory. Here's an example:
|
||||
1. Querying the critic: this can be done in the ``postprocess_trajectory`` method of a custom policy graph, which has full access to the policies and observations of concurrent agents via the ``other_agent_batches`` and ``episode`` arguments. The batch of critic predictions can then be added to the postprocessed trajectory. Here's an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -212,16 +212,11 @@ Implementing a shared critic between multiple policies requires the definition o
|
||||
axis=1)
|
||||
# add the global obs and global critic value
|
||||
sample_batch["global_obs"] = global_obs_batch
|
||||
sample_batch["global_vf"] = self.sess.run(
|
||||
self.global_critic_network, feed_dict={"obs": global_obs_batch})
|
||||
# metrics like "global reward" can be retrieved from the info return of the environment
|
||||
sample_batch["global_rewards"] = [
|
||||
info["global_reward"] for info in sample_batch["infos"]]
|
||||
sample_batch["central_vf"] = self.sess.run(
|
||||
self.critic_network, feed_dict={"obs": global_obs_batch})
|
||||
return sample_batch
|
||||
|
||||
2. Updating the critic: the centralized critic loss can be added to the loss of some arbitrary policy graph. The policy graph that is chosen must add the inputs for the critic loss to its postprocessed trajectory batches.
|
||||
|
||||
For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.
|
||||
2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy graph, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicyGraph example <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy_graph.py>`__.
|
||||
|
||||
Agent-Driven
|
||||
------------
|
||||
|
||||
@@ -5,7 +5,7 @@ RLlib is an open-source library for reinforcement learning that offers both a co
|
||||
|
||||
.. image:: rllib-stack.svg
|
||||
|
||||
Learn more about RLlib's design by reading the `ICML paper <https://arxiv.org/abs/1712.09381>`__.
|
||||
RLlib is built on `Ray <https://github.com/ray-project/ray>`__. Learn more about RLlib's design by reading the `ICML paper <https://arxiv.org/abs/1712.09381>`__.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind
|
||||
|
||||
|
||||
def wrap_dqn(env, options):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = hasattr(env.unwrapped, "ale")
|
||||
|
||||
# Override atari default to use the deepmind wrappers.
|
||||
# TODO(ekl) this logic should be pushed to the catalog.
|
||||
if is_atari and not options.get("custom_preprocessor"):
|
||||
return wrap_deepmind(env, dim=options.get("dim", 84))
|
||||
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
@@ -12,7 +12,6 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
@@ -66,7 +65,8 @@ def create_parser(parser_creator=None):
|
||||
|
||||
|
||||
def run(args, parser):
|
||||
if not args.config:
|
||||
config = args.config
|
||||
if not config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
config_path = os.path.join(config_dir, "params.json")
|
||||
@@ -77,23 +77,24 @@ def run(args, parser):
|
||||
"Could not find params.json in either the checkpoint dir or "
|
||||
"its parent directory.")
|
||||
with open(config_path) as f:
|
||||
args.config = json.load(f)
|
||||
config = json.load(f)
|
||||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
if not args.env:
|
||||
if not args.config.get("env"):
|
||||
if not config.get("env"):
|
||||
parser.error("the following arguments are required: --env")
|
||||
args.env = args.config.get("env")
|
||||
args.env = config.get("env")
|
||||
|
||||
ray.init()
|
||||
|
||||
cls = get_agent_class(args.run)
|
||||
agent = cls(env=args.env, config=args.config)
|
||||
agent = cls(env=args.env, config=config)
|
||||
agent.restore(args.checkpoint)
|
||||
num_steps = int(args.steps)
|
||||
|
||||
if args.run == "DQN":
|
||||
env = gym.make(args.env)
|
||||
env = wrap_dqn(env, args.config.get("model", {}))
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
else:
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
|
||||
if args.out is not None:
|
||||
|
||||
Reference in New Issue
Block a user