[RLlib] Attention Net prep PR #1: Smaller cleanups. (#12447)

* WIP.

* Fix.

* Fix.

* Fix.
This commit is contained in:
Sven Mika
2020-11-28 01:25:47 +01:00
committed by GitHub
parent 569eee5e71
commit 0df55a139c
20 changed files with 144 additions and 42 deletions
+1 -1
View File
@@ -274,7 +274,7 @@ DDPGTorchPolicy = build_torch_policy(
optimizer_fn=make_ddpg_optimizers,
validate_spaces=validate_spaces,
before_init=before_init_fn,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
action_distribution_fn=get_distribution_inputs_and_class,
make_model_and_action_dist=build_ddpg_models_and_action_dist,
apply_gradients_fn=apply_gradients_fn,
+4 -4
View File
@@ -317,9 +317,9 @@ def setup_early_mixins(policy: Policy, obs_space, action_space,
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def after_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict) -> None:
ComputeTDErrorMixin.__init__(policy)
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
# Move target net to device (this is done automatically for the
@@ -397,7 +397,7 @@ DQNTorchPolicy = build_torch_policy(
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
extra_action_out_fn=extra_action_out_fn,
before_init=setup_early_mixins,
after_init=after_init,
before_loss_init=before_loss_init,
mixins=[
TargetNetworkMixin,
ComputeTDErrorMixin,
+1 -1
View File
@@ -81,5 +81,5 @@ MARWILTorchPolicy = build_torch_policy(
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
after_init=setup_mixins,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin])
+1 -1
View File
@@ -331,7 +331,7 @@ AsyncPPOTorchPolicy = build_torch_policy(
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=choose_optimizer,
before_init=setup_early_mixins,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
make_model=make_appo_model,
mixins=[
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
+6 -1
View File
@@ -47,7 +47,12 @@ def ppo_surrogate_loss(
# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
# Derive max_seq_len from the data itself, not from the seq_lens
# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
# 0-padded up to T=5 (as it's the case for attention nets).
B = tf.shape(train_batch["seq_lens"])[0]
max_seq_len = tf.shape(logits)[0] // B
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = tf.reshape(mask, [-1])
+1 -1
View File
@@ -265,7 +265,7 @@ PPOTorchPolicy = build_torch_policy(
postprocess_fn=postprocess_ppo_gae,
extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config,
after_init=setup_mixins,
before_loss_init=setup_mixins,
mixins=[
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
ValueNetworkMixin
+1
View File
@@ -24,6 +24,7 @@ class TestAPPO(unittest.TestCase):
for _ in framework_iterator(config):
print("w/o v-trace")
_config = config.copy()
_config["vtrace"] = False
trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0")
for i in range(num_iterations):
print(trainer.train())
+1 -1
View File
@@ -489,7 +489,7 @@ SACTorchPolicy = build_torch_policy(
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=optimizer_fn,
validate_spaces=validate_spaces,
after_init=setup_late_mixins,
before_loss_init=setup_late_mixins,
make_model_and_action_dist=build_sac_model_and_action_dist,
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
action_distribution_fn=action_distribution_fn,
+1 -1
View File
@@ -1,8 +1,8 @@
import ray
from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
@@ -157,6 +157,10 @@ class _AgentCollector:
batch = SampleBatch(batch_data)
if SampleBatch.UNROLL_ID not in batch.data:
# TODO: (sven) Once we have the additional
# model.preprocess_train_batch in place (attention net PR), we
# should not even need UNROLL_ID anymore:
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
batch.data[SampleBatch.UNROLL_ID] = np.repeat(
_AgentCollector._next_unroll_id, batch.count)
_AgentCollector._next_unroll_id += 1
@@ -238,7 +242,7 @@ class _PolicyCollector:
"""
for view_col, data in batch.items():
# Skip columns that are not used for training.
if view_col in view_requirements and \
if view_col not in view_requirements or \
not view_requirements[view_col].used_for_training:
continue
self.buffers[view_col].extend(data)
@@ -465,8 +469,7 @@ class _SimpleListCollector(_SampleCollector):
pre_batch = collector.build(policy.view_requirements)
pre_batches[agent_id] = (policy, pre_batch)
# Apply postprocessor.
post_batches = {}
# Apply reward clipping before calling postprocessing functions.
if self.clip_rewards is True:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
@@ -477,6 +480,7 @@ class _SimpleListCollector(_SampleCollector):
a_min=-self.clip_rewards,
a_max=self.clip_rewards)
post_batches = {}
for agent_id, (_, pre_batch) in pre_batches.items():
# Entire episode is said to be done.
# Error if no DONE at end of this agent's trajectory.
+3 -2
View File
@@ -257,7 +257,8 @@ class RolloutWorker(ParallelIteratorWorker):
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (DefaultCallbacks): Custom training callbacks.
callbacks (Type[DefaultCallbacks]): Custom sub-class of
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator (Callable[[IOContext], InputReader]): Function that
returns an InputReader object for loading previous generated
experiences.
@@ -340,7 +341,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.callbacks: "DefaultCallbacks" = callbacks()
else:
from ray.rllib.agents.callbacks import DefaultCallbacks
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
self.callbacks: DefaultCallbacks = DefaultCallbacks()
self.worker_index: int = worker_index
self.num_workers: int = num_workers
model_config: ModelConfigDict = model_config or {}
+11 -4
View File
@@ -1033,7 +1033,9 @@ def _process_observations_w_trajectory_view_api(
agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Infos from the environment.
agent_infos = infos[env_id].get(agent_id, {})
episode._set_last_info(agent_id, agent_infos)
# Record transition info if applicable.
if last_observation is None:
@@ -1058,15 +1060,20 @@ def _process_observations_w_trajectory_view_api(
"new_obs": filtered_obs,
}
# Add extra-action-fetches to collectors.
values_dict.update(**episode.last_pi_info_for(agent_id))
pol = policies[policy_id]
for key, value in episode.last_pi_info_for(agent_id).items():
values_dict[key] = value
# Env infos for this agent.
if "infos" in pol.view_requirements:
values_dict["infos"] = agent_infos
_sample_collector.add_action_reward_next_obs(
episode.episode_id, agent_id, env_id, policy_id,
agent_done, values_dict)
if not agent_done:
item = PolicyEvalData(
env_id, agent_id, filtered_obs, infos[env_id].get(
agent_id, {}), None if last_observation is None else
env_id, agent_id, filtered_obs, agent_infos, None
if last_observation is None else
episode.rnn_state_for(agent_id), None
if last_observation is None else
episode.last_action_for(agent_id),
@@ -10,7 +10,7 @@ import ray.rllib.agents.ppo as ppo
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.examples.policy.episode_env_aware_policy import \
EpisodeEnvAwarePolicy
EpisodeEnvAwareLSTMPolicy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
@@ -121,7 +121,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
obs_space = Box(-1.0, 1.0, shape=(700, ))
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
from ray.tune import register_env
register_env("ma_env", lambda c: RandomMultiAgentEnv({
"num_agents": 2,
@@ -147,7 +146,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
"policy_mapping_fn": policy_fn,
}
num_iterations = 2
# Only works in torch so far.
for _ in framework_iterator(config, frameworks="torch"):
print("w/ traj. view API")
config["_use_trajectory_view_api"] = True
@@ -253,7 +251,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
rollout_fragment_length = 200
assert rollout_fragment_length % max_seq_len == 0
policies = {
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
}
def policy_fn(agent_id):
@@ -316,8 +314,8 @@ def analyze_rnn_batch(batch, max_seq_len):
state_in_1 = batch["state_in_1"][idx]
# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][idx]
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][idx]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check state-in/out and next-obs values.
@@ -386,8 +384,8 @@ def analyze_rnn_batch(batch, max_seq_len):
r_t = batch["rewards"][k]
# Check postprocessing outputs.
if "postprocessed_column" in batch:
postprocessed_col_t = batch["postprocessed_column"][k]
if "2xobs" in batch:
postprocessed_col_t = batch["2xobs"][k]
assert (obs_t == postprocessed_col_t / 2.0).all()
# Check state-in/out and next-obs values.
@@ -8,7 +8,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
class EpisodeEnvAwarePolicy(RandomPolicy):
class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
"""A Policy that always knows the current EpisodeID and EnvID and
returns these in its actions."""
@@ -78,5 +78,67 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0
sample_batch["2xobs"] = sample_batch["obs"] * 2.0
return sample_batch
class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
"""A Policy that always knows the current EpisodeID and EnvID and
returns these in its actions."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state_space = Box(-1.0, 1.0, (1, ))
self.config["model"] = {"max_seq_len": 50}
class _fake_model:
pass
self.model = _fake_model()
self.model.inference_view_requirements = {
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
"env_id": ViewRequirement(),
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
"state_in_0": ViewRequirement(
"state_out_0",
# Provide state outs -50 to -1 as "state-in".
data_rel_pos="-50:-1",
# Repeat the incoming state every n time steps (usually max seq
# len).
batch_repeat_value=self.config["model"]["max_seq_len"],
space=self.state_space)
}
self.view_requirements = dict(super()._get_default_view_requirements(),
**self.model.inference_view_requirements)
@override(Policy)
def is_recurrent(self):
return True
@override(Policy)
def compute_actions_from_input_dict(self,
input_dict,
explore=None,
timestep=None,
**kwargs):
ts = input_dict["t"]
print(ts)
# Always return [episodeID, envID] as actions.
actions = np.array([[
input_dict[SampleBatch.AGENT_INDEX][i],
input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i]
] for i, _ in enumerate(input_dict["obs"])])
states = [np.array([[ts[i]] for i in range(len(input_dict["obs"]))])]
self.global_timestep += 1
return actions, states, {}
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["3xobs"] = sample_batch["obs"] * 3.0
return sample_batch
+1
View File
@@ -572,6 +572,7 @@ class Policy(metaclass=ABCMeta):
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
}
+1 -1
View File
@@ -81,7 +81,7 @@ class SampleBatch:
if self.seq_lens is not None and len(self.seq_lens) > 0:
self.count = sum(self.seq_lens)
else:
self.count = len(self.data[k])
self.count = len(next(iter(self.data.values())))
# Keeps track of new columns added after initial ones.
self.new_columns = []
+4
View File
@@ -354,6 +354,8 @@ class TorchPolicy(Policy):
)
train_batch = self._lazy_tensor_dict(postprocessed_batch)
# Calculate the actual policy loss.
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))
@@ -369,6 +371,7 @@ class TorchPolicy(Policy):
assert len(loss_out) == len(self._optimizers)
# assert not any(torch.isnan(l) for l in loss_out)
fetches = self.extra_compute_grad_fetches()
# Loop through all optimizers.
@@ -376,6 +379,7 @@ class TorchPolicy(Policy):
all_grads = []
for i, opt in enumerate(self._optimizers):
# Erase gradients in all vars of this optimizer.
opt.zero_grad()
# Recompute gradients of loss over all variables.
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
+1 -1
View File
@@ -184,7 +184,7 @@ class Exploration:
Policy's own loss function and maybe the Model's custom loss.
train_batch (SampleBatch): The training data to calculate the
loss(es) for. This train data has already gone through
this Exploration's `preprocess_train_batch()` method.
this Exploration's `postprocess_trajectory()` method.
Returns:
List[TensorType]: The updated list of loss terms.
+1 -1
View File
@@ -66,7 +66,7 @@ def minibatches(samples, sgd_minibatch_size):
# Replace with `if samples.seq_lens` check.
if "state_in_0" in samples.data or "state_out_0" in samples.data:
if log_once("not_shuffling_rnn_data_in_simple_mode"):
logger.warning("Not shuffling RNN data for SGD in simple mode")
logger.warning("Not time-shuffling RNN data for SGD.")
else:
samples.shuffle()
+28 -9
View File
@@ -37,14 +37,14 @@ def explained_variance(y, pred):
return tf.maximum(-1.0, 1 - (diff_var / y_var))
def get_placeholder(*, space=None, value=None, name=None):
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
from ray.rllib.models.catalog import ModelCatalog
if space is not None:
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
return ModelCatalog.get_action_placeholder(space, None)
return tf1.placeholder(
shape=(None, ) + space.shape,
shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
name=name,
)
@@ -52,8 +52,9 @@ def get_placeholder(*, space=None, value=None, name=None):
assert value is not None
shape = value.shape[1:]
return tf1.placeholder(
shape=(None, ) + (shape if isinstance(shape, tuple) else tuple(
shape.as_list())),
shape=(None, ) + ((None, )
if time_axis else ()) + (shape if isinstance(
shape, tuple) else tuple(shape.as_list())),
dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
name=name,
)
@@ -132,10 +133,11 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
def make_wrapper(fn):
if session_or_none:
placeholders = []
args_placeholders = []
kwargs_placeholders = {}
symbolic_out = [None]
def call(*args):
def call(*args, **kwargs):
args_flat = []
for a in args:
if type(a) is list:
@@ -153,13 +155,30 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
shape = ()
else:
shape = v.shape
placeholders.append(
args_placeholders.append(
tf1.placeholder(
dtype=v.dtype,
shape=shape,
name="arg_{}".format(i)))
symbolic_out[0] = fn(*placeholders)
feed_dict = dict(zip(placeholders, args))
for k, v in kwargs.items():
if dynamic_shape:
if len(v.shape) > 0:
shape = (None, ) + v.shape[1:]
else:
shape = ()
else:
shape = v.shape
kwargs_placeholders[k] = \
tf1.placeholder(
dtype=v.dtype,
shape=shape,
name="kwarg_{}".format(k))
symbolic_out[0] = fn(*args_placeholders,
**kwargs_placeholders)
feed_dict = dict(zip(args_placeholders, args))
feed_dict.update(
{kwargs_placeholders[k]: kwargs[k]
for k in kwargs.keys()})
ret = session_or_none.run(symbolic_out[0], feed_dict)
return ret