[RLlib] IMPALA PyTorch (#8287)

This PR adds an IMPALA PyTorch implementation.

- adds compilation tests for LSTM and w/o LSTM.
- adds learning test for CartPole.
This commit is contained in:
Sven Mika
2020-05-03 13:44:25 +02:00
committed by GitHub
parent 1228369a87
commit 166bb5d690
17 changed files with 335 additions and 45 deletions
+1 -1
View File
@@ -201,7 +201,7 @@ matrix:
- travis_wait 60 bazel test --config=ci --build_tests_only --test_tag_filters=quick_train rllib/...
# Test everything that does not have any of the "main" labels:
# "learning_tests|quick_train|examples|tests_dir".
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests,-quick_train,-examples,-tests_dir rllib/...
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests_tf,-learning_tests_torch,-quick_train,-examples,-tests_dir rllib/...
# RLlib: Everything in rllib/examples/ directory.
- os: linux
+2 -2
View File
@@ -18,7 +18,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
`APEX-DDPG`_ tf No **Yes** **Yes**
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
`IMPALA`_ tf **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
@@ -94,7 +94,7 @@ SpaceInvaders 646 ~300
Importance Weighted Actor-Learner Architecture (IMPALA)
-------------------------------------------------------
|tensorflow|
|pytorch| |tensorflow|
`[paper] <https://arxiv.org/abs/1802.01561>`__
`[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/impala/impala.py>`__
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code <https://github.com/deepmind/scalable_agent/blob/master/vtrace.py>`__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model <rllib-models.html#custom-models-tensorflow>`__. Multiple learner GPUs and experience replay are also supported.
+1 -1
View File
@@ -89,7 +89,7 @@ Algorithms
- |pytorch| |tensorflow| :ref:`Distributed Prioritized Experience Replay (Ape-X) <apex>`
- |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) <impala>`
- |pytorch| |tensorflow| :ref:`Importance Weighted Actor-Learner Architecture (IMPALA) <impala>`
- |pytorch| |tensorflow| :ref:`Asynchronous Proximal Policy Optimization (APPO) <appo>`
+6
View File
@@ -136,6 +136,12 @@ py_test(
size = "small",
srcs = ["agents/impala/tests/test_vtrace.py"]
)
py_test(
name = "test_impala",
tags = ["agents_dir"],
size = "medium",
srcs = ["agents/impala/tests/test_impala.py"]
)
# MARWILTrainer
py_test(
+29 -18
View File
@@ -11,7 +11,7 @@ from ray.tune.resources import Resources
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# V-trace params (see vtrace_tf.py).
# V-trace params (see vtrace_tf/torch.py).
"vtrace": True,
"vtrace_clip_rho_threshold": 1.0,
"vtrace_clip_pg_rho_threshold": 1.0,
@@ -83,22 +83,6 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def choose_policy(config):
if config["vtrace"]:
return VTraceTFPolicy
else:
return A3CTFPolicy
def validate_config(config):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"IMPALA does not support PyTorch yet! Use tf instead.")
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
def defer_make_workers(trainer, env_creator, policy, config):
# Defer worker creation to after the optimizer has been created.
return trainer._make_workers(env_creator, policy, config, 0)
@@ -157,12 +141,39 @@ class OverrideDefaultResourceRequest:
cf["num_workers"])
def get_policy_class(config):
if config["use_pytorch"]:
if config["vtrace"]:
from ray.rllib.agents.impala.vtrace_torch_policy import \
VTraceTorchPolicy
return VTraceTorchPolicy
else:
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
return A3CTorchPolicy
else:
if config["vtrace"]:
return VTraceTFPolicy
else:
return A3CTFPolicy
def validate_config(config):
if config["entropy_coeff"] < 0.0:
raise DeprecationWarning("`entropy_coeff` must be >= 0.0!")
if config["vtrace"] and not config["in_evaluation"]:
if config["batch_mode"] != "truncate_episodes":
raise ValueError(
"Must use `batch_mode`=truncate_episodes if `vtrace` is True.")
ImpalaTrainer = build_trainer(
name="IMPALA",
default_config=DEFAULT_CONFIG,
default_policy=VTraceTFPolicy,
validate_config=validate_config,
get_policy_class=choose_policy,
get_policy_class=get_policy_class,
make_workers=defer_make_workers,
make_policy_optimizer=make_aggregators_and_optimizer,
mixins=[OverrideDefaultResourceRequest])
+46
View File
@@ -0,0 +1,46 @@
import unittest
import ray
import ray.rllib.agents.impala as impala
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import framework_iterator
tf = try_import_tf()
class TestIMPALA(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_impala_compilation(self):
"""Test whether an ImpalaTrainer can be built with both frameworks."""
config = impala.DEFAULT_CONFIG.copy()
num_iterations = 1
for _ in framework_iterator(config, frameworks=("torch", "tf")):
local_cfg = config.copy()
for env in ["Pendulum-v0", "CartPole-v0"]:
print("Env={}".format(env))
print("w/ LSTM")
# Test w/o LSTM.
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
# Test w/ LSTM.
print("w/o LSTM")
local_cfg["model"]["use_lstm"] = True
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
-7
View File
@@ -251,12 +251,6 @@ def postprocess_trajectory(policy,
return sample_batch
def validate_config(policy, obs_space, action_space, config):
if config["vtrace"] and not config["in_evaluation"]:
assert config["batch_mode"] == "truncate_episodes", \
"Must use `truncate_episodes` batch mode with V-trace."
def choose_optimizer(policy, config):
if policy.config["opt_type"] == "adam":
return tf.train.AdamOptimizer(policy.cur_lr)
@@ -289,7 +283,6 @@ VTraceTFPolicy = build_tf_policy(
postprocess_fn=postprocess_trajectory,
optimizer_fn=choose_optimizer,
gradients_fn=clip_gradients,
before_init=validate_config,
before_loss_init=setup_mixins,
mixins=[LearningRateSchedule, EntropyCoeffSchedule],
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
+2
View File
@@ -32,6 +32,7 @@ tensors.
from ray.rllib.agents.impala.vtrace_tf import VTraceFromLogitsReturns, \
VTraceReturns
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.utils import force_list
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
@@ -217,6 +218,7 @@ def multi_from_logits(behaviour_policy_logits,
behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions, dist_class, model))
behaviour_action_log_probs = force_list(behaviour_action_log_probs)
log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)
+215
View File
@@ -1,12 +1,187 @@
import gym
import logging
import numpy as np
import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.impala.vtrace_tf_policy import postprocess_trajectory
import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import global_norm, sequence_mask
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
class VTraceLoss:
def __init__(self,
actions,
actions_logp,
actions_entropy,
dones,
behaviour_action_logp,
behaviour_logits,
target_logits,
discount,
rewards,
values,
bootstrap_value,
dist_class,
model,
valid_mask,
config,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0):
"""Policy gradient loss with vtrace importance weighting.
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
batch_size. The reason we need to know `B` is for V-trace to properly
handle episode cut boundaries.
Args:
actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
actions_logp: A float32 tensor of shape [T, B].
actions_entropy: A float32 tensor of shape [T, B].
dones: A bool tensor of shape [T, B].
behaviour_action_logp: Tensor of shape [T, B].
behaviour_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
target_logits: A list with length of ACTION_SPACE of float32
tensors of shapes
[T, B, ACTION_SPACE[0]],
...,
[T, B, ACTION_SPACE[-1]]
discount: A float32 scalar.
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
config: Trainer config dict.
"""
if valid_mask is None:
valid_mask = torch.ones_like(actions_logp)
# Compute vtrace on the CPU for better perf
# (devices handled inside `vtrace.multi_from_logits`).
self.vtrace_returns = vtrace.multi_from_logits(
behaviour_action_log_probs=behaviour_action_logp,
behaviour_policy_logits=behaviour_logits,
target_policy_logits=target_logits,
actions=torch.unbind(actions, dim=2),
discounts=(1.0 - dones.float()) * discount,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
dist_class=dist_class,
model=model,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)
self.value_targets = self.vtrace_returns.vs
# The policy gradients loss
self.pi_loss = -torch.sum(
actions_logp * self.vtrace_returns.pg_advantages * valid_mask)
# The baseline loss
delta = (values - self.vtrace_returns.vs) * valid_mask
self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))
# The entropy loss
self.entropy = torch.sum(actions_entropy * valid_mask)
# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
self.entropy * entropy_coeff)
def build_vtrace_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
if isinstance(policy.action_space, gym.spaces.Discrete):
is_multidiscrete = False
output_hidden_shape = [policy.action_space.n]
elif isinstance(policy.action_space,
gym.spaces.multi_discrete.MultiDiscrete):
is_multidiscrete = True
output_hidden_shape = policy.action_space.nvec.astype(np.int32)
else:
is_multidiscrete = False
output_hidden_shape = 1
def _make_time_major(*args, **kw):
return make_time_major(policy, train_batch.get("seq_lens"), *args,
**kw)
actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
rewards = train_batch[SampleBatch.REWARDS]
behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
if isinstance(output_hidden_shape, list):
unpacked_behaviour_logits = torch.split(
behaviour_logits, output_hidden_shape, dim=1)
unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1)
else:
unpacked_behaviour_logits = torch.chunk(
behaviour_logits, output_hidden_shape, dim=1)
unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
values = model.value_function()
if policy.is_recurrent():
max_seq_len = torch.max(train_batch["seq_lens"]) - 1
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = torch.reshape(mask, [-1])
else:
mask = torch.ones_like(rewards)
# Prepare actions for loss.
loss_actions = actions if is_multidiscrete else torch.unsqueeze(
actions, dim=1)
# Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
policy.loss = VTraceLoss(
actions=_make_time_major(loss_actions, drop_last=True),
actions_logp=_make_time_major(
action_dist.logp(actions), drop_last=True),
actions_entropy=_make_time_major(
action_dist.multi_entropy(), drop_last=True),
dones=_make_time_major(dones, drop_last=True),
behaviour_action_logp=_make_time_major(
behaviour_action_logp, drop_last=True),
behaviour_logits=_make_time_major(
unpacked_behaviour_logits, drop_last=True),
target_logits=_make_time_major(unpacked_outputs, drop_last=True),
discount=policy.config["gamma"],
rewards=_make_time_major(rewards, drop_last=True),
values=_make_time_major(values, drop_last=True),
bootstrap_value=_make_time_major(values)[-1],
dist_class=TorchCategorical if is_multidiscrete else dist_class,
model=model,
valid_mask=_make_time_major(mask, drop_last=True),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])
return policy.loss.total_loss
def make_time_major(policy, seq_lens, tensor, drop_last=False):
"""Swaps batch and trajectory axis.
@@ -44,6 +219,27 @@ def make_time_major(policy, seq_lens, tensor, drop_last=False):
return res
def stats(policy, train_batch):
values_batched = make_time_major(
policy,
train_batch.get("seq_lens"),
policy.model.value_function(),
drop_last=policy.config["vtrace"])
return {
"cur_lr": policy.cur_lr,
"policy_loss": policy.loss.pi_loss,
"entropy": policy.loss.entropy,
"entropy_coeff": policy.entropy_coeff,
"var_gnorm": global_norm(policy.model.trainable_variables()),
"vf_loss": policy.loss.vf_loss,
"vf_explained_var": explained_variance(
torch.reshape(policy.loss.value_targets, [-1]),
torch.reshape(values_batched, [-1]),
framework="torch"),
}
def choose_optimizer(policy, config):
if policy.config["opt_type"] == "adam":
return torch.optim.Adam(
@@ -55,3 +251,22 @@ def choose_optimizer(policy, config):
weight_decay=config["decay"],
momentum=config["momentum"],
eps=config["epsilon"])
def setup_mixins(policy, obs_space, action_space, config):
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
VTraceTorchPolicy = build_torch_policy(
name="VTraceTorchPolicy",
loss_fn=build_vtrace_loss,
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_trajectory,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=choose_optimizer,
before_init=setup_mixins,
mixins=[LearningRateSchedule, EntropyCoeffSchedule],
get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])
+2 -6
View File
@@ -1,6 +1,7 @@
from ray.rllib.agents.impala.impala import validate_config
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents.ppo.ppo import update_kl
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
# yapf: disable
@@ -89,11 +90,6 @@ def get_policy_class(config):
return AsyncPPOTFPolicy
def validate_config(config):
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >= 0.0!")
APPOTrainer = impala.ImpalaTrainer.with_updates(
name="APPO",
default_config=DEFAULT_CONFIG,
+1 -2
View File
@@ -8,7 +8,7 @@ import gym
from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, validate_config, choose_optimizer
clip_gradients, choose_optimizer
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.sample_batch import SampleBatch
@@ -449,7 +449,6 @@ AsyncPPOTFPolicy = build_tf_policy(
optimizer_fn=choose_optimizer,
gradients_fn=clip_gradients,
extra_action_fetches_fn=add_values,
before_init=validate_config,
before_loss_init=setup_mixins,
after_init=setup_late_mixins,
mixins=[
-3
View File
@@ -528,9 +528,6 @@ class ModelCatalog:
FCNet)
from ray.rllib.models.torch.visionnet import (VisionNetwork as
VisionNet)
if model_config.get("use_lstm"):
raise NotImplementedError(
"LSTM auto-wrapping not implemented for torch")
else:
from ray.rllib.models.tf.fcnet_v2 import \
FullyConnectedNetwork as FCNet
+10
View File
@@ -237,6 +237,11 @@ class TorchSquashedGaussian(TorchDistributionWrapper):
unsquashed = atanh(save_normed_values)
return unsquashed
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class TorchBeta(TorchDistributionWrapper):
"""
@@ -285,6 +290,11 @@ class TorchBeta(TorchDistributionWrapper):
def _unsquash(self, values):
return (values - self.low) / (self.high - self.low)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.prod(action_space.shape) * 2
class TorchDeterministic(TorchDistributionWrapper):
"""Action distribution that returns the input values directly.
@@ -0,0 +1,9 @@
cartpole-impala-tf:
env: CartPole-v0
run: IMPALA
stop:
episode_reward_mean: 150
timesteps_total: 500000
config:
use_pytorch: false
num_gpus: 0
@@ -0,0 +1,9 @@
cartpole-impala-torch:
env: CartPole-v0
run: IMPALA
stop:
episode_reward_mean: 150
timesteps_total: 500000
config:
use_pytorch: true
num_gpus: 0
@@ -30,10 +30,7 @@ def do_test_explorations(run,
# Test all frameworks.
for fw in framework_iterator(core_config):
if fw == "torch" and \
run in [impala.ImpalaTrainer, sac.SACTrainer]:
continue
elif fw == "eager" and run in [
if fw == "eager" and run in [
ddpg.DDPGTrainer, sac.SACTrainer, td3.TD3Trainer
]:
continue
+1 -1
View File
@@ -47,7 +47,7 @@ def framework_iterator(config=None,
logger.warning(
"framework_iterator skipping torch (not installed)!")
continue
elif not tf:
if fw != "torch" and not tf:
logger.warning("framework_iterator skipping {} (tf not "
"installed)!".format(fw))
continue