mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 22:23:13 +08:00
[rllib] Rename sample_batch_size => rollout_fragment_length (#7503)
* bulk rename * deprecation warn * update doc * update fig * line length * rename * make pytest comptaible * fix test * fi sys * rename * wip * fix more * lint * update svg * comments * lint * fix use of batch steps
This commit is contained in:
@@ -40,7 +40,7 @@ run_experiments({
|
||||
"num_gpus": 0,
|
||||
"buffer_size": 10000,
|
||||
"learning_starts": 0,
|
||||
"sample_batch_size": 1,
|
||||
"rollout_fragment_length": 1,
|
||||
"train_batch_size": 1,
|
||||
"min_iter_time_s": 10,
|
||||
"timesteps_per_iteration": 10,
|
||||
|
||||
@@ -41,7 +41,7 @@ run_experiments({
|
||||
"num_envs_per_worker": 5,
|
||||
"remote_worker_envs": True,
|
||||
"remote_env_batch_wait_ms": 99999999,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 100,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ atari-impala:
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 10
|
||||
num_envs_per_worker: 5
|
||||
@@ -36,7 +36,7 @@ atari-ppo-tf:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
@@ -60,7 +60,7 @@ atari-ppo-torch:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
@@ -94,7 +94,7 @@ apex:
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
@@ -105,7 +105,7 @@ atari-a2c:
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
clip_rewards: True
|
||||
num_workers: 5
|
||||
num_envs_per_worker: 5
|
||||
@@ -133,7 +133,7 @@ atari-basic-dqn:
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Taken from rllib/tuned_examples/atari_impala_large.yaml
|
||||
|
||||
# Runs on a g3.16xl node with 5 m5.24xl workers
|
||||
# Takes roughly 10 minutes. x10?
|
||||
atari-impala:
|
||||
@@ -13,7 +12,7 @@ atari-impala:
|
||||
stop:
|
||||
timesteps_total: 30000000
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 128
|
||||
num_envs_per_worker: 5
|
||||
@@ -21,4 +20,4 @@ atari-impala:
|
||||
lr_schedule: [
|
||||
[0, 0.0005],
|
||||
[20000000, 0.000000000001],
|
||||
]
|
||||
]
|
||||
|
||||
@@ -1,6 +1,35 @@
|
||||
RLlib Algorithms
|
||||
================
|
||||
|
||||
.. tip::
|
||||
|
||||
Check out the `environments <rllib-env.html>`__ page to learn more about different environment types.
|
||||
|
||||
Feature Compatibility Matrix
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
============= ======================= ================== =========== ===========================
|
||||
Algorithm Discrete Actions Continuous Multi-Agent Model Support
|
||||
============= ======================= ================== =========== ===========================
|
||||
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
PG **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
IMPALA **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
DQN, Rainbow **Yes** `+parametric`_ No **Yes**
|
||||
DDPG, TD3 No **Yes** **Yes**
|
||||
APEX-DQN **Yes** `+parametric`_ No **Yes**
|
||||
APEX-DDPG No **Yes** **Yes**
|
||||
SAC **Yes** **Yes** **Yes**
|
||||
ES **Yes** **Yes** No
|
||||
ARS **Yes** **Yes** No
|
||||
QMIX **Yes** No **Yes** `+RNN`_
|
||||
MARWIL **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
============= ======================= ================== =========== ===========================
|
||||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions
|
||||
|
||||
High-throughput architectures
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -321,7 +350,7 @@ Soft Actor Critic (SAC)
|
||||
|
||||
SAC architecture (same as DQN)
|
||||
|
||||
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``, and currently has no support for non-continuous action distributions.
|
||||
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``.
|
||||
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__, `HalfCheetah-v3 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/halfcheetah-sac.yaml>`__
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#polic
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 74 KiB After Width: | Height: | Size: 87 KiB |
@@ -3,33 +3,12 @@ RLlib Environments
|
||||
|
||||
RLlib works with several different types of environments, including `OpenAI Gym <https://gym.openai.com/>`__, user-defined, multi-agent, and also batched environments.
|
||||
|
||||
.. tip::
|
||||
|
||||
Not all environments work with all algorithms. Check out the algorithm `feature compatibility matrix <rllib-algorithms.html#feature-compatibility-matrix>`__ for more information.
|
||||
|
||||
.. image:: rllib-envs.svg
|
||||
|
||||
Feature Compatibility Matrix
|
||||
----------------------------
|
||||
|
||||
============= ======================= ================== =========== ===========================
|
||||
Algorithm Discrete Actions Continuous Multi-Agent Model Support
|
||||
============= ======================= ================== =========== ===========================
|
||||
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
PG **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
IMPALA **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
|
||||
DQN, Rainbow **Yes** `+parametric`_ No **Yes**
|
||||
DDPG, TD3 No **Yes** **Yes**
|
||||
APEX-DQN **Yes** `+parametric`_ No **Yes**
|
||||
APEX-DDPG No **Yes** **Yes**
|
||||
SAC **Yes** **Yes** **Yes**
|
||||
ES **Yes** **Yes** No
|
||||
ARS **Yes** **Yes** No
|
||||
QMIX **Yes** No **Yes** `+RNN`_
|
||||
MARWIL **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
============= ======================= ================== =========== ===========================
|
||||
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions
|
||||
|
||||
Configuring Environments
|
||||
------------------------
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ The ``rllib train`` command (same as the ``train.py`` script in the repo) has a
|
||||
The most important options are for choosing the environment
|
||||
with ``--env`` (any OpenAI gym environment including ones registered by the user
|
||||
can be used) and for choosing the algorithm with ``--run``
|
||||
(available options are ``SAC``, ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
|
||||
(available options include ``SAC``, ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``MARWIL``, ``APEX``, and ``APEX_DDPG``).
|
||||
|
||||
Evaluating Trained Policies
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -83,6 +83,7 @@ Specifying Resources
|
||||
|
||||
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``.
|
||||
|
||||
.. Original image: https://docs.google.com/drawings/d/14QINFvx3grVyJyjAnjggOCEVN-Iq6pYVJ3jA2S6j8z0/edit?usp=sharing
|
||||
.. image:: rllib-config.svg
|
||||
|
||||
Common Parameters
|
||||
@@ -632,8 +633,9 @@ The following are example excerpts from different Trainers' configs
|
||||
# a) DQN: see rllib/agents/dqn/dqn.py
|
||||
"explore": True,
|
||||
"exploration_config": {
|
||||
"type": "EpsilonGreedy", # <- Exploration sub-class by name or full path to module+class
|
||||
# (e.g. “ray.rllib.utils.exploration.epsilon_greedy.EpsilonGreedy”)
|
||||
# Exploration sub-class by name or full path to module+class
|
||||
# (e.g. “ray.rllib.utils.exploration.epsilon_greedy.EpsilonGreedy”)
|
||||
"type": "EpsilonGreedy",
|
||||
# Parameters for the Exploration class' constructor:
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.02,
|
||||
@@ -748,6 +750,19 @@ Note that in the ``on_postprocess_traj`` callback you have full access to the tr
|
||||
* Backdating rewards to previous time steps (e.g., based on values in ``info``).
|
||||
* Adding model-based curiosity bonuses to rewards (you can train the model with a `custom model supervised loss <rllib-models.html#supervised-model-losses>`__).
|
||||
|
||||
To access the policy / model (``policy.model``) in the callbacks, note that ``info['pre_batch']`` returns a tuple where the first element is a policy and the second one is the batch itself. You can also access all the rollout worker state using the following call:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
|
||||
# You can use this from any callback to get a reference to the
|
||||
# RolloutWorker running in the process, which in turn has references to
|
||||
# all the policies, etc: see rollout_worker.py for more info.
|
||||
rollout_worker = get_global_worker()
|
||||
|
||||
Policy losses are defined over the ``post_batch`` data, so you can mutate that in the callbacks to change what data the policy loss function sees.
|
||||
|
||||
Curriculum Learning
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ Policies can be implemented using `any framework <https://github.com/ray-project
|
||||
Sample Batches
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
Whether running in a single process or `large cluster <rllib-training.html#specifying-resources>`__, all data interchange in RLlib is in the form of `sample batches <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__. Sample batches encode one or more fragments of a trajectory. Typically, RLlib collects batches of size ``sample_batch_size`` from rollout workers, and concatenates one or more of these batches into a batch of size ``train_batch_size`` that is the input to SGD.
|
||||
Whether running in a single process or `large cluster <rllib-training.html#specifying-resources>`__, all data interchange in RLlib is in the form of `sample batches <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__. Sample batches encode one or more fragments of a trajectory. Typically, RLlib collects batches of size ``rollout_fragment_length`` from rollout workers, and concatenates one or more of these batches into a batch of size ``train_batch_size`` that is the input to SGD.
|
||||
|
||||
A typical sample batch looks something like the following when summarized. Since all values are kept in arrays, this allows for efficient encoding and transmission across the network:
|
||||
|
||||
|
||||
+9
-9
@@ -428,7 +428,7 @@ py_test(
|
||||
"--env", "PongDeterministic-v4",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"sample_batch_size\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
|
||||
"--config", "'{\"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -549,7 +549,7 @@ py_test(
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"timesteps_total\": 40000}'",
|
||||
"--ray-object-store-memory=1000000000",
|
||||
"--config", "'{\"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"sample_batch_size\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
|
||||
"--config", "'{\"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -580,7 +580,7 @@ py_test(
|
||||
"--env", "FrozenLake-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -594,7 +594,7 @@ py_test(
|
||||
"--env", "FrozenLake-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -607,7 +607,7 @@ py_test(
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -620,7 +620,7 @@ py_test(
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -632,7 +632,7 @@ py_test(
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -645,7 +645,7 @@ py_test(
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -658,7 +658,7 @@ py_test(
|
||||
"--env", "Pong-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"sample_batch_size\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from ray.rllib.utils.experimental_dsl import (
|
||||
A2C_DEFAULT_CONFIG = merge_dicts(
|
||||
A3C_CONFIG,
|
||||
{
|
||||
"sample_batch_size": 20,
|
||||
"rollout_fragment_length": 20,
|
||||
"min_iter_time_s": 10,
|
||||
"sample_async": False,
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# Size of rollout batch
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
# GAE(gamma) parameter
|
||||
"lambda": 1.0,
|
||||
# Max global norm for each gradient calculated by worker
|
||||
@@ -35,7 +35,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Min time per iteration
|
||||
"min_iter_time_s": 5,
|
||||
# Workers sample async. Note that this increases the effective
|
||||
# sample_batch_size by up to 5x due to async buffering of batches.
|
||||
# rollout_fragment_length by up to 5x due to async buffering of batches.
|
||||
"sample_async": True,
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
|
||||
@@ -100,13 +100,13 @@ class Worker:
|
||||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=False):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
rollout_rewards, rollout_fragment_length = policies.rollout(
|
||||
self.policy,
|
||||
self.env,
|
||||
timestep_limit=timestep_limit,
|
||||
add_noise=add_noise,
|
||||
offset=self.config["offset"])
|
||||
return rollout_rewards, rollout_length
|
||||
return rollout_rewards, rollout_fragment_length
|
||||
|
||||
def do_rollouts(self, params, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
|
||||
@@ -19,7 +19,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"worker_side_prioritization": True,
|
||||
|
||||
@@ -128,7 +128,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"learning_starts": 1500,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 1,
|
||||
"rollout_fragment_length": 1,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
|
||||
@@ -30,7 +30,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
|
||||
@@ -62,7 +62,7 @@ def make_async_optimizer(workers, config):
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
**extra_config)
|
||||
workers.add_workers(config["num_workers"])
|
||||
opt._set_workers(workers.remote_workers())
|
||||
|
||||
@@ -101,7 +101,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 4,
|
||||
"rollout_fragment_length": 4,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
@@ -157,7 +157,7 @@ def make_policy_optimizer(workers, config):
|
||||
def validate_config_and_setup_param_noise(config):
|
||||
"""Checks and updates the config based on settings.
|
||||
|
||||
Rewrites sample_batch_size to take into account n_step truncation.
|
||||
Rewrites rollout_fragment_length to take into account n_step truncation.
|
||||
"""
|
||||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
@@ -223,9 +223,9 @@ def validate_config_and_setup_param_noise(config):
|
||||
}
|
||||
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(config["sample_batch_size"],
|
||||
adjusted_batch_size = max(config["rollout_fragment_length"],
|
||||
config.get("n_step", 1))
|
||||
config["sample_batch_size"] = adjusted_batch_size
|
||||
config["rollout_fragment_length"] = adjusted_batch_size
|
||||
|
||||
# Setup parameter noise.
|
||||
if config.get("parameter_noise", False):
|
||||
|
||||
@@ -104,12 +104,12 @@ class Worker:
|
||||
return return_filters
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=True):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
rollout_rewards, rollout_fragment_length = policies.rollout(
|
||||
self.policy,
|
||||
self.env,
|
||||
timestep_limit=timestep_limit,
|
||||
add_noise=add_noise)
|
||||
return rollout_rewards, rollout_length
|
||||
return rollout_rewards, rollout_fragment_length
|
||||
|
||||
def do_rollouts(self, params, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
|
||||
@@ -19,15 +19,15 @@ DEFAULT_CONFIG = with_common_config({
|
||||
#
|
||||
# == Overview of data flow in IMPALA ==
|
||||
# 1. Policy evaluation in parallel across `num_workers` actors produces
|
||||
# batches of size `sample_batch_size * num_envs_per_worker`.
|
||||
# batches of size `rollout_fragment_length * num_envs_per_worker`.
|
||||
# 2. If enabled, the replay buffer stores and produces batches of size
|
||||
# `sample_batch_size * num_envs_per_worker`.
|
||||
# `rollout_fragment_length * num_envs_per_worker`.
|
||||
# 3. If enabled, the minibatch ring buffer stores and replays batches of
|
||||
# size `train_batch_size` up to `num_sgd_iter` times per batch.
|
||||
# 4. The learner thread executes data parallel SGD across `num_gpus` GPUs
|
||||
# on batches of size `train_batch_size`.
|
||||
#
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 500,
|
||||
"min_iter_time_s": 10,
|
||||
"num_workers": 2,
|
||||
@@ -45,7 +45,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# a p:1 proportion to new data samples.
|
||||
"replay_proportion": 0.0,
|
||||
# number of sample batches to store for replay. The number of transitions
|
||||
# saved total will be (replay_buffer_num_slots * sample_batch_size).
|
||||
# saved total will be (replay_buffer_num_slots * rollout_fragment_length).
|
||||
"replay_buffer_num_slots": 0,
|
||||
# max queue size for train batches feeding into the learner
|
||||
"learner_queue_size": 16,
|
||||
@@ -117,7 +117,7 @@ def make_aggregators_and_optimizer(workers, config):
|
||||
workers,
|
||||
lr=config["lr"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
replay_buffer_num_slots=config["replay_buffer_num_slots"],
|
||||
replay_proportion=config["replay_proportion"],
|
||||
|
||||
@@ -136,7 +136,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):
|
||||
else:
|
||||
# Important: chop the tensor into batches at known episode cut
|
||||
# boundaries. TODO(ekl) this is kind of a hack
|
||||
T = policy.config["sample_batch_size"]
|
||||
T = policy.config["rollout_fragment_length"]
|
||||
B = tf.shape(tensor)[0] // T
|
||||
rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))
|
||||
|
||||
@@ -299,4 +299,4 @@ VTraceTFPolicy = build_tf_policy(
|
||||
before_init=validate_config,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[LearningRateSchedule, EntropyCoeffSchedule],
|
||||
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
|
||||
|
||||
@@ -29,7 +29,7 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||
"kl_target": 0.01,
|
||||
|
||||
# == IMPALA optimizer params (see documentation in impala.py) ==
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 500,
|
||||
"min_iter_time_s": 10,
|
||||
"num_workers": 2,
|
||||
|
||||
@@ -454,4 +454,4 @@ AsyncPPOTFPolicy = build_tf_policy(
|
||||
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
|
||||
ValueNetworkMixin
|
||||
],
|
||||
get_batch_divisibility_req=lambda p: p.config["sample_batch_size"])
|
||||
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
|
||||
|
||||
@@ -21,8 +21,8 @@ Note that unlike the paper, we currently do not implement straggler mitigation.
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
|
||||
# During the sampling phase, each rollout worker will collect a batch
|
||||
# `sample_batch_size * num_envs_per_worker` steps in size.
|
||||
"sample_batch_size": 100,
|
||||
# `rollout_fragment_length * num_envs_per_worker` steps in size.
|
||||
"rollout_fragment_length": 100,
|
||||
# Vectorize the env (should enable by default since each worker has a GPU).
|
||||
"num_envs_per_worker": 5,
|
||||
# During the SGD phase, workers iterate over minibatches of this size.
|
||||
@@ -49,10 +49,11 @@ def validate_config(config):
|
||||
if config["train_batch_size"] == -1:
|
||||
# Auto set.
|
||||
config["train_batch_size"] = (
|
||||
config["sample_batch_size"] * config["num_envs_per_worker"])
|
||||
config["rollout_fragment_length"] * config["num_envs_per_worker"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Set sample_batch_size instead of train_batch_size for DDPPO.")
|
||||
"Set rollout_fragment_length instead of train_batch_size "
|
||||
"for DDPPO.")
|
||||
ppo.validate_config(config)
|
||||
|
||||
|
||||
@@ -73,7 +74,7 @@ def make_distributed_allreduce_optimizer(workers, config):
|
||||
|
||||
return TorchDistributedDataParallelOptimizer(
|
||||
workers,
|
||||
expected_batch_size=config["sample_batch_size"] *
|
||||
expected_batch_size=config["rollout_fragment_length"] *
|
||||
config["num_envs_per_worker"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
|
||||
@@ -24,7 +24,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# Initial coefficient for KL divergence.
|
||||
"kl_coeff": 0.2,
|
||||
# Size of batches collected from each worker.
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
# Number of timesteps collected for each SGD round. This defines the size
|
||||
# of each SGD epoch.
|
||||
"train_batch_size": 4000,
|
||||
@@ -88,7 +88,7 @@ def choose_policy_optimizer(workers, config):
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
standardize_fields=["advantages"],
|
||||
|
||||
@@ -21,7 +21,7 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
|
||||
@@ -59,7 +59,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 4,
|
||||
"rollout_fragment_length": 4,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
|
||||
@@ -4,8 +4,8 @@ from ray.rllib.agents.sac.sac_policy import SACTFPolicy
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
|
||||
"train_batch_size", "learning_starts"
|
||||
"prioritized_replay_beta", "prioritized_replay_eps",
|
||||
"rollout_fragment_length", "train_batch_size", "learning_starts"
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
@@ -72,7 +72,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"learning_starts": 1500,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 1,
|
||||
"rollout_fragment_length": 1,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
|
||||
+29
-13
@@ -26,6 +26,7 @@ from ray.tune.resources import Resources
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
@@ -46,22 +47,26 @@ COMMON_CONFIG = {
|
||||
# model inference batching, which can improve performance for inference
|
||||
# bottlenecked workloads.
|
||||
"num_envs_per_worker": 1,
|
||||
# Default sample batch size (unroll length). Batches of this size are
|
||||
# collected from rollout workers until train_batch_size is met. When using
|
||||
# multiple envs per worker, this is multiplied by num_envs_per_worker.
|
||||
# Divide episodes into fragments of this many steps each during rollouts.
|
||||
# Sample batches of this size are collected from rollout workers and
|
||||
# combined into a larger batch of `train_batch_size` for learning.
|
||||
#
|
||||
# For example, given sample_batch_size=100 and train_batch_size=1000:
|
||||
# 1. RLlib will collect 10 batches of size 100 from the rollout workers.
|
||||
# 2. These batches are concatenated and we perform an epoch of SGD.
|
||||
# For example, given rollout_fragment_length=100 and train_batch_size=1000:
|
||||
# 1. RLlib collects 10 fragments of 100 steps each from rollout workers.
|
||||
# 2. These fragments are concatenated and we perform an epoch of SGD.
|
||||
#
|
||||
# If we further set num_envs_per_worker=5, then the sample batches will be
|
||||
# of size 5*100 = 500, and RLlib will only collect 2 batches per epoch.
|
||||
# When using multiple envs per worker, the fragment size is multiplied by
|
||||
# `num_envs_per_worker`. This is since we are collecting steps from
|
||||
# multiple envs in parallel. For example, if num_envs_per_worker=5, then
|
||||
# rollout workers will return experiences in chunks of 5*100 = 500 steps.
|
||||
#
|
||||
# The exact workflow here can vary per algorithm. For example, PPO further
|
||||
# The dataflow here can vary per algorithm. For example, PPO further
|
||||
# divides the train batch into minibatches for multi-epoch SGD.
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
# Deprecated; renamed to `rollout_fragment_length` in 0.8.4.
|
||||
"sample_batch_size": DEPRECATED_VALUE,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes" to
|
||||
# `sample_batch_size` length unrolls. Episode truncation guarantees more
|
||||
# `rollout_fragment_length` length unrolls. Episode truncation guarantees
|
||||
# evenly sized batches, but increases variance as the reward-to-go will
|
||||
# need to be estimated at truncation boundaries.
|
||||
"batch_mode": "truncate_episodes",
|
||||
@@ -71,7 +76,7 @@ COMMON_CONFIG = {
|
||||
# algorithms can take advantage of trainer GPUs. This can be fractional
|
||||
# (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Training batch size, if applicable. Should be >= sample_batch_size.
|
||||
# Training batch size, if applicable. Should be >= rollout_fragment_length.
|
||||
# Samples batches will be concatenated together to a batch of this size,
|
||||
# which is then passed to SGD.
|
||||
"train_batch_size": 200,
|
||||
@@ -599,7 +604,7 @@ class Trainer(Trainable):
|
||||
extra_config["in_evaluation"] is True
|
||||
extra_config.update({
|
||||
"batch_mode": "complete_episodes",
|
||||
"batch_steps": 1,
|
||||
"rollout_fragment_length": 1,
|
||||
"in_evaluation": True,
|
||||
})
|
||||
logger.debug(
|
||||
@@ -883,6 +888,17 @@ class Trainer(Trainable):
|
||||
@classmethod
|
||||
def merge_trainer_configs(cls, config1, config2):
|
||||
config1 = copy.deepcopy(config1)
|
||||
# Error if trainer default has deprecated value.
|
||||
if config1["sample_batch_size"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"sample_batch_size", new="rollout_fragment_length", error=True)
|
||||
# Warning if user override config has deprecated value.
|
||||
if ("sample_batch_size" in config2
|
||||
and config2["sample_batch_size"] != DEPRECATED_VALUE):
|
||||
deprecation_warning(
|
||||
"sample_batch_size", new="rollout_fragment_length")
|
||||
config2["rollout_fragment_length"] = config2["sample_batch_size"]
|
||||
del config2["sample_batch_size"]
|
||||
return deep_update(config1, config2, cls._allow_unknown_configs,
|
||||
cls._allow_unknown_subkeys,
|
||||
cls._override_all_subkeys_if_type_changes)
|
||||
|
||||
@@ -33,7 +33,7 @@ def on_episode_start(info):
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Size of batches collected from each worker
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
# Number of timesteps collected for each SGD round
|
||||
"train_batch_size": 4000,
|
||||
# Total SGD batch size across all devices for SGD
|
||||
@@ -160,10 +160,9 @@ class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
|
||||
def mcts_creator():
|
||||
return MCTS(model, config["mcts_config"])
|
||||
|
||||
super().__init__(
|
||||
obs_space, action_space, config, model, alpha_zero_loss,
|
||||
TorchCategorical, mcts_creator, _env_creator
|
||||
)
|
||||
super().__init__(obs_space, action_space, config, model,
|
||||
alpha_zero_loss, TorchCategorical, mcts_creator,
|
||||
_env_creator)
|
||||
|
||||
|
||||
AlphaZeroTrainer = build_trainer(
|
||||
|
||||
@@ -25,7 +25,7 @@ if __name__ == "__main__":
|
||||
config={
|
||||
"env": CartPole,
|
||||
"num_workers": args.num_workers,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 500,
|
||||
"sgd_minibatch_size": 64,
|
||||
"lr": 1e-4,
|
||||
|
||||
@@ -85,7 +85,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"learning_starts": 1024 * 25,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 100,
|
||||
"rollout_fragment_length": 100,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
|
||||
@@ -119,7 +119,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
policy_mapping_fn=None,
|
||||
policies_to_train=None,
|
||||
tf_session_creator=None,
|
||||
batch_steps=100,
|
||||
rollout_fragment_length=100,
|
||||
batch_mode="truncate_episodes",
|
||||
episode_horizon=None,
|
||||
preprocessor_pref="deepmind",
|
||||
@@ -165,20 +165,21 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
or None for all policies.
|
||||
tf_session_creator (func): A function that returns a TF session.
|
||||
This is optional and only useful with TFPolicy.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
in each sample batch returned from this worker.
|
||||
rollout_fragment_length (int): The target number of env transitions
|
||||
to include in each sample batch returned from this worker.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of at most `batch_steps * num_envs` in size. The batch will
|
||||
be exactly `batch_steps * num_envs` in size if
|
||||
of at most `rollout_fragment_length * num_envs` in size.
|
||||
The batch will be exactly
|
||||
`rollout_fragment_length * num_envs` in size if
|
||||
postprocessing does not change batch sizes. Episodes may be
|
||||
truncated in order to meet this size requirement.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps * num_envs` in size. Episodes will
|
||||
not be truncated, but multiple episodes may be packed
|
||||
within one batch to meet the batch size. Note that when
|
||||
`num_envs > 1`, episode steps will be buffered until the
|
||||
episode completes, and hence batches may contain
|
||||
of at least `rollout_fragment_length * num_envs` in size.
|
||||
Episodes will not be truncated, but multiple episodes may
|
||||
be packed within one batch to meet the batch size. Note
|
||||
that when `num_envs > 1`, episode steps will be buffered
|
||||
until the episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
episode_horizon (int): Whether to stop episodes at this horizon.
|
||||
preprocessor_pref (str): Whether to prefer RLlib preprocessors
|
||||
@@ -273,7 +274,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
if not callable(policy_mapping_fn):
|
||||
raise ValueError("Policy mapping function not callable?")
|
||||
self.env_creator = env_creator
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.rollout_fragment_length = rollout_fragment_length * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
@@ -399,10 +400,9 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
unroll_length = batch_steps
|
||||
pack_episodes = True
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
unroll_length = float("inf") # never cut episodes
|
||||
rollout_fragment_length = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
@@ -435,7 +435,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
rollout_fragment_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
@@ -453,7 +453,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
rollout_fragment_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
@@ -484,7 +484,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
self.sample_batch_size))
|
||||
self.rollout_fragment_length))
|
||||
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
@@ -496,8 +496,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while steps_so_far < self.sample_batch_size and len(
|
||||
batches) < max_batches:
|
||||
while (steps_so_far < self.rollout_fragment_length
|
||||
and len(batches) < max_batches):
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
|
||||
+22
-19
@@ -66,7 +66,7 @@ class SyncSampler(SamplerInput):
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
rollout_fragment_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
@@ -75,7 +75,7 @@ class SyncSampler(SamplerInput):
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False):
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.unroll_length = unroll_length
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
@@ -85,7 +85,7 @@ class SyncSampler(SamplerInput):
|
||||
self.perf_stats = PerfStats()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
|
||||
no_done_at_end)
|
||||
@@ -127,7 +127,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
preprocessors,
|
||||
obs_filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
rollout_fragment_length,
|
||||
callbacks,
|
||||
horizon=None,
|
||||
pack=False,
|
||||
@@ -144,7 +144,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
self.queue = queue.Queue(5)
|
||||
self.extra_batches = queue.Queue()
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.unroll_length = unroll_length
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
@@ -179,7 +179,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
lambda x: self.extra_batches.put(x, timeout=600.0))
|
||||
rollout_provider = _env_runner(
|
||||
self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
self.perf_stats, self.soft_horizon, self.no_done_at_end)
|
||||
@@ -225,7 +225,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||
|
||||
|
||||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
unroll_length, horizon, preprocessors, obs_filters,
|
||||
rollout_fragment_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
perf_stats, soft_horizon, no_done_at_end):
|
||||
"""This implements the common experience collection logic.
|
||||
@@ -237,8 +237,9 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
agent is then "bound" to the returned policy for the episode.
|
||||
unroll_length (int): Number of episode steps before `SampleBatch` is
|
||||
yielded. Set to infinity to yield complete episodes.
|
||||
rollout_fragment_length (int): Number of episode steps before
|
||||
`SampleBatch` is yielded. Set to infinity to yield complete
|
||||
episodes.
|
||||
horizon (int): Horizon of the episode.
|
||||
preprocessors (dict): Map of policy id to preprocessor for the
|
||||
observations prior to filtering.
|
||||
@@ -246,7 +247,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
observations for the policy.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
pack (bool): Whether to pack multiple episodes into each batch. This
|
||||
guarantees batches will be exactly `unroll_length` in size.
|
||||
guarantees batches will be exactly `rollout_fragment_length` in
|
||||
size.
|
||||
clip_actions (bool): Whether to clip actions to the space range.
|
||||
callbacks (dict): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
@@ -332,8 +334,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon, no_done_at_end)
|
||||
preprocessors, obs_filters, rollout_fragment_length, pack,
|
||||
callbacks, soft_horizon, no_done_at_end)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
@@ -361,8 +363,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
def _process_observations(base_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon, no_done_at_end):
|
||||
obs_filters, rollout_fragment_length, pack,
|
||||
callbacks, soft_horizon, no_done_at_end):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Returns:
|
||||
@@ -374,8 +376,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
active_envs = set()
|
||||
to_eval = defaultdict(list)
|
||||
outputs = []
|
||||
large_batch_threshold = max(1000, unroll_length * 10) if \
|
||||
unroll_length != float("inf") else 5000
|
||||
large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
|
||||
rollout_fragment_length != float("inf") else 5000
|
||||
|
||||
# For each environment
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
@@ -395,8 +397,9 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
"the sampler. If this is more than you expected, check that "
|
||||
"that you set a horizon on your environment correctly and that"
|
||||
" it terminates at some point. "
|
||||
"Note: In multi-agent environments, `sample_batch_size` sets "
|
||||
"the batch size based on environment steps, not the steps of "
|
||||
"Note: In multi-agent environments, `rollout_fragment_length` "
|
||||
"sets the batch size based on environment steps, not the "
|
||||
"steps of "
|
||||
"individual agents, which can result in unexpectedly large "
|
||||
"batches. Also, you may be in evaluation waiting for your Env "
|
||||
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||
@@ -480,7 +483,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||
episode.batch_builder.check_missing_dones()
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
episode.batch_builder.count >= rollout_fragment_length:
|
||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||
elif all_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
|
||||
@@ -240,7 +240,7 @@ class WorkerSet:
|
||||
policies_to_train=config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
batch_steps=config["sample_batch_size"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
|
||||
@@ -62,7 +62,7 @@ if __name__ == "__main__":
|
||||
"num_data_loader_buffers": 1,
|
||||
"num_aggregation_workers": 1,
|
||||
"broadcast_interval": 50,
|
||||
"sample_batch_size": 100,
|
||||
"rollout_fragment_length": 100,
|
||||
"train_batch_size": sample_from(
|
||||
lambda spec: 1000 * max(1, spec.config.num_gpus)),
|
||||
"_fake_sampler": True,
|
||||
|
||||
@@ -166,7 +166,7 @@ def run_heuristic_vs_learned(use_lstm=False, trainer="PG"):
|
||||
"gamma": 0.9,
|
||||
"num_workers": 0,
|
||||
"num_envs_per_worker": 4,
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
"train_batch_size": 200,
|
||||
"multiagent": {
|
||||
"policies_to_train": ["learned"],
|
||||
|
||||
@@ -182,7 +182,7 @@ if __name__ == "__main__":
|
||||
group = False
|
||||
elif args.run == "QMIX":
|
||||
config = {
|
||||
"sample_batch_size": 4,
|
||||
"rollout_fragment_length": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_fraction": .4,
|
||||
"exploration_final_eps": 0.0,
|
||||
@@ -205,7 +205,7 @@ if __name__ == "__main__":
|
||||
"buffer_size": 1000,
|
||||
"learning_starts": 1000,
|
||||
"train_batch_size": 128,
|
||||
"sample_batch_size": 32,
|
||||
"rollout_fragment_length": 32,
|
||||
"target_network_update_freq": 500,
|
||||
"timesteps_per_iteration": 1000,
|
||||
"env_config": {
|
||||
|
||||
@@ -53,7 +53,8 @@ class AggregationWorkerBase:
|
||||
|
||||
def __init__(self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
replay_buffer_num_slots, train_batch_size,
|
||||
rollout_fragment_length):
|
||||
"""Initialize an aggregator.
|
||||
|
||||
Arguments:
|
||||
@@ -65,20 +66,22 @@ class AggregationWorkerBase:
|
||||
replay_buffer_num_slots (int): max number of sample batches to
|
||||
store in the replay buffer
|
||||
train_batch_size (int): size of batches to learn on
|
||||
sample_batch_size (int): size of batches to sample from workers
|
||||
rollout_fragment_length (int): size of batches to sample from
|
||||
workers.
|
||||
"""
|
||||
|
||||
self.broadcasted_weights = initial_weights_obj_id
|
||||
self.remote_workers = remote_workers
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
if replay_proportion:
|
||||
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
|
||||
if (replay_buffer_num_slots * rollout_fragment_length <=
|
||||
train_batch_size):
|
||||
raise ValueError(
|
||||
"Replay buffer size is too small to produce train, "
|
||||
"please increase replay_buffer_num_slots.",
|
||||
replay_buffer_num_slots, sample_batch_size,
|
||||
replay_buffer_num_slots, rollout_fragment_length,
|
||||
train_batch_size)
|
||||
|
||||
# Kick off async background sampling
|
||||
@@ -159,7 +162,7 @@ class AggregationWorkerBase:
|
||||
def _augment_with_replay(self, sample_futures):
|
||||
def can_replay():
|
||||
num_needed = int(
|
||||
np.ceil(self.train_batch_size / self.sample_batch_size))
|
||||
np.ceil(self.train_batch_size / self.rollout_fragment_length))
|
||||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
@@ -184,7 +187,7 @@ class SimpleAggregator(AggregationWorkerBase, Aggregator):
|
||||
replay_proportion=0.0,
|
||||
replay_buffer_num_slots=0,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
rollout_fragment_length=50,
|
||||
broadcast_interval=5):
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
@@ -193,7 +196,7 @@ class SimpleAggregator(AggregationWorkerBase, Aggregator):
|
||||
AggregationWorkerBase.__init__(
|
||||
self, self.broadcasted_weights, self.workers.remote_workers(),
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
replay_buffer_num_slots, train_batch_size, rollout_fragment_length)
|
||||
|
||||
@override(Aggregator)
|
||||
def broadcast_new_weights(self):
|
||||
|
||||
@@ -31,7 +31,7 @@ class TreeAggregator(Aggregator):
|
||||
replay_proportion=0.0,
|
||||
replay_buffer_num_slots=0,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
rollout_fragment_length=50,
|
||||
broadcast_interval=5):
|
||||
"""Initialize a tree aggregator.
|
||||
|
||||
@@ -45,7 +45,8 @@ class TreeAggregator(Aggregator):
|
||||
replay_buffer_num_slots (int): max number of sample batches to
|
||||
store in the replay buffer
|
||||
train_batch_size (int): size of batches to learn on
|
||||
sample_batch_size (int): size of batches to sample from workers
|
||||
rollout_fragment_length (int): size of batches to sample from
|
||||
workers.
|
||||
broadcast_interval (int): max number of workers to send the
|
||||
same set of weights to
|
||||
"""
|
||||
@@ -55,7 +56,7 @@ class TreeAggregator(Aggregator):
|
||||
max_sample_requests_in_flight_per_worker
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_buffer_num_slots = replay_buffer_num_slots
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.train_batch_size = train_batch_size
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.broadcasted_weights = ray.put(
|
||||
@@ -82,11 +83,11 @@ class TreeAggregator(Aggregator):
|
||||
|
||||
self.aggregators = aggregators
|
||||
for i, agg in enumerate(self.aggregators):
|
||||
agg.init.remote(self.broadcasted_weights, assigned_workers[i],
|
||||
self.max_sample_requests_in_flight_per_worker,
|
||||
self.replay_proportion,
|
||||
self.replay_buffer_num_slots,
|
||||
self.train_batch_size, self.sample_batch_size)
|
||||
agg.init.remote(
|
||||
self.broadcasted_weights, assigned_workers[i],
|
||||
self.max_sample_requests_in_flight_per_worker,
|
||||
self.replay_proportion, self.replay_buffer_num_slots,
|
||||
self.train_batch_size, self.rollout_fragment_length)
|
||||
|
||||
self.agg_tasks = TaskPool()
|
||||
for agg in self.aggregators:
|
||||
@@ -140,7 +141,8 @@ class AggregationWorker(AggregationWorkerBase):
|
||||
|
||||
def init(self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
replay_buffer_num_slots, train_batch_size,
|
||||
rollout_fragment_length):
|
||||
"""Deferred init that assigns sub-workers to this aggregator."""
|
||||
|
||||
logger.info("Assigned workers {} to aggregation worker {}".format(
|
||||
@@ -149,7 +151,7 @@ class AggregationWorker(AggregationWorkerBase):
|
||||
AggregationWorkerBase.__init__(
|
||||
self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
replay_buffer_num_slots, train_batch_size, rollout_fragment_length)
|
||||
self.initialized = True
|
||||
|
||||
def set_weights(self, weights):
|
||||
|
||||
@@ -56,7 +56,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=512,
|
||||
sample_batch_size=50,
|
||||
rollout_fragment_length=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
debug=False,
|
||||
@@ -73,7 +73,8 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
prioritized_replay_beta (float): replay beta hyperparameter
|
||||
prioritized_replay_eps (float): replay eps hyperparameter
|
||||
train_batch_size (int): size of batches to learn on
|
||||
sample_batch_size (int): size of batches to sample from workers
|
||||
rollout_fragment_length (int): size of batches to sample from
|
||||
workers.
|
||||
num_replay_buffer_shards (int): number of actors to use to store
|
||||
replay samples
|
||||
max_weight_sync_delay (int): update the weights of a rollout worker
|
||||
|
||||
@@ -22,10 +22,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
This class coordinates the data transfers between the learner thread
|
||||
and remote workers (IMPALA actors).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
workers,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
rollout_fragment_length=50,
|
||||
num_envs_per_worker=1,
|
||||
num_gpus=0,
|
||||
lr=0.0005,
|
||||
@@ -90,7 +91,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
max_sample_requests_in_flight_per_worker),
|
||||
replay_buffer_num_slots=replay_buffer_num_slots,
|
||||
train_batch_size=train_batch_size,
|
||||
sample_batch_size=sample_batch_size,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
broadcast_interval=broadcast_interval)
|
||||
else:
|
||||
self.aggregator = SimpleAggregator(
|
||||
@@ -100,7 +101,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
max_sample_requests_in_flight_per_worker),
|
||||
replay_buffer_num_slots=replay_buffer_num_slots,
|
||||
train_batch_size=train_batch_size,
|
||||
sample_batch_size=sample_batch_size,
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
broadcast_interval=broadcast_interval)
|
||||
|
||||
def add_stat_val(self, key, val):
|
||||
|
||||
@@ -41,7 +41,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
workers,
|
||||
sgd_batch_size=128,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=200,
|
||||
rollout_fragment_length=200,
|
||||
num_envs_per_worker=1,
|
||||
train_batch_size=1024,
|
||||
num_gpus=0,
|
||||
@@ -53,7 +53,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
workers (WorkerSet): all workers
|
||||
sgd_batch_size (int): SGD minibatch size within train batch size
|
||||
num_sgd_iter (int): number of passes to learn on per train batch
|
||||
sample_batch_size (int): size of batches to sample from workers
|
||||
rollout_fragment_length (int): size of batches to sample from
|
||||
workers.
|
||||
num_envs_per_worker (int): num envs in each rollout worker
|
||||
train_batch_size (int): size of batches to learn on
|
||||
num_gpus (int): number of GPUs to use for data-parallel SGD
|
||||
@@ -67,7 +68,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
self.batch_size = sgd_batch_size
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
self.num_envs_per_worker = num_envs_per_worker
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.rollout_fragment_length = rollout_fragment_length
|
||||
self.train_batch_size = train_batch_size
|
||||
self.shuffle_sequences = shuffle_sequences
|
||||
if not num_gpus:
|
||||
@@ -132,9 +133,10 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
|
||||
with self.sample_timer:
|
||||
if self.workers.remote_workers():
|
||||
samples = collect_samples(
|
||||
self.workers.remote_workers(), self.sample_batch_size,
|
||||
self.num_envs_per_worker, self.train_batch_size)
|
||||
samples = collect_samples(self.workers.remote_workers(),
|
||||
self.rollout_fragment_length,
|
||||
self.num_envs_per_worker,
|
||||
self.train_batch_size)
|
||||
if samples.count > self.train_batch_size * 2:
|
||||
logger.info(
|
||||
"Collected more training samples than expected "
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.rllib.utils.memory import ray_get_and_free
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collect_samples(agents, sample_batch_size, num_envs_per_worker,
|
||||
def collect_samples(agents, rollout_fragment_length, num_envs_per_worker,
|
||||
train_batch_size):
|
||||
"""Collects at least train_batch_size samples, never discarding any."""
|
||||
|
||||
@@ -27,7 +27,8 @@ def collect_samples(agents, sample_batch_size, num_envs_per_worker,
|
||||
trajectories.append(next_sample)
|
||||
|
||||
# Only launch more tasks if we don't already have enough pending
|
||||
pending = len(agent_dict) * sample_batch_size * num_envs_per_worker
|
||||
pending = len(
|
||||
agent_dict) * rollout_fragment_length * num_envs_per_worker
|
||||
if num_timesteps_so_far + pending < train_batch_size:
|
||||
fut_sample2 = agent.sample.remote()
|
||||
agent_dict[fut_sample2] = agent
|
||||
|
||||
@@ -68,7 +68,7 @@ class PPOCollectTest(unittest.TestCase):
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
"train_batch_size": 128,
|
||||
"num_workers": 3,
|
||||
})
|
||||
@@ -80,7 +80,7 @@ class PPOCollectTest(unittest.TestCase):
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
"train_batch_size": 900,
|
||||
"num_workers": 3,
|
||||
})
|
||||
@@ -92,7 +92,7 @@ class PPOCollectTest(unittest.TestCase):
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"sample_batch_size": 200,
|
||||
"rollout_fragment_length": 200,
|
||||
"num_envs_per_worker": 2,
|
||||
"train_batch_size": 900,
|
||||
"num_workers": 3,
|
||||
@@ -150,7 +150,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
workers,
|
||||
minibatch_buffer_size=10,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=10,
|
||||
rollout_fragment_length=10,
|
||||
train_batch_size=50)
|
||||
self._wait_for(optimizer, 1000, 10000)
|
||||
self.assertLess(optimizer.stats()["num_steps_sampled"], 5000)
|
||||
@@ -163,7 +163,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
workers,
|
||||
replay_buffer_num_slots=100,
|
||||
replay_proportion=10,
|
||||
sample_batch_size=10,
|
||||
rollout_fragment_length=10,
|
||||
train_batch_size=10,
|
||||
)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
@@ -182,7 +182,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
num_sgd_iter=10,
|
||||
replay_buffer_num_slots=100,
|
||||
replay_proportion=10,
|
||||
sample_batch_size=10,
|
||||
rollout_fragment_length=10,
|
||||
train_batch_size=10)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
@@ -219,21 +219,21 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
workers,
|
||||
num_gpus=1,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=50,
|
||||
rollout_fragment_length=50,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
workers,
|
||||
num_gpus=1,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=25,
|
||||
rollout_fragment_length=25,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
workers,
|
||||
num_gpus=1,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=74,
|
||||
rollout_fragment_length=74,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
@@ -242,7 +242,7 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
workers,
|
||||
sample_batch_size=1000,
|
||||
rollout_fragment_length=1000,
|
||||
train_batch_size=1000,
|
||||
learner_queue_timeout=1)
|
||||
self.assertRaises(AssertionError,
|
||||
|
||||
@@ -124,7 +124,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
@@ -134,7 +134,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
@@ -144,7 +144,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
@@ -157,7 +157,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=BadPolicy,
|
||||
sample_async=True,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
self.assertRaises(Exception, lambda: ev.sample())
|
||||
|
||||
@@ -206,7 +206,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
episode_horizon=20,
|
||||
batch_steps=10,
|
||||
rollout_fragment_length=10,
|
||||
batch_mode="complete_episodes")
|
||||
self.assertRaises(ValueError, lambda: ev.sample())
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
@@ -41,7 +41,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
batch = ev.sample()
|
||||
@@ -59,7 +59,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
rollout_fragment_length=100)
|
||||
optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev))
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
|
||||
@@ -91,7 +91,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
||||
"PPO", {
|
||||
"num_sgd_iter": 1,
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
})
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"output": output,
|
||||
"sample_batch_size": 250,
|
||||
"rollout_fragment_length": 250,
|
||||
})
|
||||
agent.train()
|
||||
return agent
|
||||
@@ -127,7 +127,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
"input_evaluation": [],
|
||||
"sample_batch_size": 99,
|
||||
"rollout_fragment_length": 99,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
|
||||
@@ -198,7 +198,7 @@ class TestRNNSequencing(unittest.TestCase):
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
"train_batch_size": 10,
|
||||
"sgd_minibatch_size": 10,
|
||||
"vf_share_layers": True,
|
||||
@@ -256,7 +256,7 @@ class TestRNNSequencing(unittest.TestCase):
|
||||
config={
|
||||
"shuffle_sequences": False, # for deterministic testing
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 20,
|
||||
"rollout_fragment_length": 20,
|
||||
"train_batch_size": 20,
|
||||
"sgd_minibatch_size": 10,
|
||||
"vf_share_layers": True,
|
||||
|
||||
@@ -336,7 +336,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 150)
|
||||
@@ -356,7 +356,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
rollout_fragment_length=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True,
|
||||
remote_env_batch_wait_ms=99999999)
|
||||
@@ -375,7 +375,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
rollout_fragment_length=50,
|
||||
num_envs=4,
|
||||
remote_worker_envs=True)
|
||||
batch = ev.sample()
|
||||
@@ -392,7 +392,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
episode_horizon=10, # test with episode horizon set
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
@@ -407,7 +407,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=1)
|
||||
rollout_fragment_length=1)
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
".*don't have a last observation.*",
|
||||
lambda: ev.sample())
|
||||
@@ -421,7 +421,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 50)
|
||||
# since we round robin introduce agents into the env, some of the env
|
||||
@@ -469,7 +469,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=StatefulPolicy,
|
||||
batch_steps=5)
|
||||
rollout_fragment_length=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
self.assertEqual(batch["state_in_0"][0], {})
|
||||
@@ -518,7 +518,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=5)
|
||||
rollout_fragment_length=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 10)
|
||||
@@ -599,7 +599,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
if optimizer_cls == AsyncGradientsOptimizer:
|
||||
|
||||
def policy_mapper(agent_id):
|
||||
@@ -610,7 +610,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_mapper,
|
||||
batch_steps=50)
|
||||
rollout_fragment_length=50)
|
||||
]
|
||||
else:
|
||||
remote_workers = []
|
||||
@@ -659,7 +659,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
rollout_fragment_length=100)
|
||||
workers = WorkerSet._from_existing(worker, [])
|
||||
optimizer = SyncSamplesOptimizer(workers)
|
||||
for i in range(100):
|
||||
|
||||
@@ -251,7 +251,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
env="nested",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"rollout_fragment_length": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite",
|
||||
@@ -280,7 +280,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
env="nested2",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"rollout_fragment_length": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite2",
|
||||
@@ -340,7 +340,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
env="nested_ma",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"rollout_fragment_length": 5,
|
||||
"train_batch_size": 5,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
@@ -404,7 +404,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"use_pytorch": True,
|
||||
"sample_batch_size": 5,
|
||||
"rollout_fragment_length": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite",
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestPerf(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
batch_steps=100)
|
||||
rollout_fragment_length=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 1:
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
pg = PGTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 50,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 50,
|
||||
"callbacks": {
|
||||
"on_episode_start": lambda x: counts.update({"start": 1}),
|
||||
@@ -231,12 +231,13 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env="test",
|
||||
config={
|
||||
"num_workers": 2,
|
||||
"sample_batch_size": 5,
|
||||
"rollout_fragment_length": 5,
|
||||
"num_envs_per_worker": 2,
|
||||
})
|
||||
results = pg.workers.foreach_worker(lambda ev: ev.sample_batch_size)
|
||||
results = pg.workers.foreach_worker(
|
||||
lambda ev: ev.rollout_fragment_length)
|
||||
results2 = pg.workers.foreach_worker_with_index(
|
||||
lambda ev, i: (i, ev.sample_batch_size))
|
||||
lambda ev, i: (i, ev.rollout_fragment_length))
|
||||
results3 = pg.workers.foreach_worker(
|
||||
lambda ev: ev.foreach_env(lambda env: 1))
|
||||
self.assertEqual(results, [10, 10, 10])
|
||||
@@ -269,7 +270,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=4,
|
||||
soft_horizon=False)
|
||||
samples = ev.sample()
|
||||
@@ -287,7 +288,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=6,
|
||||
soft_horizon=False)
|
||||
samples = ev.sample()
|
||||
@@ -307,7 +308,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=4,
|
||||
soft_horizon=True)
|
||||
samples = ev.sample()
|
||||
@@ -348,7 +349,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=2,
|
||||
rollout_fragment_length=2,
|
||||
num_envs=8)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
@@ -371,7 +372,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=4,
|
||||
rollout_fragment_length=4,
|
||||
num_envs=4)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
@@ -386,7 +387,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=10)
|
||||
rollout_fragment_length=10)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
@@ -402,7 +403,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
rollout_fragment_length=15,
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 15)
|
||||
@@ -411,7 +412,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=5,
|
||||
rollout_fragment_length=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
@@ -420,7 +421,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
rollout_fragment_length=15,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 20)
|
||||
|
||||
@@ -215,7 +215,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
"num_workers": 1,
|
||||
"num_sgd_iter": 1,
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
}
|
||||
check_support("PPO", config, self.stats, check_bounds=True)
|
||||
@@ -283,7 +283,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
"num_workers": 1,
|
||||
"num_sgd_iter": 1,
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"rollout_fragment_length": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
})
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ atari-a2c:
|
||||
- SpaceInvadersNoFrameskip-v4
|
||||
run: A2C
|
||||
config:
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
clip_rewards: True
|
||||
num_workers: 5
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -29,7 +29,7 @@ apex:
|
||||
# APEX
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
|
||||
@@ -12,7 +12,7 @@ atari-ddppo:
|
||||
num_gpus_per_worker: 1
|
||||
# Each worker will sample 100 * 5 envs per worker steps = 500 steps
|
||||
# per optimization round. This is 5000 steps summed across workers.
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
num_envs_per_worker: 5
|
||||
# Each worker will take a minibatch of 50. There are 10 workers total,
|
||||
# so the effective minibatch size will be 500.
|
||||
|
||||
@@ -19,7 +19,7 @@ atari-dist-dqn:
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -21,7 +21,7 @@ atari-basic-dqn:
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -21,7 +21,7 @@ dueling-ddqn:
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -11,7 +11,7 @@ atari-impala:
|
||||
stop:
|
||||
timesteps_total: 3000000
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 128
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -9,7 +9,7 @@ atari-impala:
|
||||
- SpaceInvadersNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 32
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -16,7 +16,7 @@ atari-ppo:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
|
||||
@@ -12,7 +12,7 @@ atari-impala:
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 10
|
||||
num_envs_per_worker: 5
|
||||
@@ -36,7 +36,7 @@ atari-ppo-tf:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
@@ -60,7 +60,7 @@ atari-ppo-torch:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 100
|
||||
rollout_fragment_length: 100
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 10
|
||||
@@ -94,7 +94,7 @@ apex:
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
@@ -105,7 +105,7 @@ atari-a2c:
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
clip_rewards: True
|
||||
num_workers: 5
|
||||
num_envs_per_worker: 5
|
||||
@@ -133,7 +133,7 @@ atari-basic-dqn:
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -9,7 +9,7 @@ halfcheetah-appo:
|
||||
vtrace: True
|
||||
gamma: 0.99
|
||||
lambda: 0.95
|
||||
sample_batch_size: 512
|
||||
rollout_fragment_length: 512
|
||||
train_batch_size: 4096
|
||||
num_workers: 16
|
||||
num_gpus: 1
|
||||
|
||||
@@ -42,7 +42,7 @@ halfcheetah-ddpg:
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
sample_batch_size: 1
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
# === Parallelism ===
|
||||
|
||||
@@ -17,7 +17,7 @@ halfcheetah_sac:
|
||||
target_entropy: auto
|
||||
no_done_at_end: True
|
||||
n_step: 1
|
||||
sample_batch_size: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: False
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
|
||||
@@ -43,7 +43,7 @@ mountaincarcontinuous-ddpg:
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.00001
|
||||
learning_starts: 1000
|
||||
sample_batch_size: 1
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
# === Parallelism ===
|
||||
|
||||
@@ -43,7 +43,7 @@ pendulum-ddpg:
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
sample_batch_size: 1
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
# === Parallelism ===
|
||||
|
||||
@@ -18,7 +18,7 @@ pendulum_sac:
|
||||
target_entropy: auto
|
||||
no_done_at_end: True
|
||||
n_step: 1
|
||||
sample_batch_size: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: False
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
|
||||
@@ -3,7 +3,7 @@ pong-a3c-pytorch-cnn:
|
||||
run: A3C
|
||||
config:
|
||||
num_workers: 16
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
use_pytorch: true
|
||||
vf_loss_coeff: 0.5
|
||||
entropy_coeff: 0.01
|
||||
|
||||
@@ -5,7 +5,7 @@ pong-a3c:
|
||||
run: A3C
|
||||
config:
|
||||
num_workers: 16
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
use_pytorch: false
|
||||
vf_loss_coeff: 0.5
|
||||
entropy_coeff: 0.01
|
||||
|
||||
@@ -12,7 +12,7 @@ pong-appo:
|
||||
config:
|
||||
vtrace: True
|
||||
use_kl_loss: False
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 750
|
||||
num_workers: 32
|
||||
broadcast_interval: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ pong-deterministic-dqn:
|
||||
lr: .0001
|
||||
learning_starts: 10000
|
||||
buffer_size: 50000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
|
||||
@@ -7,7 +7,7 @@ pong-impala-fast:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 1000
|
||||
num_workers: 128
|
||||
num_envs_per_worker: 5
|
||||
|
||||
@@ -5,7 +5,7 @@ pong-impala-vectorized:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 32
|
||||
num_envs_per_worker: 10
|
||||
|
||||
@@ -7,7 +7,7 @@ pong-impala:
|
||||
env: PongNoFrameskip-v4
|
||||
run: IMPALA
|
||||
config:
|
||||
sample_batch_size: 50
|
||||
rollout_fragment_length: 50
|
||||
train_batch_size: 500
|
||||
num_workers: 128
|
||||
num_envs_per_worker: 1
|
||||
|
||||
@@ -13,7 +13,7 @@ pong-ppo:
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
train_batch_size: 5000
|
||||
sample_batch_size: 20
|
||||
rollout_fragment_length: 20
|
||||
sgd_minibatch_size: 500
|
||||
num_sgd_iter: 10
|
||||
num_workers: 32
|
||||
|
||||
@@ -11,7 +11,7 @@ pong-deterministic-rainbow:
|
||||
hiddens: [512]
|
||||
learning_starts: 10000
|
||||
buffer_size: 50000
|
||||
sample_batch_size: 4
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 2
|
||||
|
||||
@@ -5,7 +5,7 @@ cartpole-appo-vt:
|
||||
episode_reward_mean: 100
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
sample_batch_size: 10
|
||||
rollout_fragment_length: 10
|
||||
train_batch_size: 10
|
||||
num_envs_per_worker: 5
|
||||
num_workers: 1
|
||||
|
||||
@@ -5,7 +5,7 @@ cartpole-appo:
|
||||
episode_reward_mean: 100
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
sample_batch_size: 10
|
||||
rollout_fragment_length: 10
|
||||
train_batch_size: 10
|
||||
num_envs_per_worker: 5
|
||||
num_workers: 1
|
||||
|
||||
@@ -85,12 +85,12 @@ def ParallelRollouts(workers: WorkerSet, mode="bulk_sync",
|
||||
>>> rollouts = ParallelRollouts(workers, mode="async")
|
||||
>>> batch = next(rollouts)
|
||||
>>> print(batch.count)
|
||||
50 # config.sample_batch_size
|
||||
50 # config.rollout_fragment_length
|
||||
|
||||
>>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
>>> batch = next(rollouts)
|
||||
>>> print(batch.count)
|
||||
200 # config.sample_batch_size * config.num_workers
|
||||
200 # config.rollout_fragment_length * config.num_workers
|
||||
|
||||
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
from tensorflow.python.eager.context import eager_mode
|
||||
import unittest
|
||||
|
||||
@@ -188,5 +189,4 @@ class TestExplorations(unittest.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
Reference in New Issue
Block a user